Compare commits

..

99 Commits

Author SHA1 Message Date
8ca8538268 Update CHANGELOG for v1.13.5. 2025-11-18 13:42:44 -08:00
155c49cc5e build: gitea is pulling the repo information from github. 2025-11-18 13:42:28 -08:00
dda9fd9f07 Bump CHANGELOG for v1.13.4. 2025-11-18 13:37:12 -08:00
c251c1e1b5 fix up goreleaser config. 2025-11-18 13:35:47 -08:00
6eb533f79b Tweak goreleaser config. 2025-11-18 13:23:20 -08:00
ea5ffa4828 Update CHANGELOG for v1.13.3. 2025-11-18 13:11:29 -08:00
aa96e47112 update goreleaser config 2025-11-18 13:01:23 -08:00
d34a417dce fixups 2025-11-18 12:37:15 -08:00
d11e0cf9f9 lib: add byte slice output for HexEncode 2025-11-18 11:58:43 -08:00
aad7d68599 cmd/ski: update display mode 2025-11-18 11:46:58 -08:00
4560868688 cmd: switch programs over to certlib.Fetcher. 2025-11-18 11:08:17 -08:00
8d5406256f cmd/certdump: use certlib.Fetcher. 2025-11-17 19:49:42 -08:00
9280e846fa certlib: add Fetcher
Fetcher is an interface and set of functions for retrieving a
certificate (or chain of certificates) from a spec. It will
determine whether the certificate spec is a file or a server,
and fetch accordingly.
2025-11-17 19:48:57 -08:00
0a71661901 CHANGELOG: bump to v1.13.2. 2025-11-17 15:50:51 -08:00
804f53d27d Refactor bundling into separate package. 2025-11-17 15:08:10 -08:00
cfb80355bb Update CHANGELOG for v1.13.1. 2025-11-17 10:08:05 -08:00
77160395a0 Cleaning up a few things. 2025-11-17 10:07:03 -08:00
37d5e04421 Adding Dockerfile 2025-11-17 09:03:43 -08:00
dc54eeacbc Remove cert bundles generated in testdata. 2025-11-17 08:36:31 -08:00
e2a3081ce5 cmd: add certser command. 2025-11-17 07:18:46 -08:00
3149d958f4 cmd: add certser 2025-11-17 06:55:20 -08:00
f296344acf twofactor: linting fixes 2025-11-16 21:51:38 -08:00
3fb2d88a3f go get rsc.io/qr 2025-11-16 20:44:13 -08:00
150c02b377 Fix subtree. 2025-11-16 18:55:43 -08:00
83f88c49fe Import twofactor. 2025-11-16 18:45:34 -08:00
7c437ac45f Add 'twofactor/' from commit 'c999bf35b0e47de4f63d59abbe0d7efc76c13ced'
git-subtree-dir: twofactor
git-subtree-mainline: 4dc135cfe0
git-subtree-split: c999bf35b0
2025-11-16 18:43:03 -08:00
c999bf35b0 linter fixes. 2025-11-16 18:39:18 -08:00
4dc135cfe0 Update CHANGELOG for v1.11.2. 2025-11-16 13:18:38 -08:00
790113e189 cmd: refactor for code reuse. 2025-11-16 13:15:08 -08:00
8348c5fd65 Update CHANGELOG. 2025-11-16 11:09:02 -08:00
1eafb638a8 cmd: finish linting fixes 2025-11-16 11:03:12 -08:00
3ad562b6fa cmd: continuing linter fixes 2025-11-16 02:54:02 -08:00
0f77bd49dc cmd: continue lint fixes. 2025-11-16 01:32:19 -08:00
f31d74243f cmd: start linting fixes. 2025-11-16 00:36:19 -08:00
a573f1cd20 Update CHANGELOG. 2025-11-15 23:48:54 -08:00
f93cf5fa9c adding lru/mru cache. 2025-11-15 23:48:00 -08:00
b879d62384 cert-bundler: lint fixes 2025-11-15 23:27:50 -08:00
c99ffd4394 cmd: cleaning up programs 2025-11-15 23:17:40 -08:00
ed8c07c1c5 Add 'mru/' from commit '2899885c4220560df4f60e4c052a6ab9773a0386'
git-subtree-dir: mru
git-subtree-mainline: cf2b016433
git-subtree-split: 2899885c42
2025-11-15 22:54:26 -08:00
cf2b016433 certlib: complete overhaul. 2025-11-15 22:54:12 -08:00
2899885c42 linter fixes 2025-11-15 22:46:42 -08:00
f3b4838cf6 Overhauling certlib.
LICENSE to Apache 2.0.
2025-11-15 22:00:29 -08:00
8ed30e9960 certlib: linter autofixes 2025-11-15 21:10:09 -08:00
c7de3919b0 log: linting fixes 2025-11-15 21:06:16 -08:00
840066004a logging: linter fixes 2025-11-15 21:02:19 -08:00
9fb93a3802 mwc: linter fixes 2025-11-15 20:39:21 -08:00
ecc7e5ab1e rand: remove unused package 2025-11-15 20:37:02 -08:00
a934c42aa1 temp fix before removing 2025-11-15 20:36:14 -08:00
948986ba60 testutil: remove unused code
It was probably a WIP for something else; it was started in
2016 and not touched since.
2025-11-15 20:25:37 -08:00
3be86573aa testio: linting fixes 2025-11-15 20:24:00 -08:00
e3a6355edb tee: add tests; linter fixes.
Additionally, disable reassign in testing files.
2025-11-15 20:18:09 -08:00
66d16acebc seekbuf: linter fixes 2025-11-15 19:58:41 -08:00
fdff2e0afe sbuf: linter fixes 2025-11-15 19:53:18 -08:00
3d9625b40b Fix calls to die.With. 2025-11-15 16:10:14 -08:00
547a0d8f32 disable linting until cleanups are finished 2025-11-15 16:00:58 -08:00
876a0a2c2b fileutil: linter fixes. 2025-11-15 15:58:51 -08:00
a37d28e3d7 die: linter feedback fixes. 2025-11-15 15:55:17 -08:00
ddf26e00af dbg: linter feedback updates. 2025-11-15 15:53:57 -08:00
e4db163efe Cleaning up. 2025-11-15 15:48:18 -08:00
571443c282 config: apply linting feedback. 2025-11-15 15:47:29 -08:00
aba5e519a4 First round of linter cleanups. 2025-11-15 15:11:07 -08:00
5fcba0e814 Trying a different config. 2025-11-15 13:34:18 -08:00
928c643d8d Fix linter config. 2025-11-15 13:16:30 -08:00
fd9f9f6d66 Fix linting. 2025-11-15 13:08:38 -08:00
a5b7727c8f Add linting stage. 2025-11-15 13:05:00 -08:00
3135c18d95 ignore goland directory 2025-11-15 01:53:40 -08:00
0dcd18c6f1 clean up code
- travisci is long dead
- golangci-lint the repo
2024-12-02 13:47:43 -08:00
024d552293 add circle ci config 2024-12-02 13:26:34 -08:00
9cd2ced695 There are different keys for different hash sizes. 2024-12-02 13:16:32 -08:00
b92e16fa4d Handle evictions properly when cache is empty. 2023-08-27 18:01:16 -07:00
6fbdece4be Initial import. 2022-02-24 21:39:10 -08:00
619c08a13f Update travis to latest Go versions. 2020-10-31 08:06:11 -07:00
944a57bf0e Switch to go modules. 2020-10-31 07:41:31 -07:00
0857b29624 Actually support clock mocking. 2020-10-31 07:24:34 -07:00
CodeLingo Bot
e95404bfc5 Fix function comments based on best practices from Effective Go
Signed-off-by: CodeLingo Bot <bot@codelingo.io>
2020-10-30 14:57:11 -07:00
ujjwalsh
924654e7c4 Added Support for Linux on Power 2020-10-30 07:50:32 -07:00
9e0979e07f Support clock mocking.
This addresses #15.
2018-12-07 08:23:01 -08:00
Aaron Bieber
bbc82ff8de Pad non-padded secrets. This lets us continue building on <= go1.8.
- Add tests for secrets using various padding methods.
- Add a new method/test to append padding to non-padded secrets.
2018-04-18 13:39:21 -07:00
Aaron Bieber
5fd928f69a Decode using WithPadding as pointed out by @gl-sergei.
This makes us print the same 6 digits as oathtool for non-padded
secrets like "a6mryljlbufszudtjdt42nh5by".
2018-04-18 13:39:21 -07:00
Aaron Bieber
acefe4a3b9 Don't assume our secret is base32 encoded.
According to https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm
secrets are only base32 encoded in gauthenticator and gauth friendly providers.
2018-04-16 13:14:03 -07:00
a1452cebc9 Travis requires a string for Go 1.10. 2018-04-16 13:03:16 -07:00
6e9812e6f5 Vendor dependencies and add more tests. 2018-04-16 13:03:16 -07:00
Aaron Bieber
8c34415c34 add readme 2018-04-16 12:52:39 -07:00
Paul TREHIOU
2cf2c15def Case insensitive algorithm match 2018-04-16 12:43:27 -07:00
Aaron Bieber
eaad1884d4 Make sure our secret is always uppercase
Non-uppercase secrets that are base32 encoded will fial to decode
unless we upper them.
2017-09-17 18:19:23 -07:00
5d57d844d4 Add license (MIT). 2017-04-13 10:02:20 -07:00
Kyle Isom
31b9d175dd Add travis config. 2017-03-20 14:20:49 -07:00
Aaron Bieber
79e106da2e point to new qr location 2017-03-20 13:18:56 -07:00
Kyle Isom
939b1bc272 Updating imports. 2015-08-12 12:29:34 -07:00
Kyle
89e74f390b Add doc.go, finish YubiKey removal. 2014-04-24 20:43:13 -06:00
Kyle
7881b6fdfc Remove test TOTP client. 2014-04-24 20:40:44 -06:00
Kyle
5bef33245f Remove YubiKey (not currently functional). 2014-04-24 20:37:53 -06:00
Kyle
84250b0501 More documentation. 2014-04-24 20:37:00 -06:00
Kyle Isom
459e9f880f Add function to build Google TOTPs from secret 2014-04-23 16:54:16 -07:00
Kyle Isom
0982f47ce3 Add last night's progress.
Basic functionality for HOTP, TOTP, and YubiKey OTP. Still need YubiKey
HMAC, serialisation, check, and scan.
2013-12-20 17:00:01 -07:00
Kyle Isom
1dec15fd11 add missing files
new files are
	oath_test
	totp code
2013-12-19 00:21:26 -07:00
Kyle Isom
2ee9cae5ba Add basic Google Authenticator TOTP client. 2013-12-19 00:20:00 -07:00
Kyle Isom
dc04475120 HOTP and TOTP-SHA-1 working.
why the frak aren't the SHA-256 and SHA-512 variants working
2013-12-19 00:04:26 -07:00
Kyle Isom
dbbd5116b5 Initial import.
Basic HOTP functionality.
2013-12-18 21:48:14 -07:00
138 changed files with 7136 additions and 3163 deletions

View File

@@ -5,6 +5,30 @@ version: 2.1
# Define a job to be invoked later in a workflow. # Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs # See: https://circleci.com/docs/2.0/configuration-reference/#jobs
jobs: jobs:
lint:
working_directory: ~/repo
docker:
- image: cimg/go:1.22.2
steps:
- checkout
- restore_cache:
keys:
- go-mod-v4-{{ checksum "go.sum" }}
- run:
name: Install Dependencies
command: go mod download
- save_cache:
key: go-mod-v4-{{ checksum "go.sum" }}
paths:
- "/go/pkg/mod"
- run:
name: Install golangci-lint
command: |
curl -sSfL https://raw.githubusercontent.com/golangci/golangci-lint/master/install.sh | sh -s -- -b $(go env GOPATH)/bin
- run:
name: Run golangci-lint
command: golangci-lint run --timeout=5m
testbuild: testbuild:
working_directory: ~/repo working_directory: ~/repo
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub. # Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
@@ -27,16 +51,17 @@ jobs:
- "/go/pkg/mod" - "/go/pkg/mod"
- run: - run:
name: Run tests name: Run tests
command: go test ./... command: go test -race ./...
- run: - run:
name: Run build name: Run build
command: go build ./... command: go build ./...
- store_test_results: - store_test_results:
path: /tmp/test-reports path: /tmp/test-reports
# Invoke jobs via workflows # Invoke jobs via workflows
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows # See: https://circleci.com/docs/2.0/configuration-reference/#workflows
# Linting is disabled while cleanups are ongoing.
workflows: workflows:
testbuild: testbuild:
jobs: jobs:
- testbuild - testbuild
- lint

35
.github/workflows/release.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Release
on:
push:
tags:
- 'v*'
workflow_dispatch: {}
permissions:
contents: write
jobs:
goreleaser:
name: GoReleaser
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache: true
- name: Run GoReleaser
uses: goreleaser/goreleaser-action@v6
with:
distribution: goreleaser
version: latest
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

9
.gitignore vendored
View File

@@ -1,4 +1,5 @@
bazel-bin .idea
bazel-goutils cmd/cert-bundler/testdata/pkg/*
bazel-out # Added by goreleaser init:
bazel-testlogs dist/
cmd/cert-bundler/testdata/bundle/

View File

@@ -1,87 +1,522 @@
run: # This file is licensed under the terms of the MIT license https://opensource.org/license/mit
timeout: 5m # Copyright (c) 2021-2025 Marat Reymers
tests: true
build-tags: [] ## Golden config for golangci-lint v2.6.2
modules-download-mode: readonly #
# This is the best config for golangci-lint based on my experience and opinion.
# It is very strict, but not extremely strict.
# Feel free to adapt it to suit your needs.
# If this config helps you, please consider keeping a link to this file (see the next comment).
# Based on https://gist.github.com/maratori/47a4d00457a92aa426dbd48a18776322
version: "2"
output:
sort-order:
- file
- linter
- severity
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
# Exclude some lints for CLI programs under cmd/ (package main).
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
exclude-rules:
- path: ^cmd/
linters:
- forbidigo
- path: cmd/.*
linters:
- forbidigo
- path: .*/cmd/.*
linters:
- forbidigo
formatters:
enable:
- goimports # checks if the code and import statements are formatted according to the 'goimports' command
- golines # checks if code is formatted, and fixes long lines
## you may want to enable
#- gci # checks if code and import statements are formatted, with additional rules
#- gofmt # checks if the code is formatted according to 'gofmt' command
#- gofumpt # enforces a stricter format than 'gofmt', while being backwards compatible
#- swaggo # formats swaggo comments
# All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml
settings:
goimports:
# A list of prefixes, which, if set, checks import paths
# with the given prefixes are grouped after 3rd-party packages.
# Default: []
local-prefixes:
- github.com/my/project
golines:
# Target maximum line length.
# Default: 100
max-len: 120
linters: linters:
enable: enable:
- errcheck - asasalint # checks for pass []any as any in variadic func(...any)
- gosimple - asciicheck # checks that your code does not contain non-ASCII identifiers
- govet - bidichk # checks for dangerous unicode character sequences
- ineffassign - bodyclose # checks whether HTTP response body is closed successfully
- staticcheck - canonicalheader # checks whether net/http.Header uses canonical header
- unused - copyloopvar # detects places where loop variables are copied (Go 1.22+)
- gofmt - cyclop # checks function and package cyclomatic complexity
- goimports - depguard # checks if package imports are in a list of acceptable packages
- misspell - dupl # tool for code clone detection
- unparam - durationcheck # checks for two durations multiplied together
- unconvert - errcheck # checking for unchecked errors, these unchecked errors can be critical bugs in some cases
- goconst - errname # checks that sentinel errors are prefixed with the Err and error types are suffixed with the Error
- gocyclo - errorlint # finds code that will cause problems with the error wrapping scheme introduced in Go 1.13
- gosec - exhaustive # checks exhaustiveness of enum switch statements
- prealloc - exptostd # detects functions from golang.org/x/exp/ that can be replaced by std functions
- copyloopvar - fatcontext # detects nested contexts in loops
- revive - forbidigo # forbids identifiers
- typecheck - funcorder # checks the order of functions, methods, and constructors
- funlen # tool for detection of long functions
- gocheckcompilerdirectives # validates go compiler directive comments (//go:)
- gochecksumtype # checks exhaustiveness on Go "sum types"
- gocognit # computes and checks the cognitive complexity of functions
- goconst # finds repeated strings that could be replaced by a constant
- gocritic # provides diagnostics that check for bugs, performance and style issues
- gocyclo # computes and checks the cyclomatic complexity of functions
- godoclint # checks Golang's documentation practice
- godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
- ineffassign # detects when assignments to existing variables are not used
- intrange # finds places where for loops could make use of an integer range
- iotamixing # checks if iotas are being used in const blocks with other non-iota declarations
- loggercheck # checks key value pairs for common logger libraries (kitlog,klog,logr,zap)
- makezero # finds slice declarations with non-zero initial length
- mirror # reports wrong mirror patterns of bytes/strings usage
- mnd # detects magic numbers
- modernize # suggests simplifications to Go code, using modern language and library features
- musttag # enforces field tags in (un)marshaled structs
- nakedret # finds naked returns in functions greater than a specified function length
- nestif # reports deeply nested if statements
- nilerr # finds the code that returns nil even if it checks that the error is not nil
- nilnesserr # reports that it checks for err != nil, but it returns a different nil value error (powered by nilness and nilerr)
- nilnil # checks that there is no simultaneous return of nil error and an invalid value
- noctx # finds sending http request without context.Context
- nolintlint # reports ill-formed or insufficient nolint directives
- nonamedreturns # reports all named returns
- nosprintfhostport # checks for misuse of Sprintf to construct a host with port in a URL
- perfsprint # checks that fmt.Sprintf can be replaced with a faster alternative
- predeclared # finds code that shadows one of Go's predeclared identifiers
- promlinter # checks Prometheus metrics naming via promlint
- protogetter # reports direct reads from proto message fields when getters should be used
- reassign # checks that package variables are not reassigned
- recvcheck # checks for receiver type consistency
- revive # fast, configurable, extensible, flexible, and beautiful linter for Go, drop-in replacement of golint
- rowserrcheck # checks whether Err of rows is checked successfully
- sloglint # ensure consistent code style when using log/slog
- spancheck # checks for mistakes with OpenTelemetry/Census spans
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
- staticcheck # is a go vet on steroids, applying a ton of static analysis checks
- testableexamples # checks if examples are testable (have an expected output)
- testifylint # checks usage of github.com/stretchr/testify
- testpackage # makes you use a separate _test package
- tparallel # detects inappropriate usage of t.Parallel() method in your Go test codes
- unconvert # removes unnecessary type conversions
- unparam # reports unused function parameters
- unqueryvet # detects SELECT * in SQL queries and SQL builders, encouraging explicit column selection
- unused # checks for unused constants, variables, functions and types
- usestdlibvars # detects the possibility to use variables/constants from the Go standard library
- usetesting # reports uses of functions with replacement inside the testing package
- wastedassign # finds wasted assignment statements
- whitespace # detects leading and trailing whitespace
linters-settings: ## you may want to enable
gocyclo: #- arangolint # opinionated best practices for arangodb client
min-complexity: 15 #- decorder # checks declaration order and count of types, constants, variables and functions
#- exhaustruct # [highly recommend to enable] checks if all structure fields are initialized
goconst: #- ginkgolinter # [if you use ginkgo/gomega] enforces standards of using ginkgo and gomega
min-len: 3 #- godox # detects usage of FIXME, TODO and other keywords inside comments
min-occurrences: 3 #- goheader # checks is file header matches to pattern
#- inamedparam # [great idea, but too strict, need to ignore a lot of cases by default] reports interfaces with unnamed method parameters
misspell: #- interfacebloat # checks the number of methods inside an interface
locale: US #- ireturn # accept interfaces, return concrete types
#- noinlineerr # disallows inline error handling `if err := ...; err != nil {`
revive: #- prealloc # [premature optimization, but can be used in some cases] finds slice declarations that could potentially be preallocated
rules: #- tagalign # checks that struct tags are well aligned
- name: exported #- varnamelen # [great idea, but too many false positives] checks that the length of a variable's name matches its scope
disabled: false #- wrapcheck # checks that errors returned from external packages are wrapped
- name: error-return #- zerologlint # detects the wrong usage of zerolog that a user forgets to dispatch zerolog.Event
- name: error-naming
- name: if-return
- name: var-naming
- name: package-comments
disabled: true
- name: indent-error-flow
- name: context-as-argument
gosec:
excludes:
- G304 # File path from variable (common in file utilities)
- G404 # Use of weak random (acceptable for non-crypto use)
issues: ## disabled
exclude-rules: #- containedctx # detects struct contained context.Context field
# Exclude some linters from running on tests files #- contextcheck # [too many false positives] checks the function whether use a non-inherited context
- path: _test\.go #- dogsled # checks assignments with too many blank identifiers (e.g. x, _, _, _, := f())
linters: #- dupword # [useless without config] checks for duplicate words in the source code
- gocyclo #- err113 # [too strict] checks the errors handling expressions
- errcheck #- errchkjson # [don't see profit + I'm against of omitting errors like in the first example https://github.com/breml/errchkjson] checks types passed to the json encoding functions. Reports unsupported types and optionally reports occasions, where the check for the returned error can be omitted
- gosec #- forcetypeassert # [replaced by errcheck] finds forced type assertions
#- gomodguard # [use more powerful depguard] allow and block lists linter for direct Go module dependencies
# Exclude embedded content from checks #- gosmopolitan # reports certain i18n/l10n anti-patterns in your Go codebase
- path: ".*\\.txt$" #- grouper # analyzes expression groups
linters: #- importas # enforces consistent import aliases
#- lll # [replaced by golines] reports long lines
#- maintidx # measures the maintainability index of each function
#- misspell # [useless] finds commonly misspelled English words in comments
#- nlreturn # [too strict and mostly code is not more readable] checks for a new line before return and branch statements to increase code clarity
#- paralleltest # [too many false positives] detects missing usage of t.Parallel() method in your Go test
#- tagliatelle # checks the struct tags
#- thelper # detects golang test helpers without t.Helper() call and checks the consistency of test helpers
#- wsl # [too strict and mostly code is not more readable] whitespace linter forces you to use empty lines
#- wsl_v5 # [too strict and mostly code is not more readable] add or remove empty lines
# All settings can be found here https://github.com/golangci/golangci-lint/blob/HEAD/.golangci.reference.yml
settings:
cyclop:
# The maximal code complexity to report.
# Default: 10
max-complexity: 30
# The maximal average package complexity.
# If it's higher than 0.0 (float) the check is enabled.
# Default: 0.0
package-average: 10.0
depguard:
# Rules to apply.
#
# Variables:
# - File Variables
# Use an exclamation mark `!` to negate a variable.
# Example: `!$test` matches any file that is not a go test file.
#
# `$all` - matches all go files
# `$test` - matches all go test files
#
# - Package Variables
#
# `$gostd` - matches all of go's standard library (Pulled from `GOROOT`)
#
# Default (applies if no custom rules are defined): Only allow $gostd in all files.
rules:
"deprecated":
# List of file globs that will match this list of settings to compare against.
# By default, if a path is relative, it is relative to the directory where the golangci-lint command is executed.
# The placeholder '${base-path}' is substituted with a path relative to the mode defined with `run.relative-path-mode`.
# The placeholder '${config-path}' is substituted with a path relative to the configuration file.
# Default: $all
files:
- "$all"
# List of packages that are not allowed.
# Entries can be a variable (starting with $), a string prefix, or an exact match (if ending with $).
# Default: []
deny:
- pkg: github.com/golang/protobuf
desc: Use google.golang.org/protobuf instead, see https://developers.google.com/protocol-buffers/docs/reference/go/faq#modules
- pkg: github.com/satori/go.uuid
desc: Use github.com/google/uuid instead, satori's package is not maintained
- pkg: github.com/gofrs/uuid$
desc: Use github.com/gofrs/uuid/v5 or later, it was not a go module before v5
"non-test files":
files:
- "!$test"
deny:
- pkg: math/rand$
desc: Use math/rand/v2 instead, see https://go.dev/blog/randv2
"non-main files":
files:
- "!**/main.go"
deny:
- pkg: log$
desc: Use log/slog instead, see https://go.dev/blog/slog
embeddedstructfieldcheck:
# Checks that sync.Mutex and sync.RWMutex are not used as embedded fields.
# Default: false
forbid-mutex: true
errcheck:
# Report about not checking of errors in type assertions: `a := b.(MyStruct)`.
# Such cases aren't reported by default.
# Default: false
check-type-assertions: true
exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
exhaustive:
# Program elements to check for exhaustiveness.
# Default: [ switch ]
check:
- switch
- map
exhaustruct:
# List of regular expressions to match type names that should be excluded from processing.
# Anonymous structs can be matched by '<anonymous>' alias.
# Has precedence over `include`.
# Each regular expression must match the full type name, including package path.
# For example, to match type `net/http.Cookie` regular expression should be `.*/http\.Cookie`,
# but not `http\.Cookie`.
# Default: []
exclude:
# std libs
- ^net/http.Client$
- ^net/http.Cookie$
- ^net/http.Request$
- ^net/http.Response$
- ^net/http.Server$
- ^net/http.Transport$
- ^net/url.URL$
- ^os/exec.Cmd$
- ^reflect.StructField$
# public libs
- ^github.com/Shopify/sarama.Config$
- ^github.com/Shopify/sarama.ProducerMessage$
- ^github.com/mitchellh/mapstructure.DecoderConfig$
- ^github.com/prometheus/client_golang/.+Opts$
- ^github.com/spf13/cobra.Command$
- ^github.com/spf13/cobra.CompletionOptions$
- ^github.com/stretchr/testify/mock.Mock$
- ^github.com/testcontainers/testcontainers-go.+Request$
- ^github.com/testcontainers/testcontainers-go.FromDockerfile$
- ^golang.org/x/tools/go/analysis.Analyzer$
- ^google.golang.org/protobuf/.+Options$
- ^gopkg.in/yaml.v3.Node$
# Allows empty structures in return statements.
# Default: false
allow-empty-returns: true
funcorder:
# Checks if the exported methods of a structure are placed before the non-exported ones.
# Default: true
struct-method: false
funlen:
# Checks the number of lines in a function.
# If lower than 0, disable the check.
# Default: 60
lines: 100
# Checks the number of statements in a function.
# If lower than 0, disable the check.
# Default: 40
statements: 50
gochecksumtype:
# Presence of `default` case in switch statements satisfies exhaustiveness, if all members are not listed.
# Default: true
default-signifies-exhaustive: false
gocognit:
# Minimal code complexity to report.
# Default: 30 (but we recommend 10-20)
min-complexity: 20
gocritic:
# Settings passed to gocritic.
# The settings key is the name of a supported gocritic checker.
# The list of supported checkers can be found at https://go-critic.com/overview.
settings:
captLocal:
# Whether to restrict checker to params only.
# Default: true
paramsOnly: false
underef:
# Whether to skip (*x).method() calls where x is a pointer receiver.
# Default: true
skipRecvDeref: false
godoclint:
# List of rules to enable in addition to the default set.
# Default: empty
enable:
# Assert no unused link in godocs.
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
- no-unused-link
gosec:
excludes:
- G104 # handled by errcheck
- G301
- G306
govet:
# Enable all analyzers.
# Default: false
enable-all: true
# Disable analyzers by name.
# Run `GL_DEBUG=govet golangci-lint run --enable=govet` to see default, all available analyzers, and enabled analyzers.
# Default: []
disable:
- fieldalignment # too strict
# Settings per analyzer.
settings:
shadow:
# Whether to be strict about shadowing; can be noisy.
# Default: false
strict: true
inamedparam:
# Skips check for interface methods with only a single parameter.
# Default: false
skip-single-param: true
mnd:
ignored-functions:
- args.Error
- flag.Arg
- flag.Duration.*
- flag.Float.*
- flag.Int.*
- flag.Uint.*
- os.Chmod
- os.Mkdir.*
- os.OpenFile
- os.WriteFile
- prometheus.ExponentialBuckets.*
- prometheus.LinearBuckets
ignored-numbers:
- 1
- 2
- 3
- 4
- 8
- 24
- 30
- 365
nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
# Default: 30
max-func-lines: 0
nolintlint:
# Exclude the following linters from requiring an explanation.
# Default: []
allow-no-explanation: [ funlen, gocognit, golines ]
# Enable to require an explanation of nonzero length after each nolint directive.
# Default: false
require-explanation: true
# Enable to require nolint directives to mention the specific linter being suppressed.
# Default: false
require-specific: true
perfsprint:
# Optimizes into strings concatenation.
# Default: true
strconcat: false
reassign:
# Patterns for global variable names that are checked for reassignment.
# See https://github.com/curioswitch/go-reassign#usage
# Default: ["EOF", "Err.*"]
patterns:
- ".*"
rowserrcheck:
# database/sql is always checked.
# Default: []
packages:
- github.com/jmoiron/sqlx
sloglint:
# Enforce not using global loggers.
# Values:
# - "": disabled
# - "all": report all global loggers
# - "default": report only the default slog logger
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#no-global
# Default: ""
no-global: all
# Enforce using methods that accept a context.
# Values:
# - "": disabled
# - "all": report all contextless calls
# - "scope": report only if a context exists in the scope of the outermost function
# https://github.com/go-simpler/sloglint?tab=readme-ov-file#context-only
# Default: ""
context: scope
staticcheck:
# SAxxxx checks in https://staticcheck.dev/docs/configuration/options/#checks
# Example (to disable some checks): [ "all", "-SA1000", "-SA1001"]
# Default: ["all", "-ST1000", "-ST1003", "-ST1016", "-ST1020", "-ST1021", "-ST1022"]
checks:
- all - all
# Incorrect or missing package comment.
# Ignore deprecation warnings in legacy code if needed # https://staticcheck.dev/docs/checks/#ST1000
- linters: - -ST1000
- staticcheck # Use consistent method receiver names.
text: "SA1019" # https://staticcheck.dev/docs/checks/#ST1016
- -ST1016
# Maximum issues count per one linter # Omit embedded fields from selector expression.
max-issues-per-linter: 0 # https://staticcheck.dev/docs/checks/#QF1008
- -QF1008
# Maximum count of issues with the same text # We often explicitly enable old/deprecated ciphers for research.
max-same-issues: 0 - -SA1019
# Covered by revive.
- -ST1003
output: usetesting:
formats: # Enable/disable `os.TempDir()` detections.
- format: colored-line-number # Default: false
path: stdout os-temp-dir: true
print-issued-lines: true
print-linter-name: true exclusions:
# Log a warning if an exclusion rule is unused.
# Default: false
warn-unused: true
# Predefined exclusion rules.
# Default: []
presets:
- std-error-handling
- common-false-positives
rules:
- path: 'ahash/ahash.go'
linters: [ staticcheck, gosec ]
- path: 'twofactor/.*.go'
linters: [ exhaustive, mnd, revive ]
- path: 'backoff/backoff_test.go'
linters: [ testpackage ]
- path: 'dbg/dbg_test.go'
linters: [ testpackage ]
- path: 'log/logger.go'
linters: [ forbidigo ]
- path: 'logging/example_test.go'
linters: [ testableexamples ]
- path: 'main.go'
linters: [ forbidigo, mnd, reassign ]
- path: 'cmd/cruntar/main.go'
linters: [ unparam ]
- source: 'TODO'
linters: [ godot ]
- text: 'should have a package comment'
linters: [ revive ]
- text: 'exported \S+ \S+ should have comment( \(or a comment on this block\))? or be unexported'
linters: [ revive ]
- text: 'package comment should be of the form ".+"'
source: '// ?(nolint|TODO)'
linters: [ revive ]
- text: 'comment on exported \S+ \S+ should be of the form ".+"'
source: '// ?(nolint|TODO)'
linters: [ revive, staticcheck ]
- path: '_test\.go'
linters:
- bodyclose
- dupl
- errcheck
- funlen
- goconst
- gosec
- noctx
- reassign
- wrapcheck

456
.goreleaser.yaml Normal file
View File

@@ -0,0 +1,456 @@
# This is an example .goreleaser.yml file with some sensible defaults.
# Make sure to check the documentation at https://goreleaser.com
# The lines below are called `modelines`. See `:help modeline`
# Feel free to remove those if you don't want/need to use them.
# yaml-language-server: $schema=https://goreleaser.com/static/schema.json
# vim: set ts=2 sw=2 tw=0 fo=cnqoj
version: 2
before:
hooks:
# You may remove this if you don't use go modules.
- go mod tidy
# you may remove this if you don't need go generate
- go generate ./...
builds:
- id: atping
main: ./cmd/atping/main.go
binary: atping
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: ca-signed
main: ./cmd/ca-signed/main.go
binary: ca-signed
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: cert-bundler
main: ./cmd/cert-bundler/main.go
binary: cert-bundler
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: cert-revcheck
main: ./cmd/cert-revcheck/main.go
binary: cert-revcheck
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: certchain
main: ./cmd/certchain/main.go
binary: certchain
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: certdump
main: ./cmd/certdump/main.go
binary: certdump
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: certexpiry
main: ./cmd/certexpiry/main.go
binary: certexpiry
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: certser
main: ./cmd/certser/main.go
binary: certser
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: certverify
main: ./cmd/certverify/main.go
binary: certverify
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: clustersh
main: ./cmd/clustersh/main.go
binary: clustersh
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: cruntar
main: ./cmd/cruntar/main.go
binary: cruntar
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: csrpubdump
main: ./cmd/csrpubdump/main.go
binary: csrpubdump
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: data_sync
main: ./cmd/data_sync/main.go
binary: data_sync
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: diskimg
main: ./cmd/diskimg/main.go
binary: diskimg
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: dumpbytes
main: ./cmd/dumpbytes/main.go
binary: dumpbytes
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: eig
main: ./cmd/eig/main.go
binary: eig
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: fragment
main: ./cmd/fragment/main.go
binary: fragment
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: host
main: ./cmd/host/main.go
binary: host
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: jlp
main: ./cmd/jlp/main.go
binary: jlp
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: kgz
main: ./cmd/kgz/main.go
binary: kgz
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: minmax
main: ./cmd/minmax/main.go
binary: minmax
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: parts
main: ./cmd/parts/main.go
binary: parts
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: pem2bin
main: ./cmd/pem2bin/main.go
binary: pem2bin
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: pembody
main: ./cmd/pembody/main.go
binary: pembody
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: pemit
main: ./cmd/pemit/main.go
binary: pemit
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: readchain
main: ./cmd/readchain/main.go
binary: readchain
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: renfnv
main: ./cmd/renfnv/main.go
binary: renfnv
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: rhash
main: ./cmd/rhash/main.go
binary: rhash
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: rolldie
main: ./cmd/rolldie/main.go
binary: rolldie
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: showimp
main: ./cmd/showimp/main.go
binary: showimp
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: ski
main: ./cmd/ski/main.go
binary: ski
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: sprox
main: ./cmd/sprox/main.go
binary: sprox
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: stealchain
main: ./cmd/stealchain/main.go
binary: stealchain
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: stealchain-server
main: ./cmd/stealchain-server/main.go
binary: stealchain-server
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: subjhash
main: ./cmd/subjhash/main.go
binary: subjhash
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: tlsinfo
main: ./cmd/tlsinfo/main.go
binary: tlsinfo
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: tlskeypair
main: ./cmd/tlskeypair/main.go
binary: tlskeypair
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: utc
main: ./cmd/utc/main.go
binary: utc
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: yamll
main: ./cmd/yamll/main.go
binary: yamll
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
- id: zsearch
main: ./cmd/zsearch/main.go
binary: zsearch
env:
- CGO_ENABLED=0
goos: [linux, darwin]
goarch: [amd64, arm64]
ignore:
- goos: darwin
goarch: amd64
archives:
- formats: [tar.gz]
# this name template makes the OS and Arch compatible with the results of `uname`.
name_template: >-
{{ .ProjectName }}_
{{- title .Os }}_
{{- if eq .Arch "amd64" }}x86_64
{{- else if eq .Arch "386" }}i386
{{- else }}{{ .Arch }}{{ end }}
{{- if .Arm }}v{{ .Arm }}{{ end }}
# use zip for windows archives
format_overrides:
- goos: windows
formats: [zip]
changelog:
sort: asc
filters:
exclude:
- "^docs:"
- "^test:"
gitea_urls:
api: https://git.wntrmute.dev/api/v1
download: https://git.wntrmute.dev/
# set to true if you use a self-signed certificate
skip_tls_verify: false
release:
github:
owner: kyle
name: goutils
footer: >-
---
Released by [GoReleaser](https://github.com/goreleaser/goreleaser).

View File

@@ -1,26 +0,0 @@
arch:
- amd64
- ppc64le
sudo: false
language: go
go:
- tip
- 1.9
jobs:
exclude:
- go: 1.9
arch: amd64
- go: 1.9
arch: ppc64le
script:
- go get golang.org/x/lint/golint
- go get golang.org/x/tools/cmd/cover
- go get github.com/kisom/goutils/...
- go test -cover github.com/kisom/goutils/...
- golint github.com/kisom/goutils/...
notifications:
email:
recipients:
- coder@kyleisom.net
on_success: change
on_failure: change

154
CHANGELOG
View File

@@ -1,27 +1,145 @@
Release 1.2.1 - 2018-09-15 CHANGELOG
+ Add missing format argument to Errorf call in kgz. v1.13.5 - 2025-11-18
Release 1.2.0 - 2018-09-15 Changed:
- build: updating goreleaser config.
+ Adds the kgz command line utility. v1.13.4 - 2025-11-18
Release 1.1.0 - 2017-11-16 Changed:
- build: updating goreleaser config.
+ A number of new command line utilities were added v1.13.3 - 2025-11-18
+ atping Added:
+ cruntar - certlib: introduce `Fetcher` for retrieving certificates.
+ renfnv - lib: `HexEncode` gains a byte-slice output variant.
+ - build: add GoReleaser configuration.
+ ski
+ subjhash
+ yamll
+ new package: ahash Changed:
+ package for loading hashes from an algorithm string - cmd: migrate programs to use `certlib.Fetcher` for certificate retrieval
(includes `certdump`, `ski`, and others).
- cmd/ski: update display mode.
+ new certificate loading functions in the lib package Misc:
- repository fixups and small cleanups.
+ new package: tee v1.13.2 - 2025-11-17
+ emulates tee(1)
Add:
- certlib/bundler: refactor certificate bundling from cmd/cert-bundler
into a separate package.
Changed:
- cmd/cert-bundler: refactor to use bundler package, and update Dockerfile.
v1.13.1 - 2025-11-17
Add:
- Dockerfile for cert-bundler.
v1.13.0 - 2025-11-16
Add:
- cmd/certser: print serial numbers for certificates.
- lib/HexEncode: add a new hex encode function handling multiple output
formats, including with and without colons.
v1.12.4 - 2025-11-16
Changed:
- Linting fixes for twofactor that were previously masked.
v1.12.3 erroneously tagged and pushed
v1.12.2 - 2025-11-16
Changed:
- add rsc.io/qr dependency for twofactor.
v1.12.1 - 2025-11-16
Changed:
- twofactor: Remove go.{mod,sum}.
v1.12.0 - 2025-11-16
Added
- twofactor: the github.com/kisom/twofactor repo has been subtree'd
into this repo.
v1.11.2 - 2025-11-16
Changed
- cmd/ski, cmd/csrpubdump, cmd/tlskeypair: centralize
certificate/private-key/CSR parsing by reusing certlib helpers.
This reduces duplication and improves consistency across commands.
- csr: CSR parsing in the above commands now uses certlib.ParseCSR,
which verifies CSR signatures (behavioral hardening compared to
prior parsing without signature verification).
v1.11.1 - 2025-11-16
Changed
- cmd: complete linting fixes across programs; no functional changes.
v1.11.0 - 2025-11-15
Added
- cache/mru: introduce MRU cache implementation with timestamp utilities.
Changed
- certlib: complete overhaul to simplify APIs and internals.
- repo: widespread linting cleanups across many packages (config, dbg, die,
fileutil, log/logging, mwc, sbuf, seekbuf, tee, testio, etc.).
- cmd: general program cleanups; `cert-bundler` lint fixes.
Removed
- rand: remove unused package.
- testutil: remove unused code.
v1.10.1 — 2025-11-15
Changed
- certlib: major overhaul and refactor.
- repo: linter autofixes ahead of release.
v1.10.0 — 2025-11-14
Added
- cmd: add `cert-revcheck` command.
Changed
- ci/lint: add golangci-lint stage and initial cleanup.
v1.9.1 — 2025-11-15
Fixed
- die: correct calls to `die.With`.
v1.9.0 — 2025-11-14
Added
- cmd: add `cert-bundler` tool.
Changed
- misc: minor updates and maintenance.
v1.8.1 — 2025-11-14
Added
- cmd: add `tlsinfo` tool.
v1.8.0 — 2025-11-14
Baseline
- Initial baseline for this changelog series.

197
LICENSE
View File

@@ -1,19 +1,194 @@
Copyright (c) 2015-2023 Kyle Isom <kyle@tyrfingr.is> Copyright 2025 K. Isom <kyle@imap.cc>
Permission to use, copy, modify, and distribute this software for any Licensed under the Apache License, Version 2.0 (the "License");
purpose with or without fee is hereby granted, provided that the above you may not use this file except in compliance with the License.
copyright notice and this permission notice appear in all copies. You may obtain a copy of the License at
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES http://www.apache.org/licenses/LICENSE-2.0
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
======================================================================= =======================================================================
The backoff package (written during my time at Cloudflare) is released The backoff package (written during my time at Cloudflare) is released
under the following license: under the following license:

View File

@@ -2,39 +2,52 @@ GOUTILS
This is a collection of small utility code I've written in Go; the `cmd/` This is a collection of small utility code I've written in Go; the `cmd/`
directory has a number of command-line utilities. Rather than keep all directory has a number of command-line utilities. Rather than keep all
of these in superfluous repositories of their own, or rewriting them of these in superfluous repositories of their own or rewriting them
for each project, I'm putting them here. for each project, I'm putting them here.
The project can be built with the standard Go tooling, or it can be built The project can be built with the standard Go tooling.
with Bazel.
Contents: Contents:
ahash/ Provides hashes from string algorithm specifiers. ahash/ Provides hashes from string algorithm specifiers.
assert/ Error handling, assertion-style. assert/ Error handling, assertion-style.
backoff/ Implementation of an intelligent backoff strategy. backoff/ Implementation of an intelligent backoff strategy.
cache/ Implementations of various caches.
lru/ Least-recently-used cache.
mru/ Most-recently-used cache.
certlib/ Library for working with TLS certificates.
cmd/ cmd/
atping/ Automated TCP ping, meant for putting in cronjobs. atping/ Automated TCP ping, meant for putting in cronjobs.
certchain/ Display the certificate chain from a ca-signed/ Validate whether a certificate is signed by a CA.
TLS connection. cert-bundler/
Create certificate bundles from a source of PEM
certificates.
cert-revcheck/
Check whether a certificate has been revoked or is
expired.
certchain/ Display the certificate chain from a TLS connection.
certdump/ Dump certificate information. certdump/ Dump certificate information.
certexpiry/ Print a list of certificate subjects and expiry times certexpiry/ Print a list of certificate subjects and expiry times
or warn about certificates expiring within a certain or warn about certificates expiring within a certain
window. window.
certverify/ Verify a TLS X.509 certificate, optionally printing certverify/ Verify a TLS X.509 certificate file, optionally printing
the time to expiry and checking for revocations. the time to expiry and checking for revocations.
clustersh/ Run commands or transfer files across multiple clustersh/ Run commands or transfer files across multiple
servers via SSH. servers via SSH.
cruntar/ Untar an archive with hard links, copying instead of cruntar/ (Un)tar an archive with hard links, copying instead of
linking. linking.
csrpubdump/ Dump the public key from an X.509 certificate request. csrpubdump/ Dump the public key from an X.509 certificate request.
data_sync/ Sync the user's homedir to external storage. data_sync/ Sync the user's homedir to external storage.
diskimg/ Write a disk image to a device. diskimg/ Write a disk image to a device.
dumpbytes/ Dump the contents of a file as hex bytes, printing it as
a Go []byte literal.
eig/ EEPROM image generator. eig/ EEPROM image generator.
fragment/ Print a fragment of a file. fragment/ Print a fragment of a file.
host/ Go imlpementation of the host(1) command.
jlp/ JSON linter/prettifier. jlp/ JSON linter/prettifier.
kgz/ Custom gzip compressor / decompressor that handles 99% kgz/ Custom gzip compressor / decompressor that handles 99%
of my use cases. of my use cases.
minmax/ Generate a minmax code for use in uLisp.
parts/ Simple parts database management for my collection of parts/ Simple parts database management for my collection of
electronic components. electronic components.
pem2bin/ Dump the binary body of a PEM-encoded block. pem2bin/ Dump the binary body of a PEM-encoded block.
@@ -44,37 +57,79 @@ Contents:
in a bundle. in a bundle.
renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash. renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash.
rhash/ Compute the digest of remote files. rhash/ Compute the digest of remote files.
rolldie/ Roll some dice.
showimp/ List the external (e.g. non-stdlib and outside the showimp/ List the external (e.g. non-stdlib and outside the
current working directory) imports for a Go file. current working directory) imports for a Go file.
ski Display the SKI for PEM-encoded TLS material. ski Display the SKI for PEM-encoded TLS material.
sprox/ Simple TCP proxy. sprox/ Simple TCP proxy.
stealchain/ Dump the verified chain from a TLS stealchain/ Dump the verified chain from a TLS connection to a
connection to a server. server.
stealchain- Dump the verified chain from a TLS stealchain-server/
server/ connection from a client. Dump the verified chain from a TLS connection from
from a client.
subjhash/ Print or match subject info from a certificate. subjhash/ Print or match subject info from a certificate.
tlsinfo/ Print information about a TLS connection (the TLS version
and cipher suite).
tlskeypair/ Check whether a TLS certificate and key file match. tlskeypair/ Check whether a TLS certificate and key file match.
utc/ Convert times to UTC. utc/ Convert times to UTC.
yamll/ A small YAML linter. yamll/ A small YAML linter.
zsearch/ Search for a string in directory of gzipped files.
config/ A simple global configuration system where configuration config/ A simple global configuration system where configuration
data is pulled from a file or an environment variable data is pulled from a file or an environment variable
transparently. transparently.
iniconf/ A simple INI-style configuration system.
dbg/ A debug printer. dbg/ A debug printer.
die/ Death of a program. die/ Death of a program.
fileutil/ Common file functions. fileutil/ Common file functions.
lib/ Commonly-useful functions for writing Go programs. lib/ Commonly-useful functions for writing Go programs.
log/ A syslog library.
logging/ A logging library. logging/ A logging library.
mwc/ MultiwriteCloser implementation. mwc/ MultiwriteCloser implementation.
rand/ Utilities for working with math/rand.
sbuf/ A byte buffer that can be wiped. sbuf/ A byte buffer that can be wiped.
seekbuf/ A read-seekable byte buffer. seekbuf/ A read-seekable byte buffer.
syslog/ Syslog-type logging. syslog/ Syslog-type logging.
tee/ Emulate tee(1)'s functionality in io.Writers. tee/ Emulate tee(1)'s functionality in io.Writers.
testio/ Various I/O utilities useful during testing. testio/ Various I/O utilities useful during testing.
testutil/ Various utility functions useful during testing. twofactor/ Two-factor authentication.
Each program should have a small README in the directory with more Each program should have a small README in the directory with more
information. information.
All code here is licensed under the ISC license. All code here is licensed under the Apache 2.0 license.
Error handling
--------------
This repo standardizes on Go 1.13+ error wrapping and matching. Libraries and
CLIs should:
- Wrap causes with context using `fmt.Errorf("context: %w", err)`.
- Use typed, structured errors from `certlib/certerr` for certificate-related
operations. These include a typed `*certerr.Error` with `Source` and `Kind`.
- Match errors programmatically:
- `errors.Is(err, certerr.ErrEncryptedPrivateKey)` to detect sentinel states.
- `errors.As(err, &e)` (where `var e *certerr.Error`) to inspect
`e.Source`/`e.Kind`.
Examples:
```
cert, err := certlib.LoadCertificate(path)
if err != nil {
// sentinel match:
if errors.Is(err, certerr.ErrEmptyCertificate) {
// handle empty input
}
// typed error match
var ce *certerr.Error
if errors.As(err, &ce) {
switch ce.Kind {
case certerr.KindParse:
// parse error handling
case certerr.KindLoad:
// file loading error handling
}
}
}
```

View File

@@ -4,8 +4,8 @@
package ahash package ahash
import ( import (
"crypto/md5" "crypto/md5" // #nosec G505
"crypto/sha1" "crypto/sha1" // #nosec G501
"crypto/sha256" "crypto/sha256"
"crypto/sha512" "crypto/sha512"
"errors" "errors"
@@ -17,34 +17,15 @@ import (
"io" "io"
"sort" "sort"
"git.wntrmute.dev/kyle/goutils/assert"
"golang.org/x/crypto/blake2b" "golang.org/x/crypto/blake2b"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/md4" "golang.org/x/crypto/md4" // #nosec G506
"golang.org/x/crypto/ripemd160" "golang.org/x/crypto/ripemd160" // #nosec G507
"golang.org/x/crypto/sha3" "golang.org/x/crypto/sha3"
"git.wntrmute.dev/kyle/goutils/assert"
) )
func sha224Slicer(bs []byte) []byte {
sum := sha256.Sum224(bs)
return sum[:]
}
func sha256Slicer(bs []byte) []byte {
sum := sha256.Sum256(bs)
return sum[:]
}
func sha384Slicer(bs []byte) []byte {
sum := sha512.Sum384(bs)
return sum[:]
}
func sha512Slicer(bs []byte) []byte {
sum := sha512.Sum512(bs)
return sum[:]
}
// Hash represents a generic hash function that may or may not be secure. It // Hash represents a generic hash function that may or may not be secure. It
// satisfies the hash.Hash interface. // satisfies the hash.Hash interface.
type Hash struct { type Hash struct {
@@ -247,17 +228,17 @@ func init() {
// HashList returns a sorted list of all the hash algorithms supported by the // HashList returns a sorted list of all the hash algorithms supported by the
// package. // package.
func HashList() []string { func HashList() []string {
return hashList[:] return hashList
} }
// SecureHashList returns a sorted list of all the secure (cryptographic) hash // SecureHashList returns a sorted list of all the secure (cryptographic) hash
// algorithms supported by the package. // algorithms supported by the package.
func SecureHashList() []string { func SecureHashList() []string {
return secureHashList[:] return secureHashList
} }
// InsecureHashList returns a sorted list of all the insecure hash algorithms // InsecureHashList returns a sorted list of all the insecure hash algorithms
// supported by the package. // supported by the package.
func InsecureHashList() []string { func InsecureHashList() []string {
return insecureHashList[:] return insecureHashList
} }

View File

@@ -1,16 +1,18 @@
package ahash package ahash_test
import ( import (
"bytes" "bytes"
"encoding/hex"
"fmt" "fmt"
"testing" "testing"
"git.wntrmute.dev/kyle/goutils/ahash"
"git.wntrmute.dev/kyle/goutils/assert" "git.wntrmute.dev/kyle/goutils/assert"
) )
func TestSecureHash(t *testing.T) { func TestSecureHash(t *testing.T) {
algo := "sha256" algo := "sha256"
h, err := New(algo) h, err := ahash.New(algo)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, h.IsSecure(), algo+" should be a secure hash") assert.BoolT(t, h.IsSecure(), algo+" should be a secure hash")
assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo") assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo")
@@ -19,28 +21,28 @@ func TestSecureHash(t *testing.T) {
var data []byte var data []byte
var expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" var expected = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
sum, err := Sum(algo, data) sum, err := ahash.Sum(algo, data)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, fmt.Sprintf("%x", sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum)) assert.BoolT(t, hex.EncodeToString(sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum))
data = []byte("hello, world") data = []byte("hello, world")
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
expected = "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b" expected = "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b"
sum, err = SumReader(algo, buf) sum, err = ahash.SumReader(algo, buf)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, fmt.Sprintf("%x", sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum)) assert.BoolT(t, hex.EncodeToString(sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum))
data = []byte("hello world") data = []byte("hello world")
_, err = h.Write(data) _, err = h.Write(data)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
unExpected := "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b" unExpected := "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b"
sum = h.Sum(nil) sum = h.Sum(nil)
assert.BoolT(t, fmt.Sprintf("%x", sum) != unExpected, fmt.Sprintf("hash shouldn't have returned %x", unExpected)) assert.BoolT(t, hex.EncodeToString(sum) != unExpected, fmt.Sprintf("hash shouldn't have returned %x", unExpected))
} }
func TestInsecureHash(t *testing.T) { func TestInsecureHash(t *testing.T) {
algo := "md5" algo := "md5"
h, err := New(algo) h, err := ahash.New(algo)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash") assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash")
assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo") assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo")
@@ -49,28 +51,28 @@ func TestInsecureHash(t *testing.T) {
var data []byte var data []byte
var expected = "d41d8cd98f00b204e9800998ecf8427e" var expected = "d41d8cd98f00b204e9800998ecf8427e"
sum, err := Sum(algo, data) sum, err := ahash.Sum(algo, data)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, fmt.Sprintf("%x", sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum)) assert.BoolT(t, hex.EncodeToString(sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum))
data = []byte("hello, world") data = []byte("hello, world")
buf := bytes.NewBuffer(data) buf := bytes.NewBuffer(data)
expected = "e4d7f1b4ed2e42d15898f4b27b019da4" expected = "e4d7f1b4ed2e42d15898f4b27b019da4"
sum, err = SumReader(algo, buf) sum, err = ahash.SumReader(algo, buf)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, fmt.Sprintf("%x", sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum)) assert.BoolT(t, hex.EncodeToString(sum) == expected, fmt.Sprintf("expected hash %s but have %x", expected, sum))
data = []byte("hello world") data = []byte("hello world")
_, err = h.Write(data) _, err = h.Write(data)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
unExpected := "e4d7f1b4ed2e42d15898f4b27b019da4" unExpected := "e4d7f1b4ed2e42d15898f4b27b019da4"
sum = h.Sum(nil) sum = h.Sum(nil)
assert.BoolT(t, fmt.Sprintf("%x", sum) != unExpected, fmt.Sprintf("hash shouldn't have returned %x", unExpected)) assert.BoolT(t, hex.EncodeToString(sum) != unExpected, fmt.Sprintf("hash shouldn't have returned %x", unExpected))
} }
func TestHash32(t *testing.T) { func TestHash32(t *testing.T) {
algo := "crc32-ieee" algo := "crc32-ieee"
h, err := New(algo) h, err := ahash.New(algo)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash") assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash")
assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo") assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo")
@@ -102,7 +104,7 @@ func TestHash32(t *testing.T) {
func TestHash64(t *testing.T) { func TestHash64(t *testing.T) {
algo := "crc64" algo := "crc64"
h, err := New(algo) h, err := ahash.New(algo)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash") assert.BoolT(t, !h.IsSecure(), algo+" shouldn't be a secure hash")
assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo") assert.BoolT(t, h.HashAlgo() == algo, "hash returned the wrong HashAlgo")
@@ -133,9 +135,9 @@ func TestHash64(t *testing.T) {
} }
func TestListLengthSanity(t *testing.T) { func TestListLengthSanity(t *testing.T) {
all := HashList() all := ahash.HashList()
secure := SecureHashList() secure := ahash.SecureHashList()
insecure := InsecureHashList() insecure := ahash.InsecureHashList()
assert.BoolT(t, len(all) == len(secure)+len(insecure)) assert.BoolT(t, len(all) == len(secure)+len(insecure))
} }
@@ -146,11 +148,11 @@ func TestSumLimitedReader(t *testing.T) {
extendedData := bytes.NewBufferString("hello, world! this is an extended message") extendedData := bytes.NewBufferString("hello, world! this is an extended message")
expected := "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b" expected := "09ca7e4eaa6e8ae9c7d261167129184883644d07dfba7cbfbc4c8a2e08360d5b"
hash, err := SumReader("sha256", data) hash, err := ahash.SumReader("sha256", data)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, fmt.Sprintf("%x", hash) == expected, fmt.Sprintf("have hash %x, want %s", hash, expected)) assert.BoolT(t, hex.EncodeToString(hash) == expected, fmt.Sprintf("have hash %x, want %s", hash, expected))
extendedHash, err := SumLimitedReader("sha256", extendedData, int64(dataLen)) extendedHash, err := ahash.SumLimitedReader("sha256", extendedData, int64(dataLen))
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
assert.BoolT(t, bytes.Equal(hash, extendedHash), fmt.Sprintf("have hash %x, want %x", extendedHash, hash)) assert.BoolT(t, bytes.Equal(hash, extendedHash), fmt.Sprintf("have hash %x, want %x", extendedHash, hash))

View File

@@ -9,6 +9,7 @@
package assert package assert
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
"runtime" "runtime"
@@ -16,11 +17,13 @@ import (
"testing" "testing"
) )
const callerSkip = 2
// NoDebug can be set to true to cause all asserts to be ignored. // NoDebug can be set to true to cause all asserts to be ignored.
var NoDebug bool var NoDebug bool
func die(what string, a ...string) { func die(what string, a ...string) {
_, file, line, ok := runtime.Caller(2) _, file, line, ok := runtime.Caller(callerSkip)
if !ok { if !ok {
panic(what) panic(what)
} }
@@ -31,30 +34,32 @@ func die(what string, a ...string) {
s = ": " + s s = ": " + s
} }
panic(what + s) panic(what + s)
} else {
fmt.Fprintf(os.Stderr, "%s", what)
if len(a) > 0 {
s := strings.Join(a, ", ")
fmt.Fprintln(os.Stderr, ": "+s)
} else {
fmt.Fprintf(os.Stderr, "\n")
}
fmt.Fprintf(os.Stderr, "\t%s line %d\n", file, line)
os.Exit(1)
} }
fmt.Fprintf(os.Stderr, "%s", what)
if len(a) > 0 {
s := strings.Join(a, ", ")
fmt.Fprintln(os.Stderr, ": "+s)
} else {
fmt.Fprintf(os.Stderr, "\n")
}
fmt.Fprintf(os.Stderr, "\t%s line %d\n", file, line)
os.Exit(1)
} }
// Bool asserts that cond is false. // Bool asserts that cond is false.
// //
// For example, this would replace // For example, this would replace
// if x < 0 { //
// log.Fatal("x is subzero") // if x < 0 {
// } // log.Fatal("x is subzero")
// }
// //
// The same assertion would be // The same assertion would be
// assert.Bool(x, "x is subzero") //
// assert.Bool(x, "x is subzero")
func Bool(cond bool, s ...string) { func Bool(cond bool, s ...string) {
if NoDebug { if NoDebug {
return return
@@ -68,11 +73,12 @@ func Bool(cond bool, s ...string) {
// Error asserts that err is not nil, e.g. that an error has occurred. // Error asserts that err is not nil, e.g. that an error has occurred.
// //
// For example, // For example,
// if err == nil { //
// log.Fatal("call to <something> should have failed") // if err == nil {
// } // log.Fatal("call to <something> should have failed")
// // becomes // }
// assert.Error(err, "call to <something> should have failed") // // becomes
// assert.Error(err, "call to <something> should have failed")
func Error(err error, s ...string) { func Error(err error, s ...string) {
if NoDebug { if NoDebug {
return return
@@ -100,7 +106,7 @@ func NoError(err error, s ...string) {
// ErrorEq asserts that the actual error is the expected error. // ErrorEq asserts that the actual error is the expected error.
func ErrorEq(expected, actual error) { func ErrorEq(expected, actual error) {
if NoDebug || (expected == actual) { if NoDebug || (errors.Is(expected, actual)) {
return return
} }
@@ -155,7 +161,7 @@ func NoErrorT(t *testing.T, err error) {
// ErrorEqT compares a pair of errors, calling Fatal on it if they // ErrorEqT compares a pair of errors, calling Fatal on it if they
// don't match. // don't match.
func ErrorEqT(t *testing.T, expected, actual error) { func ErrorEqT(t *testing.T, expected, actual error) {
if NoDebug || (expected == actual) { if NoDebug || (errors.Is(expected, actual)) {
return return
} }

View File

@@ -10,29 +10,21 @@
// backoff is configured with a maximum duration that will not be // backoff is configured with a maximum duration that will not be
// exceeded. // exceeded.
// //
// The `New` function will attempt to use the system's cryptographic // This package uses math/rand/v2 for jitter, which is automatically
// random number generator to seed a Go math/rand random number // seeded from a cryptographically secure source.
// source. If this fails, the package will panic on startup.
package backoff package backoff
import ( import (
"crypto/rand"
"encoding/binary"
"io"
"math" "math"
mrand "math/rand" "math/rand/v2"
"sync"
"time" "time"
) )
var prngMu sync.Mutex
var prng *mrand.Rand
// DefaultInterval is used when a Backoff is initialised with a // DefaultInterval is used when a Backoff is initialised with a
// zero-value Interval. // zero-value Interval.
var DefaultInterval = 5 * time.Minute var DefaultInterval = 5 * time.Minute
// DefaultMaxDuration is maximum amount of time that the backoff will // DefaultMaxDuration is the maximum amount of time that the backoff will
// delay for. // delay for.
var DefaultMaxDuration = 6 * time.Hour var DefaultMaxDuration = 6 * time.Hour
@@ -50,10 +42,9 @@ type Backoff struct {
// interval controls the time step for backing off. // interval controls the time step for backing off.
interval time.Duration interval time.Duration
// noJitter controls whether to use the "Full Jitter" // noJitter controls whether to use the "Full Jitter" improvement to attempt
// improvement to attempt to smooth out spikes in a high // to smooth out spikes in a high-contention scenario. If noJitter is set to
// contention scenario. If noJitter is set to true, no // true, no jitter will be introduced.
// jitter will be introduced.
noJitter bool noJitter bool
// decay controls the decay of n. If it is non-zero, n is // decay controls the decay of n. If it is non-zero, n is
@@ -65,17 +56,17 @@ type Backoff struct {
lastTry time.Time lastTry time.Time
} }
// New creates a new backoff with the specified max duration and // New creates a new backoff with the specified maxDuration duration and
// interval. Zero values may be used to use the default values. // interval. Zero values may be used to use the default values.
// //
// Panics if either max or interval is negative. // Panics if either dMax or interval is negative.
func New(max time.Duration, interval time.Duration) *Backoff { func New(dMax time.Duration, interval time.Duration) *Backoff {
if max < 0 || interval < 0 { if dMax < 0 || interval < 0 {
panic("backoff: max or interval is negative") panic("backoff: dMax or interval is negative")
} }
b := &Backoff{ b := &Backoff{
maxDuration: max, maxDuration: dMax,
interval: interval, interval: interval,
} }
b.setup() b.setup()
@@ -84,27 +75,12 @@ func New(max time.Duration, interval time.Duration) *Backoff {
// NewWithoutJitter works similarly to New, except that the created // NewWithoutJitter works similarly to New, except that the created
// Backoff will not use jitter. // Backoff will not use jitter.
func NewWithoutJitter(max time.Duration, interval time.Duration) *Backoff { func NewWithoutJitter(dMax time.Duration, interval time.Duration) *Backoff {
b := New(max, interval) b := New(dMax, interval)
b.noJitter = true b.noJitter = true
return b return b
} }
func init() {
var buf [8]byte
var n int64
_, err := io.ReadFull(rand.Reader, buf[:])
if err != nil {
panic(err.Error())
}
n = int64(binary.LittleEndian.Uint64(buf[:]))
src := mrand.NewSource(n)
prng = mrand.New(src)
}
func (b *Backoff) setup() { func (b *Backoff) setup() {
if b.interval == 0 { if b.interval == 0 {
b.interval = DefaultInterval b.interval = DefaultInterval
@@ -122,35 +98,44 @@ func (b *Backoff) Duration() time.Duration {
b.decayN() b.decayN()
t := b.duration(b.n) d := b.duration(b.n)
if b.n < math.MaxUint64 { if b.n < math.MaxUint64 {
b.n++ b.n++
} }
if !b.noJitter { if !b.noJitter {
prngMu.Lock() d = time.Duration(rand.Int64N(int64(d))) // #nosec G404
t = time.Duration(prng.Int63n(int64(t)))
prngMu.Unlock()
} }
return t return d
} }
const maxN uint64 = 63
// requires b to be locked. // requires b to be locked.
func (b *Backoff) duration(n uint64) (t time.Duration) { func (b *Backoff) duration(n uint64) time.Duration {
// Saturate pow // Use left shift on the underlying integer representation to avoid
pow := time.Duration(math.MaxInt64) // multiplying time.Duration by time.Duration (which is semantically
if n < 63 { // incorrect and flagged by linters).
pow = 1 << n if n >= maxN {
// Saturate when n would overflow a 64-bit shift or exceed maxDuration.
return b.maxDuration
} }
t = b.interval * pow // Calculate 2^n * interval using a shift. Detect overflow by checking
if t/pow != b.interval || t > b.maxDuration { // for sign change or monotonicity loss and clamp to maxDuration.
t = b.maxDuration shifted := b.interval << n
if shifted < 0 || shifted < b.interval {
// Overflow occurred during the shift; clamp to maxDuration.
return b.maxDuration
} }
return if shifted > b.maxDuration {
return b.maxDuration
}
return shifted
} }
// Reset resets the attempt counter of a backoff. // Reset resets the attempt counter of a backoff.
@@ -174,7 +159,7 @@ func (b *Backoff) SetDecay(decay time.Duration) {
b.decay = decay b.decay = decay
} }
// requires b to be locked // requires b to be locked.
func (b *Backoff) decayN() { func (b *Backoff) decayN() {
if b.decay == 0 { if b.decay == 0 {
return return
@@ -186,7 +171,9 @@ func (b *Backoff) decayN() {
} }
lastDuration := b.duration(b.n - 1) lastDuration := b.duration(b.n - 1)
decayed := time.Since(b.lastTry) > lastDuration+b.decay // Reset when the elapsed time is at least the previous backoff plus decay.
// Using ">=" avoids boundary flakiness in tests and real usage.
decayed := time.Since(b.lastTry) >= lastDuration+b.decay
b.lastTry = time.Now() b.lastTry = time.Now()
if !decayed { if !decayed {

View File

@@ -9,7 +9,7 @@ import (
// If given New with 0's and no jitter, ensure that certain invariants are met: // If given New with 0's and no jitter, ensure that certain invariants are met:
// //
// - the default max duration and interval should be used // - the default maxDuration duration and interval should be used
// - noJitter should be true // - noJitter should be true
// - the RNG should not be initialised // - the RNG should not be initialised
// - the first duration should be equal to the default interval // - the first duration should be equal to the default interval
@@ -17,7 +17,11 @@ func TestDefaults(t *testing.T) {
b := NewWithoutJitter(0, 0) b := NewWithoutJitter(0, 0)
if b.maxDuration != DefaultMaxDuration { if b.maxDuration != DefaultMaxDuration {
t.Fatalf("expected new backoff to use the default max duration (%s), but have %s", DefaultMaxDuration, b.maxDuration) t.Fatalf(
"expected new backoff to use the default maxDuration duration (%s), but have %s",
DefaultMaxDuration,
b.maxDuration,
)
} }
if b.interval != DefaultInterval { if b.interval != DefaultInterval {
@@ -44,11 +48,11 @@ func TestSetup(t *testing.T) {
} }
} }
// Ensure that tries incremenets as expected. // Ensure that tries increments as expected.
func TestTries(t *testing.T) { func TestTries(t *testing.T) {
b := NewWithoutJitter(5, 1) b := NewWithoutJitter(5, 1)
for i := uint64(0); i < 3; i++ { for i := range uint64(3) {
if b.n != i { if b.n != i {
t.Fatalf("want tries=%d, have tries=%d", i, b.n) t.Fatalf("want tries=%d, have tries=%d", i, b.n)
} }
@@ -73,7 +77,7 @@ func TestTries(t *testing.T) {
func TestReset(t *testing.T) { func TestReset(t *testing.T) {
const iter = 10 const iter = 10
b := New(1000, 1) b := New(1000, 1)
for i := 0; i < iter; i++ { for range iter {
_ = b.Duration() _ = b.Duration()
} }
@@ -87,18 +91,18 @@ func TestReset(t *testing.T) {
} }
} }
const decay = 5 * time.Millisecond const decay = time.Second
const max = 10 * time.Millisecond const maxDuration = 10 * time.Millisecond
const interval = time.Millisecond const interval = time.Millisecond
func TestDecay(t *testing.T) { func TestDecay(t *testing.T) {
const iter = 10 const iter = 10
b := NewWithoutJitter(max, 1) b := NewWithoutJitter(maxDuration, 1)
b.SetDecay(decay) b.SetDecay(decay)
var backoff time.Duration var backoff time.Duration
for i := 0; i < iter; i++ { for range iter {
backoff = b.Duration() backoff = b.Duration()
} }
@@ -127,7 +131,7 @@ func TestDecaySaturation(t *testing.T) {
b.SetDecay(decay) b.SetDecay(decay)
var duration time.Duration var duration time.Duration
for i := 0; i <= 2; i++ { for range 3 {
duration = b.Duration() duration = b.Duration()
} }
@@ -145,7 +149,7 @@ func TestDecaySaturation(t *testing.T) {
} }
func ExampleBackoff_SetDecay() { func ExampleBackoff_SetDecay() {
b := NewWithoutJitter(max, interval) b := NewWithoutJitter(maxDuration, interval)
b.SetDecay(decay) b.SetDecay(decay)
// try 0 // try 0

179
cache/lru/lru.go vendored Normal file
View File

@@ -0,0 +1,179 @@
// Package lru implements a Least Recently Used cache.
package lru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the least recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new LRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("LRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("lru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("lru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("lru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

87
cache/lru/lru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,87 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests mirror the MRU-style behavior present in this LRU package
// implementation (eviction removes the most-recently-used entry).
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
// Since this implementation evicts the most-recently-used item, inserting
// "goat" when full evicts "owl" (the most recent at that time).
mock.Add(time.Second)
if _, ok := c.Get("owl"); ok {
t.Fatal("store should not have an entry for owl (MRU-evicted)")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
if !c.Has("raven") {
t.Fatal("store should contain an entry for 'raven'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/lru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package lru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t > ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("lru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("lru: duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("lru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

50
cache/lru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,50 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests validate timestamps ordering semantics for the LRU package.
// Note: The LRU timestamps are sorted with most-recent-first (descending by t).
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// make owl the most recent
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
if ts.K(0) != "owl" {
t.Fatalf("first key should be owl, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "raven" {
t.Fatalf("third key should be raven, have %s", ts.K(2))
}
}

178
cache/mru/mru.go vendored Normal file
View File

@@ -0,0 +1,178 @@
package mru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the most recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new MRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("MRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("mru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("mru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("mru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

92
cache/mru/mru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,92 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
v, ok := c.Get("owl")
if !ok {
t.Fatal("store should have an entry for owl")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
itm := v
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if itm != 2 {
t.Fatalf("stored item should be 2, have %d", itm)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
if !c.Has("owl") {
t.Fatal("store should contain an entry for 'owl'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/mru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package mru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t < ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("mru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("mru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

49
cache/mru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,49 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Millisecond)
// raven, goat, owl
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// at this point, the keys should be raven, goat, owl.
if ts.K(0) != "raven" {
t.Fatalf("first key should be raven, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "owl" {
t.Fatalf("third key should be owl, have %s", ts.K(2))
}
}

677
certlib/bundler/bundler.go Normal file
View File

@@ -0,0 +1,677 @@
package bundler
import (
"archive/tar"
"archive/zip"
"compress/gzip"
"crypto/sha256"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"sort"
"strings"
"time"
"gopkg.in/yaml.v2"
"git.wntrmute.dev/kyle/goutils/certlib"
)
const defaultFileMode = 0644
// Config represents the top-level YAML configuration.
type Config struct {
Config struct {
Hashes string `yaml:"hashes"`
Expiry string `yaml:"expiry"`
} `yaml:"config"`
Chains map[string]ChainGroup `yaml:"chains"`
}
// ChainGroup represents a named group of certificate chains.
type ChainGroup struct {
Certs []CertChain `yaml:"certs"`
Outputs Outputs `yaml:"outputs"`
}
// CertChain represents a root certificate and its intermediates.
type CertChain struct {
Root string `yaml:"root"`
Intermediates []string `yaml:"intermediates"`
}
// Outputs defines output format options.
type Outputs struct {
IncludeSingle bool `yaml:"include_single"`
IncludeIndividual bool `yaml:"include_individual"`
Manifest bool `yaml:"manifest"`
Formats []string `yaml:"formats"`
Encoding string `yaml:"encoding"`
}
var formatExtensions = map[string]string{
"zip": ".zip",
"tgz": ".tar.gz",
}
// Run performs the bundling operation given a config file path and an output directory.
func Run(configFile string, outputDir string) error {
if configFile == "" {
return errors.New("configuration file required")
}
cfg, err := loadConfig(configFile)
if err != nil {
return fmt.Errorf("loading config: %w", err)
}
expiryDuration := 365 * 24 * time.Hour
if cfg.Config.Expiry != "" {
expiryDuration, err = parseDuration(cfg.Config.Expiry)
if err != nil {
return fmt.Errorf("parsing expiry: %w", err)
}
}
if err = os.MkdirAll(outputDir, 0750); err != nil {
return fmt.Errorf("creating output directory: %w", err)
}
totalFormats := 0
for _, group := range cfg.Chains {
totalFormats += len(group.Outputs.Formats)
}
createdFiles := make([]string, 0, totalFormats)
for groupName, group := range cfg.Chains {
files, perr := processChainGroup(groupName, group, expiryDuration, outputDir)
if perr != nil {
return fmt.Errorf("processing chain group %s: %w", groupName, perr)
}
createdFiles = append(createdFiles, files...)
}
if cfg.Config.Hashes != "" {
hashFile := filepath.Join(outputDir, cfg.Config.Hashes)
if gerr := generateHashFile(hashFile, createdFiles); gerr != nil {
return fmt.Errorf("generating hash file: %w", gerr)
}
}
return nil
}
func loadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if uerr := yaml.Unmarshal(data, &cfg); uerr != nil {
return nil, uerr
}
return &cfg, nil
}
func parseDuration(s string) (time.Duration, error) {
// Support simple formats like "1y", "6m", "30d"
if len(s) < 2 {
return 0, fmt.Errorf("invalid duration format: %s", s)
}
unit := s[len(s)-1]
value := s[:len(s)-1]
var multiplier time.Duration
switch unit {
case 'y', 'Y':
multiplier = 365 * 24 * time.Hour
case 'm', 'M':
multiplier = 30 * 24 * time.Hour
case 'd', 'D':
multiplier = 24 * time.Hour
default:
return time.ParseDuration(s)
}
var num int
_, err := fmt.Sscanf(value, "%d", &num)
if err != nil {
return 0, fmt.Errorf("invalid duration value: %s", s)
}
return time.Duration(num) * multiplier, nil
}
func processChainGroup(
groupName string,
group ChainGroup,
expiryDuration time.Duration,
outputDir string,
) ([]string, error) {
// Default encoding to "pem" if not specified
encoding := group.Outputs.Encoding
if encoding == "" {
encoding = "pem"
}
// Collect certificates from all chains in the group
singleFileCerts, individualCerts, sourcePaths, err := loadAndCollectCerts(
group.Certs,
group.Outputs,
expiryDuration,
)
if err != nil {
return nil, err
}
// Prepare files for inclusion in archives
archiveFiles, err := prepareArchiveFiles(singleFileCerts, individualCerts, sourcePaths, group.Outputs, encoding)
if err != nil {
return nil, err
}
// Create archives for the entire group
createdFiles, err := createArchiveFiles(groupName, group.Outputs.Formats, archiveFiles, outputDir)
if err != nil {
return nil, err
}
return createdFiles, nil
}
// loadAndCollectCerts loads all certificates from chains and collects them for processing.
func loadAndCollectCerts(
chains []CertChain,
outputs Outputs,
expiryDuration time.Duration,
) ([]*x509.Certificate, []certWithPath, []string, error) {
var singleFileCerts []*x509.Certificate
var individualCerts []certWithPath
var sourcePaths []string
for _, chain := range chains {
s, i, cerr := collectFromChain(chain, outputs, expiryDuration)
if cerr != nil {
return nil, nil, nil, cerr
}
if len(s) > 0 {
singleFileCerts = append(singleFileCerts, s...)
}
if len(i) > 0 {
individualCerts = append(individualCerts, i...)
}
// Record source paths for timestamp preservation
// Only append when loading succeeded
sourcePaths = append(sourcePaths, chain.Root)
sourcePaths = append(sourcePaths, chain.Intermediates...)
}
return singleFileCerts, individualCerts, sourcePaths, nil
}
// collectFromChain loads a single chain, performs checks, and returns the certs to include.
func collectFromChain(
chain CertChain,
outputs Outputs,
expiryDuration time.Duration,
) (
[]*x509.Certificate,
[]certWithPath,
error,
) {
var single []*x509.Certificate
var indiv []certWithPath
// Load root certificate
rootCert, rerr := certlib.LoadCertificate(chain.Root)
if rerr != nil {
return nil, nil, fmt.Errorf("failed to load root certificate %s: %w", chain.Root, rerr)
}
// Check expiry for root
checkExpiry(chain.Root, rootCert, expiryDuration)
// Add root to collections if needed
if outputs.IncludeSingle {
single = append(single, rootCert)
}
if outputs.IncludeIndividual {
indiv = append(indiv, certWithPath{cert: rootCert, path: chain.Root})
}
// Load and validate intermediates
for _, intPath := range chain.Intermediates {
intCert, lerr := certlib.LoadCertificate(intPath)
if lerr != nil {
return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %w", intPath, lerr)
}
// Validate that intermediate is signed by root
if sigErr := intCert.CheckSignatureFrom(rootCert); sigErr != nil {
return nil, nil, fmt.Errorf(
"intermediate %s is not properly signed by root %s: %w",
intPath,
chain.Root,
sigErr,
)
}
// Check expiry for intermediate
checkExpiry(intPath, intCert, expiryDuration)
// Add intermediate to collections if needed
if outputs.IncludeSingle {
single = append(single, intCert)
}
if outputs.IncludeIndividual {
indiv = append(indiv, certWithPath{cert: intCert, path: intPath})
}
}
return single, indiv, nil
}
// prepareArchiveFiles prepares all files to be included in archives.
func prepareArchiveFiles(
singleFileCerts []*x509.Certificate,
individualCerts []certWithPath,
sourcePaths []string,
outputs Outputs,
encoding string,
) ([]fileEntry, error) {
var archiveFiles []fileEntry
// Track used filenames to avoid collisions inside archives
usedNames := make(map[string]int)
// Handle a single bundle file
if outputs.IncludeSingle && len(singleFileCerts) > 0 {
bundleTime := maxModTime(sourcePaths)
files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true)
if err != nil {
return nil, fmt.Errorf("failed to encode single bundle: %w", err)
}
for i := range files {
files[i].name = makeUniqueName(files[i].name, usedNames)
files[i].modTime = bundleTime
// Best-effort: we do not have a portable birth/creation time.
// Use the same timestamp for created time to track deterministically.
files[i].createTime = bundleTime
}
archiveFiles = append(archiveFiles, files...)
}
// Handle individual files
if outputs.IncludeIndividual {
for _, cp := range individualCerts {
baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path))
files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false)
if err != nil {
return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err)
}
mt := fileModTime(cp.path)
for i := range files {
files[i].name = makeUniqueName(files[i].name, usedNames)
files[i].modTime = mt
files[i].createTime = mt
}
archiveFiles = append(archiveFiles, files...)
}
}
// Generate manifest if requested
if outputs.Manifest {
manifestContent := generateManifest(archiveFiles)
manifestName := makeUniqueName("MANIFEST", usedNames)
mt := maxModTime(sourcePaths)
archiveFiles = append(archiveFiles, fileEntry{
name: manifestName,
content: manifestContent,
modTime: mt,
createTime: mt,
})
}
return archiveFiles, nil
}
// createArchiveFiles creates archive files in the specified formats.
func createArchiveFiles(
groupName string,
formats []string,
archiveFiles []fileEntry,
outputDir string,
) ([]string, error) {
createdFiles := make([]string, 0, len(formats))
for _, format := range formats {
ext, ok := formatExtensions[format]
if !ok {
return nil, fmt.Errorf("unsupported format: %s", format)
}
archivePath := filepath.Join(outputDir, groupName+ext)
switch format {
case "zip":
if err := createZipArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create zip archive: %w", err)
}
case "tgz":
if err := createTarGzArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create tar.gz archive: %w", err)
}
default:
return nil, fmt.Errorf("unsupported format: %s", format)
}
createdFiles = append(createdFiles, archivePath)
}
return createdFiles, nil
}
func checkExpiry(path string, cert *x509.Certificate, expiryDuration time.Duration) {
now := time.Now()
expiryThreshold := now.Add(expiryDuration)
if cert.NotAfter.Before(expiryThreshold) {
daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24)
if daysUntilExpiry < 0 {
fmt.Fprintf(
os.Stderr,
"WARNING: Certificate %s has EXPIRED (expired %d days ago)\n",
path,
-daysUntilExpiry,
)
} else {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02"))
}
}
}
type fileEntry struct {
name string
content []byte
modTime time.Time
createTime time.Time
}
type certWithPath struct {
cert *x509.Certificate
path string
}
// encodeCertsToFiles converts certificates to file entries based on encoding type
// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file.
func encodeCertsToFiles(
certs []*x509.Certificate,
baseName string,
encoding string,
isSingle bool,
) ([]fileEntry, error) {
var files []fileEntry
switch encoding {
case "pem":
pemContent := encodeCertsToPEM(certs)
files = append(files, fileEntry{
name: baseName + ".pem",
content: pemContent,
})
case "der":
if isSingle {
// For single file in DER, concatenate all cert DER bytes
var derContent []byte
for _, cert := range certs {
derContent = append(derContent, cert.Raw...)
}
files = append(files, fileEntry{
name: baseName + ".crt",
content: derContent,
})
} else if len(certs) > 0 {
// Individual DER file (should only have one cert)
files = append(files, fileEntry{
name: baseName + ".crt",
content: certs[0].Raw,
})
}
case "both":
// Add PEM version
pemContent := encodeCertsToPEM(certs)
files = append(files, fileEntry{
name: baseName + ".pem",
content: pemContent,
})
// Add DER version
if isSingle {
var derContent []byte
for _, cert := range certs {
derContent = append(derContent, cert.Raw...)
}
files = append(files, fileEntry{
name: baseName + ".crt",
content: derContent,
})
} else if len(certs) > 0 {
files = append(files, fileEntry{
name: baseName + ".crt",
content: certs[0].Raw,
})
}
default:
return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding)
}
return files, nil
}
// encodeCertsToPEM encodes certificates to PEM format.
func encodeCertsToPEM(certs []*x509.Certificate) []byte {
var pemContent []byte
for _, cert := range certs {
pemBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
pemContent = append(pemContent, pem.EncodeToMemory(pemBlock)...)
}
return pemContent
}
func generateManifest(files []fileEntry) []byte {
// Build a sorted list of files by filename to ensure deterministic manifest ordering
sorted := make([]fileEntry, 0, len(files))
for _, f := range files {
// Defensive: skip any existing manifest entry
if f.name == "MANIFEST" {
continue
}
sorted = append(sorted, f)
}
sort.Slice(sorted, func(i, j int) bool { return sorted[i].name < sorted[j].name })
var manifest strings.Builder
for _, file := range sorted {
hash := sha256.Sum256(file.content)
manifest.WriteString(fmt.Sprintf("%x %s\n", hash, file.name))
}
return []byte(manifest.String())
}
// closeWithErr attempts to close all provided closers, joining any close errors with baseErr.
func closeWithErr(baseErr error, closers ...io.Closer) error {
for _, c := range closers {
if c == nil {
continue
}
if cerr := c.Close(); cerr != nil {
baseErr = errors.Join(baseErr, cerr)
}
}
return baseErr
}
func createZipArchive(path string, files []fileEntry) error {
f, zerr := os.Create(path)
if zerr != nil {
return zerr
}
w := zip.NewWriter(f)
for _, file := range files {
hdr := &zip.FileHeader{
Name: file.name,
Method: zip.Deflate,
}
if !file.modTime.IsZero() {
hdr.SetModTime(file.modTime)
}
fw, werr := w.CreateHeader(hdr)
if werr != nil {
return closeWithErr(werr, w, f)
}
if _, werr = fw.Write(file.content); werr != nil {
return closeWithErr(werr, w, f)
}
}
// Check errors on close operations
if cerr := w.Close(); cerr != nil {
_ = f.Close()
return cerr
}
return f.Close()
}
func createTarGzArchive(path string, files []fileEntry) error {
f, terr := os.Create(path)
if terr != nil {
return terr
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
for _, file := range files {
hdr := &tar.Header{
Name: file.name,
Uid: 0,
Gid: 0,
Mode: defaultFileMode,
Size: int64(len(file.content)),
ModTime: func() time.Time {
if file.modTime.IsZero() {
return time.Now()
}
return file.modTime
}(),
}
// Set additional times if supported
hdr.AccessTime = hdr.ModTime
if !file.createTime.IsZero() {
hdr.ChangeTime = file.createTime
} else {
hdr.ChangeTime = hdr.ModTime
}
if herr := tw.WriteHeader(hdr); herr != nil {
return closeWithErr(herr, tw, gw, f)
}
if _, werr := tw.Write(file.content); werr != nil {
return closeWithErr(werr, tw, gw, f)
}
}
// Check errors on close operations in the correct order
if cerr := tw.Close(); cerr != nil {
_ = gw.Close()
_ = f.Close()
return cerr
}
if cerr := gw.Close(); cerr != nil {
_ = f.Close()
return cerr
}
return f.Close()
}
func generateHashFile(path string, files []string) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
for _, file := range files {
data, rerr := os.ReadFile(file)
if rerr != nil {
return rerr
}
hash := sha256.Sum256(data)
fmt.Fprintf(f, "%x %s\n", hash, filepath.Base(file))
}
return nil
}
// makeUniqueName ensures that each file name within the archive is unique by appending
// an incremental numeric suffix before the extension when collisions occur.
// Example: "root.pem" -> "root-2.pem", "root-3.pem", etc.
func makeUniqueName(name string, used map[string]int) string {
// If unused, mark and return as-is
if _, ok := used[name]; !ok {
used[name] = 1
return name
}
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
// Track a counter per base+ext key
key := base + ext
counter := max(used[key], 1)
for {
counter++
candidate := fmt.Sprintf("%s-%d%s", base, counter, ext)
if _, exists := used[candidate]; !exists {
used[key] = counter
used[candidate] = 1
return candidate
}
}
}
// fileModTime returns the file's modification time, or time.Now() if stat fails.
func fileModTime(path string) time.Time {
fi, err := os.Stat(path)
if err != nil {
return time.Now()
}
return fi.ModTime()
}
// maxModTime returns the latest modification time across provided paths.
// If the list is empty or stats fail, returns time.Now().
func maxModTime(paths []string) time.Time {
var zero time.Time
maxTime := zero
for _, p := range paths {
fi, err := os.Stat(p)
if err != nil {
continue
}
mt := fi.ModTime()
if maxTime.IsZero() || mt.After(maxTime) {
maxTime = mt
}
}
if maxTime.IsZero() {
return time.Now()
}
return maxTime
}

33
certlib/certerr/doc.go Normal file
View File

@@ -0,0 +1,33 @@
// Package certerr provides typed errors and helpers for certificate-related
// operations across the repository. It standardizes error construction and
// matching so callers can reliably branch on error source/kind using the
// Go 1.13+ `errors.Is` and `errors.As` helpers.
//
// Guidelines
// - Always wrap underlying causes using the helper constructors or with
// fmt.Errorf("context: %w", err).
// - Do not include sensitive data (keys, passwords, tokens) in error
// messages; add only non-sensitive, actionable context.
// - Prefer programmatic checks via errors.Is (for sentinel errors) and
// errors.As (to retrieve *certerr.Error) rather than relying on error
// string contents.
//
// Typical usage
//
// if err := doParse(); err != nil {
// return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
// }
//
// Callers may branch on error kinds and sources:
//
// var e *certerr.Error
// if errors.As(err, &e) {
// switch e.Kind {
// case certerr.KindParse:
// // handle parse error
// }
// }
//
// Sentinel errors are provided for common conditions like
// `certerr.ErrEncryptedPrivateKey` and can be matched with `errors.Is`.
package certerr

View File

@@ -37,43 +37,84 @@ const (
ErrorSourceKeypair ErrorSourceType = 5 ErrorSourceKeypair ErrorSourceType = 5
) )
// InvalidPEMType is used to indicate that we were expecting one type of PEM // ErrorKind is a broad classification describing what went wrong.
type ErrorKind uint8
const (
KindParse ErrorKind = iota + 1
KindDecode
KindVerify
KindLoad
)
func (k ErrorKind) String() string {
switch k {
case KindParse:
return "parse"
case KindDecode:
return "decode"
case KindVerify:
return "verify"
case KindLoad:
return "load"
default:
return "unknown"
}
}
// Error is a typed, wrapped error with structured context for programmatic checks.
// It implements error and supports errors.Is/As via Unwrap.
type Error struct {
Source ErrorSourceType // which domain produced the error (certificate, private key, etc.)
Kind ErrorKind // operation category (parse, decode, verify, load)
Op string // optional operation or function name
Err error // wrapped cause
}
func (e *Error) Error() string {
// Keep message format consistent with existing helpers: "failed to <kind> <source>: <err>"
// Do not include Op by default to preserve existing output expectations.
return fmt.Sprintf("failed to %s %s: %v", e.Kind.String(), e.Source.String(), e.Err)
}
func (e *Error) Unwrap() error { return e.Err }
// InvalidPEMTypeError is used to indicate that we were expecting one type of PEM
// file, but saw another. // file, but saw another.
type InvalidPEMType struct { type InvalidPEMTypeError struct {
have string have string
want []string want []string
} }
func (err *InvalidPEMType) Error() string { func (err *InvalidPEMTypeError) Error() string {
if len(err.want) == 1 { if len(err.want) == 1 {
return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0]) return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0])
} else {
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
// ErrInvalidPEMType returns a new InvalidPEMType error. // ErrInvalidPEMType returns a new InvalidPEMTypeError error.
func ErrInvalidPEMType(have string, want ...string) error { func ErrInvalidPEMType(have string, want ...string) error {
return &InvalidPEMType{ return &InvalidPEMTypeError{
have: have, have: have,
want: want, want: want,
} }
} }
func LoadingError(t ErrorSourceType, err error) error { func LoadingError(t ErrorSourceType, err error) error {
return fmt.Errorf("failed to load %s from disk: %w", t, err) return &Error{Source: t, Kind: KindLoad, Err: err}
} }
func ParsingError(t ErrorSourceType, err error) error { func ParsingError(t ErrorSourceType, err error) error {
return fmt.Errorf("failed to parse %s: %w", t, err) return &Error{Source: t, Kind: KindParse, Err: err}
} }
func DecodeError(t ErrorSourceType, err error) error { func DecodeError(t ErrorSourceType, err error) error {
return fmt.Errorf("failed to decode %s: %w", t, err) return &Error{Source: t, Kind: KindDecode, Err: err}
} }
func VerifyError(t ErrorSourceType, err error) error { func VerifyError(t ErrorSourceType, err error) error {
return fmt.Errorf("failed to verify %s: %w", t, err) return &Error{Source: t, Kind: KindVerify, Err: err}
} }
var ErrEncryptedPrivateKey = errors.New("private key is encrypted") var ErrEncryptedPrivateKey = errors.New("private key is encrypted")

View File

@@ -0,0 +1,56 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certerr
import (
"errors"
"strings"
"testing"
)
func TestTypedErrorWrappingAndFormatting(t *testing.T) {
cause := errors.New("bad data")
err := DecodeError(ErrorSourceCertificate, cause)
// Ensure we can retrieve the typed error
var e *Error
if !errors.As(err, &e) {
t.Fatalf("expected errors.As to retrieve *certerr.Error, got %T", err)
}
if e.Kind != KindDecode {
t.Fatalf("unexpected kind: %v", e.Kind)
}
if e.Source != ErrorSourceCertificate {
t.Fatalf("unexpected source: %v", e.Source)
}
// Check message format (no trailing punctuation enforced by content)
msg := e.Error()
if !strings.Contains(msg, "failed to decode certificate") || !strings.Contains(msg, "bad data") {
t.Fatalf("unexpected error message: %q", msg)
}
}
func TestErrorsIsOnWrappedSentinel(t *testing.T) {
err := DecodeError(ErrorSourcePrivateKey, ErrEncryptedPrivateKey)
if !errors.Is(err, ErrEncryptedPrivateKey) {
t.Fatalf("expected errors.Is to match ErrEncryptedPrivateKey")
}
}
func TestInvalidPEMTypeMessageSingle(t *testing.T) {
err := ErrInvalidPEMType("FOO", "CERTIFICATE")
want := "invalid PEM type: have FOO, expected CERTIFICATE"
if err.Error() != want {
t.Fatalf("unexpected error message: got %q, want %q", err.Error(), want)
}
}
func TestInvalidPEMTypeMessageMultiple(t *testing.T) {
err := ErrInvalidPEMType("FOO", "CERTIFICATE", "NEW CERTIFICATE REQUEST")
if !strings.Contains(
err.Error(),
"invalid PEM type: have FOO, expected one of CERTIFICATE, NEW CERTIFICATE REQUEST",
) {
t.Fatalf("unexpected error message: %q", err.Error())
}
}

View File

@@ -4,43 +4,53 @@ import (
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"errors" "errors"
"io/ioutil" "os"
"git.wntrmute.dev/kyle/goutils/certlib/certerr" "git.wntrmute.dev/kyle/goutils/certlib/certerr"
) )
// ReadCertificate reads a DER or PEM-encoded certificate from the // ReadCertificate reads a DER or PEM-encoded certificate from the
// byte slice. // byte slice.
func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error) { func ReadCertificate(in []byte) (*x509.Certificate, []byte, error) {
if len(in) == 0 { if len(in) == 0 {
err = certerr.ErrEmptyCertificate return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate)
return
} }
if in[0] == '-' { if in[0] == '-' {
p, remaining := pem.Decode(in) p, remaining := pem.Decode(in)
if p == nil { if p == nil {
err = errors.New("certlib: invalid PEM file") return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file"))
return
} }
rest = remaining rest := remaining
if p.Type != "CERTIFICATE" { if p.Type != "CERTIFICATE" {
err = certerr.ErrInvalidPEMType(p.Type, "CERTIFICATE") return nil, rest, certerr.ParsingError(
return certerr.ErrorSourceCertificate,
certerr.ErrInvalidPEMType(p.Type, "CERTIFICATE"),
)
} }
in = p.Bytes in = p.Bytes
cert, err := x509.ParseCertificate(in)
if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return cert, rest, nil
} }
cert, err = x509.ParseCertificate(in) cert, err := x509.ParseCertificate(in)
return if err != nil {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return cert, nil, nil
} }
// ReadCertificates tries to read all the certificates in a // ReadCertificates tries to read all the certificates in a
// PEM-encoded collection. // PEM-encoded collection.
func ReadCertificates(in []byte) (certs []*x509.Certificate, err error) { func ReadCertificates(in []byte) ([]*x509.Certificate, error) {
var cert *x509.Certificate var cert *x509.Certificate
var certs []*x509.Certificate
var err error
for { for {
cert, in, err = ReadCertificate(in) cert, in, err = ReadCertificate(in)
if err != nil { if err != nil {
@@ -64,9 +74,9 @@ func ReadCertificates(in []byte) (certs []*x509.Certificate, err error) {
// the file contains multiple certificates (e.g. a chain), only the // the file contains multiple certificates (e.g. a chain), only the
// first certificate is returned. // first certificate is returned.
func LoadCertificate(path string) (*x509.Certificate, error) { func LoadCertificate(path string) (*x509.Certificate, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, certerr.LoadingError(certerr.ErrorSourceCertificate, err)
} }
cert, _, err := ReadCertificate(in) cert, _, err := ReadCertificate(in)
@@ -76,9 +86,9 @@ func LoadCertificate(path string) (*x509.Certificate, error) {
// LoadCertificates tries to read all the certificates in a file, // LoadCertificates tries to read all the certificates in a file,
// returning them in the order that it found them in the file. // returning them in the order that it found them in the file.
func LoadCertificates(path string) ([]*x509.Certificate, error) { func LoadCertificates(path string) ([]*x509.Certificate, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, certerr.LoadingError(certerr.ErrorSourceCertificate, err)
} }
return ReadCertificates(in) return ReadCertificates(in)

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certlib package certlib
import ( import (

View File

@@ -38,6 +38,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/certlib/certerr" "git.wntrmute.dev/kyle/goutils/certlib/certerr"
@@ -47,29 +48,36 @@ import (
// private key. The key must not be in PEM format. If an error is returned, it // private key. The key must not be in PEM format. If an error is returned, it
// may contain information about the private key, so care should be taken when // may contain information about the private key, so care should be taken when
// displaying it directly. // displaying it directly.
func ParsePrivateKeyDER(keyDER []byte) (key crypto.Signer, err error) { func ParsePrivateKeyDER(keyDER []byte) (crypto.Signer, error) {
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER) // Try common encodings in order without deep nesting.
if err != nil { if k, err := x509.ParsePKCS8PrivateKey(keyDER); err == nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER) switch kk := k.(type) {
if err != nil { case *rsa.PrivateKey:
generalKey, err = x509.ParseECPrivateKey(keyDER) return kk, nil
if err != nil { case *ecdsa.PrivateKey:
generalKey, err = ParseEd25519PrivateKey(keyDER) return kk, nil
if err != nil { case ed25519.PrivateKey:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err) return kk, nil
} default:
} return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
} }
} }
if k, err := x509.ParsePKCS1PrivateKey(keyDER); err == nil {
switch generalKey := generalKey.(type) { return k, nil
case *rsa.PrivateKey:
return generalKey, nil
case *ecdsa.PrivateKey:
return generalKey, nil
case ed25519.PrivateKey:
return generalKey, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %t", generalKey))
} }
if k, err := x509.ParseECPrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := ParseEd25519PrivateKey(keyDER); err == nil {
if kk, ok := k.(ed25519.PrivateKey); ok {
return kk, nil
}
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
// If all parsers failed, return the last error from Ed25519 attempt (approximate cause).
if _, err := ParseEd25519PrivateKey(keyDER); err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err)
}
// Fallback (should be unreachable)
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, errors.New("unknown key encoding"))
} }

View File

@@ -65,12 +65,14 @@ func MarshalEd25519PublicKey(pk crypto.PublicKey) ([]byte, error) {
return nil, errEd25519WrongKeyType return nil, errEd25519WrongKeyType
} }
const bitsPerByte = 8
spki := subjectPublicKeyInfo{ spki := subjectPublicKeyInfo{
Algorithm: pkix.AlgorithmIdentifier{ Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID, Algorithm: ed25519OID,
}, },
PublicKey: asn1.BitString{ PublicKey: asn1.BitString{
BitLength: len(pub) * 8, BitLength: len(pub) * bitsPerByte,
Bytes: pub, Bytes: pub,
}, },
} }
@@ -91,7 +93,8 @@ func ParseEd25519PublicKey(der []byte) (crypto.PublicKey, error) {
return nil, errEd25519WrongID return nil, errEd25519WrongID
} }
if spki.PublicKey.BitLength != ed25519.PublicKeySize*8 { const bitsPerByte = 8
if spki.PublicKey.BitLength != ed25519.PublicKeySize*bitsPerByte {
return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch") return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch")
} }

175
certlib/fetch.go Normal file
View File

@@ -0,0 +1,175 @@
package certlib
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"net"
"os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/fileutil"
"git.wntrmute.dev/kyle/goutils/lib"
)
// FetcherOpts are options for fetching certificates. They are only applicable to ServerFetcher.
type FetcherOpts struct {
SkipVerify bool
Roots *x509.CertPool
}
// Fetcher is an interface for fetching certificates from a remote source. It
// currently supports fetching from a server or a file.
type Fetcher interface {
Get() (*x509.Certificate, error)
GetChain() ([]*x509.Certificate, error)
String() string
}
type ServerFetcher struct {
host string
port int
insecure bool
roots *x509.CertPool
}
// WithRoots sets the roots for the ServerFetcher.
func WithRoots(roots *x509.CertPool) func(*ServerFetcher) {
return func(sf *ServerFetcher) {
sf.roots = roots
}
}
// WithSkipVerify sets the insecure flag for the ServerFetcher.
func WithSkipVerify() func(*ServerFetcher) {
return func(sf *ServerFetcher) {
sf.insecure = true
}
}
// ParseServer parses a server string into a ServerFetcher. It can be a URL or a
// a host:port pair.
func ParseServer(host string) (*ServerFetcher, error) {
target, err := hosts.ParseHost(host)
if err != nil {
return nil, fmt.Errorf("failed to parse server: %w", err)
}
return &ServerFetcher{
host: target.Host,
port: target.Port,
}, nil
}
func (sf *ServerFetcher) String() string {
return fmt.Sprintf("tls://%s", net.JoinHostPort(sf.host, lib.Itoa(sf.port, -1)))
}
func (sf *ServerFetcher) GetChain() ([]*x509.Certificate, error) {
config := &tls.Config{
InsecureSkipVerify: sf.insecure, // #nosec G402 - no shit sherlock
RootCAs: sf.roots,
}
dialer := &tls.Dialer{
Config: config,
}
hostSpec := net.JoinHostPort(sf.host, lib.Itoa(sf.port, -1))
netConn, err := dialer.DialContext(context.Background(), "tcp", hostSpec)
if err != nil {
return nil, fmt.Errorf("dialing server: %w", err)
}
conn, ok := netConn.(*tls.Conn)
if !ok {
return nil, errors.New("connection is not TLS")
}
defer conn.Close()
state := conn.ConnectionState()
return state.PeerCertificates, nil
}
func (sf *ServerFetcher) Get() (*x509.Certificate, error) {
certs, err := sf.GetChain()
if err != nil {
return nil, err
}
return certs[0], nil
}
type FileFetcher struct {
path string
}
func NewFileFetcher(path string) *FileFetcher {
return &FileFetcher{
path: path,
}
}
func (ff *FileFetcher) String() string {
return ff.path
}
func (ff *FileFetcher) GetChain() ([]*x509.Certificate, error) {
if ff.path == "-" {
certData, err := io.ReadAll(os.Stdin)
if err != nil {
return nil, fmt.Errorf("failed to read from stdin: %w", err)
}
return ParseCertificatesPEM(certData)
}
certs, err := LoadCertificates(ff.path)
if err != nil {
return nil, fmt.Errorf("failed to load chain: %w", err)
}
return certs, nil
}
func (ff *FileFetcher) Get() (*x509.Certificate, error) {
certs, err := ff.GetChain()
if err != nil {
return nil, err
}
return certs[0], nil
}
// GetCertificateChain fetches a certificate chain from a remote source.
func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, error) {
if fileutil.FileDoesExist(spec) {
return NewFileFetcher(spec).GetChain()
}
fetcher, err := ParseServer(spec)
if err != nil {
return nil, err
}
if opts != nil {
fetcher.insecure = opts.SkipVerify
fetcher.roots = opts.Roots
}
return fetcher.GetChain()
}
// GetCertificate fetches the first certificate from a certificate chain.
func GetCertificate(spec string, opts *FetcherOpts) (*x509.Certificate, error) {
certs, err := GetCertificateChain(spec, opts)
if err != nil {
return nil, err
}
return certs[0], nil
}

View File

@@ -49,14 +49,14 @@ import (
"strings" "strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
ct "github.com/google/certificate-transparency-go" ct "github.com/google/certificate-transparency-go"
cttls "github.com/google/certificate-transparency-go/tls" cttls "github.com/google/certificate-transparency-go/tls"
ctx509 "github.com/google/certificate-transparency-go/x509" ctx509 "github.com/google/certificate-transparency-go/x509"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"golang.org/x/crypto/pkcs12" "golang.org/x/crypto/pkcs12"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
) )
// OneYear is a time.Duration representing a year's worth of seconds. // OneYear is a time.Duration representing a year's worth of seconds.
@@ -65,10 +65,10 @@ const OneYear = 8760 * time.Hour
// OneDay is a time.Duration representing a day's worth of seconds. // OneDay is a time.Duration representing a day's worth of seconds.
const OneDay = 24 * time.Hour const OneDay = 24 * time.Hour
// DelegationUsage is the OID for the DelegationUseage extensions // DelegationUsage is the OID for the DelegationUseage extensions.
var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44} var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44}
// DelegationExtension // DelegationExtension is a non-critical extension marking delegation usage.
var DelegationExtension = pkix.Extension{ var DelegationExtension = pkix.Extension{
Id: DelegationUsage, Id: DelegationUsage,
Critical: false, Critical: false,
@@ -81,41 +81,51 @@ func InclusiveDate(year int, month time.Month, day int) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond) return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond)
} }
const (
year2012 = 2012
year2015 = 2015
day1 = 1
)
// Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop // Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 5 years. // issuing certificates valid for more than 5 years.
var Jul2012 = InclusiveDate(2012, time.July, 01) var Jul2012 = InclusiveDate(year2012, time.July, day1)
// Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop // Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 39 months. // issuing certificates valid for more than 39 months.
var Apr2015 = InclusiveDate(2015, time.April, 01) var Apr2015 = InclusiveDate(year2015, time.April, day1)
// KeyLength returns the bit size of ECDSA or RSA PublicKey // KeyLength returns the bit size of ECDSA or RSA PublicKey.
func KeyLength(key interface{}) int { func KeyLength(key any) int {
if key == nil { switch k := key.(type) {
case *ecdsa.PublicKey:
if k == nil {
return 0
}
return k.Curve.Params().BitSize
case *rsa.PublicKey:
if k == nil {
return 0
}
return k.N.BitLen()
default:
return 0 return 0
} }
if ecdsaKey, ok := key.(*ecdsa.PublicKey); ok {
return ecdsaKey.Curve.Params().BitSize
} else if rsaKey, ok := key.(*rsa.PublicKey); ok {
return rsaKey.N.BitLen()
}
return 0
} }
// ExpiryTime returns the time when the certificate chain is expired. // ExpiryTime returns the time when the certificate chain is expired.
func ExpiryTime(chain []*x509.Certificate) (notAfter time.Time) { func ExpiryTime(chain []*x509.Certificate) time.Time {
var notAfter time.Time
if len(chain) == 0 { if len(chain) == 0 {
return return notAfter
} }
notAfter = chain[0].NotAfter notAfter = chain[0].NotAfter
for _, cert := range chain { for _, cert := range chain {
if notAfter.After(cert.NotAfter) { if notAfter.After(cert.NotAfter) {
notAfter = cert.NotAfter notAfter = cert.NotAfter
} }
} }
return return notAfter
} }
// MonthsValid returns the number of months for which a certificate is valid. // MonthsValid returns the number of months for which a certificate is valid.
@@ -144,109 +154,109 @@ func ValidExpiry(c *x509.Certificate) bool {
maxMonths = 39 maxMonths = 39
case issued.After(Jul2012): case issued.After(Jul2012):
maxMonths = 60 maxMonths = 60
case issued.Before(Jul2012): default:
maxMonths = 120 maxMonths = 120
} }
if MonthsValid(c) > maxMonths { return MonthsValid(c) <= maxMonths
return false }
}
return true // SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm.
var signatureString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Signature",
x509.MD2WithRSA: "MD2WithRSA",
x509.MD5WithRSA: "MD5WithRSA",
x509.SHA1WithRSA: "SHA1WithRSA",
x509.SHA256WithRSA: "SHA256WithRSA",
x509.SHA384WithRSA: "SHA384WithRSA",
x509.SHA512WithRSA: "SHA512WithRSA",
x509.SHA256WithRSAPSS: "SHA256WithRSAPSS",
x509.SHA384WithRSAPSS: "SHA384WithRSAPSS",
x509.SHA512WithRSAPSS: "SHA512WithRSAPSS",
x509.DSAWithSHA1: "DSAWithSHA1",
x509.DSAWithSHA256: "DSAWithSHA256",
x509.ECDSAWithSHA1: "ECDSAWithSHA1",
x509.ECDSAWithSHA256: "ECDSAWithSHA256",
x509.ECDSAWithSHA384: "ECDSAWithSHA384",
x509.ECDSAWithSHA512: "ECDSAWithSHA512",
x509.PureEd25519: "PureEd25519",
} }
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm. // an X509 signature algorithm.
func SignatureString(alg x509.SignatureAlgorithm) string { func SignatureString(alg x509.SignatureAlgorithm) string {
switch alg { if s, ok := signatureString[alg]; ok {
case x509.MD2WithRSA: return s
return "MD2WithRSA"
case x509.MD5WithRSA:
return "MD5WithRSA"
case x509.SHA1WithRSA:
return "SHA1WithRSA"
case x509.SHA256WithRSA:
return "SHA256WithRSA"
case x509.SHA384WithRSA:
return "SHA384WithRSA"
case x509.SHA512WithRSA:
return "SHA512WithRSA"
case x509.DSAWithSHA1:
return "DSAWithSHA1"
case x509.DSAWithSHA256:
return "DSAWithSHA256"
case x509.ECDSAWithSHA1:
return "ECDSAWithSHA1"
case x509.ECDSAWithSHA256:
return "ECDSAWithSHA256"
case x509.ECDSAWithSHA384:
return "ECDSAWithSHA384"
case x509.ECDSAWithSHA512:
return "ECDSAWithSHA512"
default:
return "Unknown Signature"
} }
return "Unknown Signature"
}
// HashAlgoString returns the hash algorithm name contains in the signature
// method.
var hashAlgoString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Hash Algorithm",
x509.MD2WithRSA: "MD2",
x509.MD5WithRSA: "MD5",
x509.SHA1WithRSA: "SHA1",
x509.SHA256WithRSA: "SHA256",
x509.SHA384WithRSA: "SHA384",
x509.SHA512WithRSA: "SHA512",
x509.SHA256WithRSAPSS: "SHA256",
x509.SHA384WithRSAPSS: "SHA384",
x509.SHA512WithRSAPSS: "SHA512",
x509.DSAWithSHA1: "SHA1",
x509.DSAWithSHA256: "SHA256",
x509.ECDSAWithSHA1: "SHA1",
x509.ECDSAWithSHA256: "SHA256",
x509.ECDSAWithSHA384: "SHA384",
x509.ECDSAWithSHA512: "SHA512",
x509.PureEd25519: "SHA512", // per x509 docs Ed25519 uses SHA-512 internally
} }
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
// method. // method.
func HashAlgoString(alg x509.SignatureAlgorithm) string { func HashAlgoString(alg x509.SignatureAlgorithm) string {
switch alg { if s, ok := hashAlgoString[alg]; ok {
case x509.MD2WithRSA: return s
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA:
return "SHA1"
case x509.SHA256WithRSA:
return "SHA256"
case x509.SHA384WithRSA:
return "SHA384"
case x509.SHA512WithRSA:
return "SHA512"
case x509.DSAWithSHA1:
return "SHA1"
case x509.DSAWithSHA256:
return "SHA256"
case x509.ECDSAWithSHA1:
return "SHA1"
case x509.ECDSAWithSHA256:
return "SHA256"
case x509.ECDSAWithSHA384:
return "SHA384"
case x509.ECDSAWithSHA512:
return "SHA512"
default:
return "Unknown Hash Algorithm"
} }
return "Unknown Hash Algorithm"
} }
// StringTLSVersion returns underlying enum values from human names for TLS // StringTLSVersion returns underlying enum values from human names for TLS
// versions, defaults to current golang default of TLS 1.0 // versions, defaults to current golang default of TLS 1.0.
func StringTLSVersion(version string) uint16 { func StringTLSVersion(version string) uint16 {
switch version { switch version {
case "1.3":
return tls.VersionTLS13
case "1.2": case "1.2":
return tls.VersionTLS12 return tls.VersionTLS12
case "1.1": case "1.1":
return tls.VersionTLS11 return tls.VersionTLS11
case "1.0":
return tls.VersionTLS10
default: default:
// Default to Go's historical default of TLS 1.0 for unknown values
return tls.VersionTLS10 return tls.VersionTLS10
} }
} }
// EncodeCertificatesPEM encodes a number of x509 certificates to PEM // EncodeCertificatesPEM encodes a number of x509 certificates to PEM.
func EncodeCertificatesPEM(certs []*x509.Certificate) []byte { func EncodeCertificatesPEM(certs []*x509.Certificate) []byte {
var buffer bytes.Buffer var buffer bytes.Buffer
for _, cert := range certs { for _, cert := range certs {
pem.Encode(&buffer, &pem.Block{ if err := pem.Encode(&buffer, &pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
}) }); err != nil {
return nil
}
} }
return buffer.Bytes() return buffer.Bytes()
} }
// EncodeCertificatePEM encodes a single x509 certificates to PEM // EncodeCertificatePEM encodes a single x509 certificates to PEM.
func EncodeCertificatePEM(cert *x509.Certificate) []byte { func EncodeCertificatePEM(cert *x509.Certificate) []byte {
return EncodeCertificatesPEM([]*x509.Certificate{cert}) return EncodeCertificatesPEM([]*x509.Certificate{cert})
} }
@@ -269,38 +279,52 @@ func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
certs = append(certs, cert...) certs = append(certs, cert...)
} }
if len(certsPEM) > 0 { if len(certsPEM) > 0 {
return nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("trailing data at end of certificate")) return nil, certerr.DecodeError(
certerr.ErrorSourceCertificate,
errors.New("trailing data at end of certificate"),
)
} }
return certs, nil return certs, nil
} }
// ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key, // ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key,
// either PKCS #7, PKCS #12, or raw x509. // either PKCS #7, PKCS #12, or raw x509.
func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certificate, key crypto.Signer, err error) { func ParseCertificatesDER(certsDER []byte, password string) ([]*x509.Certificate, crypto.Signer, error) {
certsDER = bytes.TrimSpace(certsDER) certsDER = bytes.TrimSpace(certsDER)
pkcs7data, err := pkcs7.ParsePKCS7(certsDER)
if err != nil { // First, try PKCS #7
var pkcs12data interface{} if pkcs7data, err7 := pkcs7.ParsePKCS7(certsDER); err7 == nil {
certs = make([]*x509.Certificate, 1)
pkcs12data, certs[0], err = pkcs12.Decode(certsDER, password)
if err != nil {
certs, err = x509.ParseCertificates(certsDER)
if err != nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err)
}
} else {
key = pkcs12data.(crypto.Signer)
}
} else {
if pkcs7data.ContentInfo != "SignedData" { if pkcs7data.ContentInfo != "SignedData" {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("can only extract certificates from signed data content info")) return nil, nil, certerr.DecodeError(
certerr.ErrorSourceCertificate,
errors.New("can only extract certificates from signed data content info"),
)
} }
certs = pkcs7data.Content.SignedData.Certificates certs := pkcs7data.Content.SignedData.Certificates
if certs == nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded"))
}
return certs, nil, nil
} }
if certs == nil {
return nil, key, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded")) // Next, try PKCS #12
if pkcs12data, cert, err12 := pkcs12.Decode(certsDER, password); err12 == nil {
signer, ok := pkcs12data.(crypto.Signer)
if !ok {
return nil, nil, certerr.DecodeError(
certerr.ErrorSourcePrivateKey,
errors.New("PKCS12 data does not contain a private key"),
)
}
return []*x509.Certificate{cert}, signer, nil
} }
return certs, key, nil
// Finally, attempt to parse raw X.509 certificates
certs, err := x509.ParseCertificates(certsDER)
if err != nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err)
}
return certs, nil, nil
} }
// ParseSelfSignedCertificatePEM parses a PEM-encoded certificate and check if it is self-signed. // ParseSelfSignedCertificatePEM parses a PEM-encoded certificate and check if it is self-signed.
@@ -310,7 +334,8 @@ func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
return nil, err return nil, err
} }
if err := cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil { err = cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature)
if err != nil {
return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err) return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err)
} }
return cert, nil return cert, nil
@@ -320,17 +345,26 @@ func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
// can handle PEM encoded PKCS #7 structures. // can handle PEM encoded PKCS #7 structures.
func ParseCertificatePEM(certPEM []byte) (*x509.Certificate, error) { func ParseCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
certPEM = bytes.TrimSpace(certPEM) certPEM = bytes.TrimSpace(certPEM)
cert, rest, err := ParseOneCertificateFromPEM(certPEM) certs, rest, err := ParseOneCertificateFromPEM(certPEM)
if err != nil { if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} else if cert == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificate decoded"))
} else if len(rest) > 0 {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("the PEM file should contain only one object"))
} else if len(cert) > 1 {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("the PKCS7 object in the PEM file should contain only one certificate"))
} }
return cert[0], nil if certs == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificate decoded"))
}
if len(rest) > 0 {
return nil, certerr.ParsingError(
certerr.ErrorSourceCertificate,
errors.New("the PEM file should contain only one object"),
)
}
if len(certs) > 1 {
return nil, certerr.ParsingError(
certerr.ErrorSourceCertificate,
errors.New("the PKCS7 object in the PEM file should contain only one certificate"),
)
}
return certs[0], nil
} }
// ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object, // ParseOneCertificateFromPEM attempts to parse one PEM encoded certificate object,
@@ -338,7 +372,6 @@ func ParseCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
// multiple certificates, from the top of certsPEM, which itself may // multiple certificates, from the top of certsPEM, which itself may
// contain multiple PEM encoded certificate objects. // contain multiple PEM encoded certificate objects.
func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) { func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, error) {
block, rest := pem.Decode(certsPEM) block, rest := pem.Decode(certsPEM)
if block == nil { if block == nil {
return nil, rest, nil return nil, rest, nil
@@ -346,8 +379,8 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
pkcs7data, err := pkcs7.ParsePKCS7(block.Bytes) pkcs7data, err2 := pkcs7.ParsePKCS7(block.Bytes)
if err != nil { if err2 != nil {
return nil, rest, err return nil, rest, err
} }
if pkcs7data.ContentInfo != "SignedData" { if pkcs7data.ContentInfo != "SignedData" {
@@ -363,10 +396,49 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
return certs, rest, nil return certs, rest, nil
} }
// LoadFullCertPool returns a certificate pool with roots and intermediates
// from disk. If no roots are provided, the system root pool will be used.
func LoadFullCertPool(roots, intermediates string) (*x509.CertPool, error) {
var err error
pool := x509.NewCertPool()
if roots == "" {
pool, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("loading system cert pool: %w", err)
}
} else {
var rootCerts []*x509.Certificate
rootCerts, err = LoadCertificates(roots)
if err != nil {
return nil, fmt.Errorf("loading roots: %w", err)
}
for _, cert := range rootCerts {
pool.AddCert(cert)
}
}
if intermediates != "" {
var intCerts []*x509.Certificate
intCerts, err = LoadCertificates(intermediates)
if err != nil {
return nil, fmt.Errorf("loading intermediates: %w", err)
}
for _, cert := range intCerts {
pool.AddCert(cert)
}
}
return pool, nil
}
// LoadPEMCertPool loads a pool of PEM certificates from file. // LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" { if certsFile == "" {
return nil, nil return nil, nil //nolint:nilnil // no CA file provided -> treat as no pool and no error
} }
pemCerts, err := os.ReadFile(certsFile) pemCerts, err := os.ReadFile(certsFile)
if err != nil { if err != nil {
@@ -379,12 +451,12 @@ func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
// PEMToCertPool concerts PEM certificates to a CertPool. // PEMToCertPool concerts PEM certificates to a CertPool.
func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) { func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
if len(pemCerts) == 0 { if len(pemCerts) == 0 {
return nil, nil return nil, nil //nolint:nilnil // empty input means no pool needed
} }
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
if !certPool.AppendCertsFromPEM(pemCerts) { if !certPool.AppendCertsFromPEM(pemCerts) {
return nil, errors.New("failed to load cert pool") return nil, certerr.LoadingError(certerr.ErrorSourceCertificate, errors.New("failed to load cert pool"))
} }
return certPool, nil return certPool, nil
@@ -393,14 +465,14 @@ func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
// ParsePrivateKeyPEM parses and returns a PEM-encoded private // ParsePrivateKeyPEM parses and returns a PEM-encoded private
// key. The private key may be either an unencrypted PKCS#8, PKCS#1, // key. The private key may be either an unencrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEM(keyPEM []byte) (crypto.Signer, error) {
return ParsePrivateKeyPEMWithPassword(keyPEM, nil) return ParsePrivateKeyPEMWithPassword(keyPEM, nil)
} }
// ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private // ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private
// key. The private key may be a potentially encrypted PKCS#8, PKCS#1, // key. The private key may be a potentially encrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (crypto.Signer, error) {
keyDER, err := GetKeyDERFromPEM(keyPEM, password) keyDER, err := GetKeyDERFromPEM(keyPEM, password)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -420,44 +492,47 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
break break
} }
} }
if keyDER != nil { if keyDER == nil {
if procType, ok := keyDER.Headers["Proc-Type"]; ok { return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
if strings.Contains(procType, "ENCRYPTED") {
if password != nil {
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
}
}
return keyDER.Bytes, nil
} }
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key")) if password != nil {
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
}
return keyDER.Bytes, nil
} }
// ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request. // ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request.
func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error) { func ParseCSR(in []byte) (*x509.CertificateRequest, []byte, error) {
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, rest := pem.Decode(in) p, rest := pem.Decode(in)
if p != nil { if p == nil {
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" { csr, err := x509.ParseCertificateRequest(in)
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, certerr.ErrInvalidPEMType(p.Type, "NEW CERTIFICATE REQUEST", "CERTIFICATE REQUEST")) if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
csr, err = x509.ParseCertificateRequest(p.Bytes) return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
} else { }
csr, err = x509.ParseCertificateRequest(in) return csr, rest, nil
} }
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" {
return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCSR,
certerr.ErrInvalidPEMType(p.Type, "NEW CERTIFICATE REQUEST", "CERTIFICATE REQUEST"),
)
}
csr, err := x509.ParseCertificateRequest(p.Bytes)
if err != nil { if err != nil {
return nil, rest, err return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
err = csr.CheckSignature() return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
if err != nil {
return nil, rest, err
} }
return csr, rest, nil return csr, rest, nil
} }
@@ -465,14 +540,14 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
// It does not check the signature. This is useful for dumping data from a CSR // It does not check the signature. This is useful for dumping data from a CSR
// locally. // locally.
func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) { func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode(csrPEM)
if block == nil { if block == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty")) return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty"))
} }
csrObject, err := x509.ParseCertificateRequest(block.Bytes) csrObject, err := x509.ParseCertificateRequest(block.Bytes)
if err != nil { if err != nil {
return nil, err return nil, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
return csrObject, nil return csrObject, nil
@@ -480,15 +555,20 @@ func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
// SignerAlgo returns an X.509 signature algorithm from a crypto.Signer. // SignerAlgo returns an X.509 signature algorithm from a crypto.Signer.
func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm { func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
const (
rsaBits2048 = 2048
rsaBits3072 = 3072
rsaBits4096 = 4096
)
switch pub := priv.Public().(type) { switch pub := priv.Public().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
bitLength := pub.N.BitLen() bitLength := pub.N.BitLen()
switch { switch {
case bitLength >= 4096: case bitLength >= rsaBits4096:
return x509.SHA512WithRSA return x509.SHA512WithRSA
case bitLength >= 3072: case bitLength >= rsaBits3072:
return x509.SHA384WithRSA return x509.SHA384WithRSA
case bitLength >= 2048: case bitLength >= rsaBits2048:
return x509.SHA256WithRSA return x509.SHA256WithRSA
default: default:
return x509.SHA1WithRSA return x509.SHA1WithRSA
@@ -509,7 +589,7 @@ func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
} }
} }
// LoadClientCertificate load key/certificate from pem files // LoadClientCertificate load key/certificate from pem files.
func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, error) { func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, error) {
if certFile != "" && keyFile != "" { if certFile != "" && keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile) cert, err := tls.LoadX509KeyPair(certFile, keyFile)
@@ -518,10 +598,10 @@ func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, e
} }
return &cert, nil return &cert, nil
} }
return nil, nil return nil, nil //nolint:nilnil // absence of client cert is not an error
} }
// CreateTLSConfig creates a tls.Config object from certs and roots // CreateTLSConfig creates a tls.Config object from certs and roots.
func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Config { func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Config {
var certs []tls.Certificate var certs []tls.Certificate
if cert != nil { if cert != nil {
@@ -530,6 +610,7 @@ func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Confi
return &tls.Config{ return &tls.Config{
Certificates: certs, Certificates: certs,
RootCAs: remoteCAs, RootCAs: remoteCAs,
MinVersion: tls.VersionTLS12, // secure default
} }
} }
@@ -554,18 +635,24 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
return nil, err return nil, err
} }
if len(rest) != 0 { if len(rest) != 0 {
return nil, certerr.ParsingError(certerr.ErrorSourceSCTList, errors.New("serialized SCT list contained trailing garbage")) return nil, certerr.ParsingError(
certerr.ErrorSourceSCTList,
errors.New("serialized SCT list contained trailing garbage"),
)
} }
list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList)) list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList))
for i, serializedSCT := range sctList.SCTList { for i, serializedSCT := range sctList.SCTList {
var sct ct.SignedCertificateTimestamp var sct ct.SignedCertificateTimestamp
rest, err := cttls.Unmarshal(serializedSCT.Val, &sct) rest2, err2 := cttls.Unmarshal(serializedSCT.Val, &sct)
if err != nil { if err2 != nil {
return nil, err return nil, err2
} }
if len(rest) != 0 { if len(rest2) != 0 {
return nil, certerr.ParsingError(certerr.ErrorSourceSCTList, errors.New("serialized SCT list contained trailing garbage")) return nil, certerr.ParsingError(
certerr.ErrorSourceSCTList,
errors.New("serialized SCT list contained trailing garbage"),
)
} }
list[i] = sct list[i] = sct
} }
@@ -577,12 +664,12 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
// unmarshalled. // unmarshalled.
func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) { func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) {
// This loop finds the SCTListExtension in the OCSP response. // This loop finds the SCTListExtension in the OCSP response.
var SCTListExtension, ext pkix.Extension var sctListExtension, ext pkix.Extension
for _, ext = range response.Extensions { for _, ext = range response.Extensions {
// sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp. // sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp.
sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5} sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
if ext.Id.Equal(sctExtOid) { if ext.Id.Equal(sctExtOid) {
SCTListExtension = ext sctListExtension = ext
break break
} }
} }
@@ -590,10 +677,10 @@ func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTim
// This code block extracts the sctList from the SCT extension. // This code block extracts the sctList from the SCT extension.
var sctList []ct.SignedCertificateTimestamp var sctList []ct.SignedCertificateTimestamp
var err error var err error
if numBytes := len(SCTListExtension.Value); numBytes != 0 { if numBytes := len(sctListExtension.Value); numBytes != 0 {
var serializedSCTList []byte var serializedSCTList []byte
rest := make([]byte, numBytes) rest := make([]byte, numBytes)
copy(rest, SCTListExtension.Value) copy(rest, sctListExtension.Value)
for len(rest) != 0 { for len(rest) != 0 {
rest, err = asn1.Unmarshal(rest, &serializedSCTList) rest, err = asn1.Unmarshal(rest, &serializedSCTList)
if err != nil { if err != nil {
@@ -611,20 +698,16 @@ func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTim
// the subsequent file. If no prefix is provided, valFile is assumed to be a // the subsequent file. If no prefix is provided, valFile is assumed to be a
// file path. // file path.
func ReadBytes(valFile string) ([]byte, error) { func ReadBytes(valFile string) ([]byte, error) {
switch splitVal := strings.SplitN(valFile, ":", 2); len(splitVal) { prefix, rest, found := strings.Cut(valFile, ":")
case 1: if !found {
return os.ReadFile(valFile) return os.ReadFile(valFile)
case 2: }
switch splitVal[0] { switch prefix {
case "env": case "env":
return []byte(os.Getenv(splitVal[1])), nil return []byte(os.Getenv(rest)), nil
case "file": case "file":
return os.ReadFile(splitVal[1]) return os.ReadFile(rest)
default:
return nil, fmt.Errorf("unknown prefix: %s", splitVal[0])
}
default: default:
return nil, fmt.Errorf("multiple prefixes: %s", return nil, fmt.Errorf("unknown prefix: %s", prefix)
strings.Join(splitVal[:len(splitVal)-1], ", "))
} }
} }

View File

@@ -9,6 +9,8 @@ import (
"strings" "strings"
) )
const defaultHTTPSPort = 443
type Target struct { type Target struct {
Host string Host string
Port int Port int
@@ -24,45 +26,50 @@ func parseURL(host string) (string, int, error) {
return "", 0, fmt.Errorf("certlib/hosts: invalid host: %s", host) return "", 0, fmt.Errorf("certlib/hosts: invalid host: %s", host)
} }
if strings.ToLower(url.Scheme) != "https" { switch strings.ToLower(url.Scheme) {
case "https":
// OK
case "tls":
// OK
default:
return "", 0, errors.New("certlib/hosts: only https scheme supported") return "", 0, errors.New("certlib/hosts: only https scheme supported")
} }
if url.Port() == "" { if url.Port() == "" {
return url.Hostname(), 443, nil return url.Hostname(), defaultHTTPSPort, nil
} }
port, err := strconv.ParseInt(url.Port(), 10, 16) portInt, err2 := strconv.ParseInt(url.Port(), 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port()) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port())
} }
return url.Hostname(), int(port), nil return url.Hostname(), int(portInt), nil
} }
func parseHostPort(host string) (string, int, error) { func parseHostPort(host string) (string, int, error) {
host, sport, err := net.SplitHostPort(host) shost, sport, err := net.SplitHostPort(host)
if err == nil { if err == nil {
port, err := strconv.ParseInt(sport, 10, 16) portInt, err2 := strconv.ParseInt(sport, 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport)
} }
return host, int(port), nil return shost, int(portInt), nil
} }
return host, 443, nil return host, defaultHTTPSPort, nil
} }
func ParseHost(host string) (*Target, error) { func ParseHost(host string) (*Target, error) {
host, port, err := parseURL(host) uhost, port, err := parseURL(host)
if err == nil { if err == nil {
return &Target{Host: host, Port: port}, nil return &Target{Host: uhost, Port: port}, nil
} }
host, port, err = parseHostPort(host) shost, port, err := parseHostPort(host)
if err == nil { if err == nil {
return &Target{Host: host, Port: port}, nil return &Target{Host: shost, Port: port}, nil
} }
return nil, fmt.Errorf("certlib/hosts: invalid host: %s", host) return nil, fmt.Errorf("certlib/hosts: invalid host: %s", host)

View File

@@ -0,0 +1,35 @@
package hosts_test
import (
"testing"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
)
type testCase struct {
Host string
Target hosts.Target
}
var testCases = []testCase{
{Host: "server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "tls://server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "https://server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "https://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "tls://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "https://server-name/something/else", Target: hosts.Target{Host: "server-name", Port: 443}},
}
func TestParseHost(t *testing.T) {
for i, tc := range testCases {
target, err := hosts.ParseHost(tc.Host)
if err != nil {
t.Fatalf("test case %d: %s", i+1, err)
}
if target.Host != tc.Target.Host {
t.Fatalf("test case %d: got host '%s', want host '%s'", i+1, target.Host, tc.Target.Host)
}
}
}

View File

@@ -93,7 +93,7 @@ type signedData struct {
Version int Version int
DigestAlgorithms asn1.RawValue DigestAlgorithms asn1.RawValue
ContentInfo asn1.RawValue ContentInfo asn1.RawValue
Certificates asn1.RawValue `asn1:"optional" asn1:"tag:0"` Certificates asn1.RawValue `asn1:"optional"`
Crls asn1.RawValue `asn1:"optional"` Crls asn1.RawValue `asn1:"optional"`
SignerInfos asn1.RawValue SignerInfos asn1.RawValue
} }
@@ -158,63 +158,95 @@ type EncryptedContentInfo struct {
EncryptedContent []byte `asn1:"tag:0,optional"` EncryptedContent []byte `asn1:"tag:0,optional"`
} }
func unmarshalInit(raw []byte) (initPKCS7, error) {
var init initPKCS7
if _, err := asn1.Unmarshal(raw, &init); err != nil {
return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return init, nil
}
func populateData(msg *PKCS7, content asn1.RawValue) error {
msg.ContentInfo = "Data"
_, err := asn1.Unmarshal(content.Bytes, &msg.Content.Data)
if err != nil {
return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return nil
}
func populateSignedData(msg *PKCS7, contentBytes []byte) error {
msg.ContentInfo = "SignedData"
var sd signedData
if _, err := asn1.Unmarshal(contentBytes, &sd); err != nil {
return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
if len(sd.Certificates.Bytes) != 0 {
certs, err := x509.ParseCertificates(sd.Certificates.Bytes)
if err != nil {
return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
msg.Content.SignedData.Certificates = certs
}
if len(sd.Crls.Bytes) != 0 {
crl, err := x509.ParseRevocationList(sd.Crls.Bytes)
if err != nil {
return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
msg.Content.SignedData.Crl = crl
}
msg.Content.SignedData.Version = sd.Version
msg.Content.SignedData.Raw = contentBytes
return nil
}
func populateEncryptedData(msg *PKCS7, contentBytes []byte) error {
msg.ContentInfo = "EncryptedData"
var ed EncryptedData
if _, err := asn1.Unmarshal(contentBytes, &ed); err != nil {
return certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
if ed.Version != 0 {
return certerr.ParsingError(
certerr.ErrorSourceCertificate,
errors.New("only PKCS #7 encryptedData version 0 is supported"),
)
}
msg.Content.EncryptedData = ed
return nil
}
// ParsePKCS7 attempts to parse the DER encoded bytes of a // ParsePKCS7 attempts to parse the DER encoded bytes of a
// PKCS7 structure. // PKCS7 structure.
func ParsePKCS7(raw []byte) (msg *PKCS7, err error) { func ParsePKCS7(raw []byte) (*PKCS7, error) {
pkcs7, err := unmarshalInit(raw)
var pkcs7 initPKCS7
_, err = asn1.Unmarshal(raw, &pkcs7)
if err != nil { if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return nil, err
} }
msg = new(PKCS7) msg := new(PKCS7)
msg.Raw = pkcs7.Raw msg.Raw = pkcs7.Raw
msg.ContentInfo = pkcs7.ContentType.String() msg.ContentInfo = pkcs7.ContentType.String()
switch {
case msg.ContentInfo == ObjIDData:
msg.ContentInfo = "Data"
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &msg.Content.Data)
if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
case msg.ContentInfo == ObjIDSignedData:
msg.ContentInfo = "SignedData"
var signedData signedData
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &signedData)
if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
if len(signedData.Certificates.Bytes) != 0 {
msg.Content.SignedData.Certificates, err = x509.ParseCertificates(signedData.Certificates.Bytes)
if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
}
if len(signedData.Crls.Bytes) != 0 {
msg.Content.SignedData.Crl, err = x509.ParseRevocationList(signedData.Crls.Bytes)
if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
}
msg.Content.SignedData.Version = signedData.Version
msg.Content.SignedData.Raw = pkcs7.Content.Bytes
case msg.ContentInfo == ObjIDEncryptedData:
msg.ContentInfo = "EncryptedData"
var encryptedData EncryptedData
_, err = asn1.Unmarshal(pkcs7.Content.Bytes, &encryptedData)
if err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
if encryptedData.Version != 0 {
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("only PKCS #7 encryptedData version 0 is supported"))
}
msg.Content.EncryptedData = encryptedData
switch msg.ContentInfo {
case ObjIDData:
if e := populateData(msg, pkcs7.Content); e != nil {
return nil, e
}
case ObjIDSignedData:
if e := populateSignedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, e
}
case ObjIDEncryptedData:
if e := populateEncryptedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, e
}
default: default:
return nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("only PKCS# 7 content of type data, signed data or encrypted data can be parsed")) return nil, certerr.ParsingError(
certerr.ErrorSourceCertificate,
errors.New("only PKCS# 7 content of type data, signed data or encrypted data can be parsed"),
)
} }
return msg, nil return msg, nil
} }

View File

@@ -5,6 +5,7 @@ package revoke
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -89,35 +90,35 @@ func ldapURL(url string) bool {
// - false, false: an error was encountered while checking revocations. // - false, false: an error was encountered while checking revocations.
// - false, true: the certificate was checked successfully, and it is not revoked. // - false, true: the certificate was checked successfully, and it is not revoked.
// - true, true: the certificate was checked successfully, and it is revoked. // - true, true: the certificate was checked successfully, and it is revoked.
// - true, false: failure to check revocation status causes verification to fail // - true, false: failure to check revocation status causes verification to fail.
func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) { func revCheck(cert *x509.Certificate) (bool, bool, error) {
for _, url := range cert.CRLDistributionPoints { for _, url := range cert.CRLDistributionPoints {
if ldapURL(url) { if ldapURL(url) {
log.Infof("skipping LDAP CRL: %s", url) log.Infof("skipping LDAP CRL: %s", url)
continue continue
} }
if revoked, ok, err := certIsRevokedCRL(cert, url); !ok { if rvk, ok2, err2 := certIsRevokedCRL(cert, url); !ok2 {
log.Warning("error checking revocation via CRL") log.Warning("error checking revocation via CRL")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via CRL") log.Info("certificate is revoked via CRL")
return true, true, err return true, true, err2
} }
} }
if revoked, ok, err := certIsRevokedOCSP(cert, HardFail); !ok { if rvk, ok2, err2 := certIsRevokedOCSP(cert, HardFail); !ok2 {
log.Warning("error checking revocation via OCSP") log.Warning("error checking revocation via OCSP")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via OCSP") log.Info("certificate is revoked via OCSP")
return true, true, err return true, true, err2
} }
return false, true, nil return false, true, nil
@@ -125,13 +126,17 @@ func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) {
// fetchCRL fetches and parses a CRL. // fetchCRL fetches and parses a CRL.
func fetchCRL(url string) (*x509.RevocationList, error) { func fetchCRL(url string) (*x509.RevocationList, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 300 { if resp.StatusCode >= http.StatusMultipleChoices {
return nil, errors.New("failed to retrieve CRL") return nil, errors.New("failed to retrieve CRL")
} }
@@ -154,12 +159,11 @@ func getIssuer(cert *x509.Certificate) *x509.Certificate {
} }
return issuer return issuer
} }
// check a cert against a specific CRL. Returns the same bool pair // check a cert against a specific CRL. Returns the same bool pair
// as revCheck, plus an error if one occurred. // as revCheck, plus an error if one occurred.
func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err error) { func certIsRevokedCRL(cert *x509.Certificate, url string) (bool, bool, error) {
crlLock.Lock() crlLock.Lock()
crl, ok := CRLSet[url] crl, ok := CRLSet[url]
if ok && crl == nil { if ok && crl == nil {
@@ -187,10 +191,9 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
// check CRL signature // check CRL signature
if issuer != nil { if issuer != nil {
err = crl.CheckSignatureFrom(issuer) if sigErr := crl.CheckSignatureFrom(issuer); sigErr != nil {
if err != nil { log.Warningf("failed to verify CRL: %v", sigErr)
log.Warningf("failed to verify CRL: %v", err) return false, false, sigErr
return false, false, err
} }
} }
@@ -199,40 +202,44 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
crlLock.Unlock() crlLock.Unlock()
} }
for _, revoked := range crl.RevokedCertificates { for _, entry := range crl.RevokedCertificateEntries {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 { if cert.SerialNumber.Cmp(entry.SerialNumber) == 0 {
log.Info("Serial number match: intermediate is revoked.") log.Info("Serial number match: intermediate is revoked.")
return true, true, err return true, true, nil
} }
} }
return false, true, err return false, true, nil
} }
// VerifyCertificate ensures that the certificate passed in hasn't // VerifyCertificate ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificate(cert *x509.Certificate) (revoked, ok bool) { func VerifyCertificate(cert *x509.Certificate) (bool, bool) {
revoked, ok, _ = VerifyCertificateError(cert) revoked, ok, _ := VerifyCertificateError(cert)
return revoked, ok return revoked, ok
} }
// VerifyCertificateError ensures that the certificate passed in hasn't // VerifyCertificateError ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error) { func VerifyCertificateError(cert *x509.Certificate) (bool, bool, error) {
if !time.Now().Before(cert.NotAfter) { if !time.Now().Before(cert.NotAfter) {
msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter) msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter)
log.Info(msg) log.Info(msg)
return true, true, fmt.Errorf(msg) return true, true, errors.New(msg)
} else if !time.Now().After(cert.NotBefore) { } else if !time.Now().After(cert.NotBefore) {
msg := fmt.Sprintf("Certificate isn't valid until %s\n", cert.NotBefore) msg := fmt.Sprintf("Certificate isn't valid until %s\n", cert.NotBefore)
log.Info(msg) log.Info(msg)
return true, true, fmt.Errorf(msg) return true, true, errors.New(msg)
} }
return revCheck(cert) return revCheck(cert)
} }
func fetchRemote(url string) (*x509.Certificate, error) { func fetchRemote(url string) (*x509.Certificate, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -255,8 +262,12 @@ var ocspOpts = ocsp.RequestOptions{
Hash: crypto.SHA1, Hash: crypto.SHA1,
} }
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e error) { const ocspGetURLMaxLen = 256
var err error
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (bool, bool, error) {
var revoked bool
var ok bool
var lastErr error
ocspURLs := leaf.OCSPServer ocspURLs := leaf.OCSPServer
if len(ocspURLs) == 0 { if len(ocspURLs) == 0 {
@@ -272,15 +283,16 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts) ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts)
if err != nil { if err != nil {
return revoked, ok, err return false, false, err
} }
for _, server := range ocspURLs { for _, server := range ocspURLs {
resp, err := sendOCSPRequest(server, ocspRequest, leaf, issuer) resp, e := sendOCSPRequest(server, ocspRequest, leaf, issuer)
if err != nil { if e != nil {
if strict { if strict {
return revoked, ok, err return false, false, e
} }
lastErr = e
continue continue
} }
@@ -292,9 +304,9 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
revoked = true revoked = true
} }
return revoked, ok, err return revoked, ok, nil
} }
return revoked, ok, err return revoked, ok, lastErr
} }
// sendOCSPRequest attempts to request an OCSP response from the // sendOCSPRequest attempts to request an OCSP response from the
@@ -303,12 +315,21 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) { func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) {
var resp *http.Response var resp *http.Response
var err error var err error
if len(req) > 256 { if len(req) > ocspGetURLMaxLen {
buf := bytes.NewBuffer(req) buf := bytes.NewBuffer(req)
resp, err = HTTPClient.Post(server, "application/ocsp-request", buf) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodPost, server, buf)
if e != nil {
return nil, e
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
resp, err = HTTPClient.Do(httpReq)
} else { } else {
reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req)) reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req))
resp, err = HTTPClient.Get(reqURL) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
if e != nil {
return nil, e
}
resp, err = HTTPClient.Do(httpReq)
} }
if err != nil { if err != nil {
@@ -343,21 +364,21 @@ func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate)
var crlRead = io.ReadAll var crlRead = io.ReadAll
// SetCRLFetcher sets the function to use to read from the http response body // SetCRLFetcher sets the function to use to read from the http response body.
func SetCRLFetcher(fn func(io.Reader) ([]byte, error)) { func SetCRLFetcher(fn func(io.Reader) ([]byte, error)) {
crlRead = fn crlRead = fn
} }
var remoteRead = io.ReadAll var remoteRead = io.ReadAll
// SetRemoteFetcher sets the function to use to read from the http response body // SetRemoteFetcher sets the function to use to read from the http response body.
func SetRemoteFetcher(fn func(io.Reader) ([]byte, error)) { func SetRemoteFetcher(fn func(io.Reader) ([]byte, error)) {
remoteRead = fn remoteRead = fn
} }
var ocspRead = io.ReadAll var ocspRead = io.ReadAll
// SetOCSPFetcher sets the function to use to read from the http response body // SetOCSPFetcher sets the function to use to read from the http response body.
func SetOCSPFetcher(fn func(io.Reader) ([]byte, error)) { func SetOCSPFetcher(fn func(io.Reader) ([]byte, error)) {
ocspRead = fn ocspRead = fn
} }

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package revoke package revoke
import ( import (
@@ -50,7 +51,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// to indicate that this is the case. // to indicate that this is the case.
// 2014/05/22 14:18:17 Certificate expired 2014-04-04 14:14:20 +0000 UTC // 2014/05/22 14:18:17 Certificate expired 2014-04-04 14:14:20 +0000 UTC
// 2014/05/22 14:18:17 Revoked certificate: misc/intermediate_ca/ActalisServerAuthenticationCA.crt // 2014/05/22 14:18:17 Revoked certificate: misc/intermediate_ca/ActalisServerAuthenticationCA.crt.
var expiredCert = mustParse(`-----BEGIN CERTIFICATE----- var expiredCert = mustParse(`-----BEGIN CERTIFICATE-----
MIIEXTCCA8agAwIBAgIEBycURTANBgkqhkiG9w0BAQUFADB1MQswCQYDVQQGEwJV MIIEXTCCA8agAwIBAgIEBycURTANBgkqhkiG9w0BAQUFADB1MQswCQYDVQQGEwJV
UzEYMBYGA1UEChMPR1RFIENvcnBvcmF0aW9uMScwJQYDVQQLEx5HVEUgQ3liZXJU UzEYMBYGA1UEChMPR1RFIENvcnBvcmF0aW9uMScwJQYDVQQLEx5HVEUgQ3liZXJU
@@ -80,7 +81,7 @@ sESPRwHkcMUNdAp37FLweUw=
// 2014/05/22 14:18:31 Serial number match: intermediate is revoked. // 2014/05/22 14:18:31 Serial number match: intermediate is revoked.
// 2014/05/22 14:18:31 certificate is revoked via CRL // 2014/05/22 14:18:31 certificate is revoked via CRL
// 2014/05/22 14:18:31 Revoked certificate: misc/intermediate_ca/MobileArmorEnterpriseCA.crt // 2014/05/22 14:18:31 Revoked certificate: misc/intermediate_ca/MobileArmorEnterpriseCA.crt.
var revokedCert = mustParse(`-----BEGIN CERTIFICATE----- var revokedCert = mustParse(`-----BEGIN CERTIFICATE-----
MIIEEzCCAvugAwIBAgILBAAAAAABGMGjftYwDQYJKoZIhvcNAQEFBQAwcTEoMCYG MIIEEzCCAvugAwIBAgILBAAAAAABGMGjftYwDQYJKoZIhvcNAQEFBQAwcTEoMCYG
A1UEAxMfR2xvYmFsU2lnbiBSb290U2lnbiBQYXJ0bmVycyBDQTEdMBsGA1UECxMU A1UEAxMfR2xvYmFsU2lnbiBSb290U2lnbiBQYXJ0bmVycyBDQTEdMBsGA1UECxMU
@@ -106,7 +107,7 @@ Kz5vh+5tmytUPKA8hUgmLWe94lMb7Uqq2wgZKsqun5DAWleKu81w7wEcOrjiiB+x
jeBHq7OnpWm+ccTOPCE6H4ZN4wWVS7biEBUdop/8HgXBPQHWAdjL jeBHq7OnpWm+ccTOPCE6H4ZN4wWVS7biEBUdop/8HgXBPQHWAdjL
-----END CERTIFICATE-----`) -----END CERTIFICATE-----`)
// A Comodo intermediate CA certificate with issuer url, CRL url and OCSP url // A Comodo intermediate CA certificate with issuer url, CRL url and OCSP url.
var goodComodoCA = (`-----BEGIN CERTIFICATE----- var goodComodoCA = (`-----BEGIN CERTIFICATE-----
MIIGCDCCA/CgAwIBAgIQKy5u6tl1NmwUim7bo3yMBzANBgkqhkiG9w0BAQwFADCB MIIGCDCCA/CgAwIBAgIQKy5u6tl1NmwUim7bo3yMBzANBgkqhkiG9w0BAQwFADCB
hTELMAkGA1UEBhMCR0IxGzAZBgNVBAgTEkdyZWF0ZXIgTWFuY2hlc3RlcjEQMA4G hTELMAkGA1UEBhMCR0IxGzAZBgNVBAgTEkdyZWF0ZXIgTWFuY2hlc3RlcjEQMA4G
@@ -153,7 +154,7 @@ func mustParse(pemData string) *x509.Certificate {
panic("Invalid PEM type.") panic("Invalid PEM type.")
} }
cert, err := x509.ParseCertificate([]byte(block.Bytes)) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }
@@ -182,7 +183,6 @@ func TestGood(t *testing.T) {
} else if revoked { } else if revoked {
t.Fatalf("good certificate should not have been marked as revoked") t.Fatalf("good certificate should not have been marked as revoked")
} }
} }
func TestLdap(t *testing.T) { func TestLdap(t *testing.T) {
@@ -230,7 +230,6 @@ func TestBadCRLSet(t *testing.T) {
t.Fatalf("key emptystring should be deleted from CRLSet") t.Fatalf("key emptystring should be deleted from CRLSet")
} }
delete(CRLSet, "") delete(CRLSet, "")
} }
func TestCachedCRLSet(t *testing.T) { func TestCachedCRLSet(t *testing.T) {
@@ -241,13 +240,11 @@ func TestCachedCRLSet(t *testing.T) {
} }
func TestRemoteFetchError(t *testing.T) { func TestRemoteFetchError(t *testing.T) {
badurl := ":" badurl := ":"
if _, err := fetchRemote(badurl); err == nil { if _, err := fetchRemote(badurl); err == nil {
t.Fatalf("fetching bad url should result in non-nil error") t.Fatalf("fetching bad url should result in non-nil error")
} }
} }
func TestNoOCSPServers(t *testing.T) { func TestNoOCSPServers(t *testing.T) {

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"net" "net"
@@ -28,10 +29,16 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Printf("connecting to %s/%s... ", addr, proto) fmt.Printf("connecting to %s/%s... ", addr, proto)
os.Stdout.Sync() if err = os.Stdout.Sync(); err != nil {
return err
}
} }
conn, err := net.DialTimeout(proto, addr, timeout) dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := dialer.DialContext(context.Background(), proto, addr)
if err != nil { if err != nil {
if verbose { if verbose {
fmt.Println("failed.") fmt.Println("failed.")
@@ -42,8 +49,8 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Println("OK") fmt.Println("OK")
} }
conn.Close()
return nil return conn.Close()
} }
func main() { func main() {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"crypto/x509" "crypto/x509"
"embed" "embed"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,22 +15,22 @@ import (
// loadCertsFromFile attempts to parse certificates from a file that may be in // loadCertsFromFile attempts to parse certificates from a file that may be in
// PEM or DER/PKCS#7 format. Returns the parsed certificates or an error. // PEM or DER/PKCS#7 format. Returns the parsed certificates or an error.
func loadCertsFromFile(path string) ([]*x509.Certificate, error) { func loadCertsFromFile(path string) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Try PEM first if certs, err = certlib.ParseCertificatesPEM(data); err == nil {
if certs, err := certlib.ParseCertificatesPEM(data); err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password) if certs, _, err = certlib.ParseCertificatesDER(data, ""); err == nil {
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil {
return certs, nil return certs, nil
} else {
return nil, err
} }
return nil, err
} }
func makePoolFromFile(path string) (*x509.CertPool, error) { func makePoolFromFile(path string) (*x509.CertPool, error) {
@@ -56,49 +57,50 @@ var embeddedTestdata embed.FS
// loadCertsFromBytes attempts to parse certificates from bytes that may be in // loadCertsFromBytes attempts to parse certificates from bytes that may be in
// PEM or DER/PKCS#7 format. // PEM or DER/PKCS#7 format.
func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) { func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) {
// Try PEM first certs, err := certlib.ParseCertificatesPEM(data)
if certs, err := certlib.ParseCertificatesPEM(data); err == nil { if err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password)
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil { certs, _, err = certlib.ParseCertificatesDER(data, "")
if err == nil {
return certs, nil return certs, nil
} else {
return nil, err
} }
return nil, err
} }
func makePoolFromBytes(data []byte) (*x509.CertPool, error) { func makePoolFromBytes(data []byte) (*x509.CertPool, error) {
certs, err := loadCertsFromBytes(data) certs, err := loadCertsFromBytes(data)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return nil, fmt.Errorf("failed to load CA certificates from embedded bytes") return nil, errors.New("failed to load CA certificates from embedded bytes")
} }
pool := x509.NewCertPool() pool := x509.NewCertPool()
for _, c := range certs { for _, c := range certs {
pool.AddCert(c) pool.AddCert(c)
} }
return pool, nil return pool, nil
} }
// isSelfSigned returns true if the given certificate is self-signed. // isSelfSigned returns true if the given certificate is self-signed.
// It checks that the subject and issuer match and that the certificate's // It checks that the subject and issuer match and that the certificate's
// signature verifies against its own public key. // signature verifies against its own public key.
func isSelfSigned(cert *x509.Certificate) bool { func isSelfSigned(cert *x509.Certificate) bool {
if cert == nil { if cert == nil {
return false return false
} }
// Quick check: subject and issuer match // Quick check: subject and issuer match
if cert.Subject.String() != cert.Issuer.String() { if cert.Subject.String() != cert.Issuer.String() {
return false return false
} }
// Cryptographic check: the certificate is signed by itself // Cryptographic check: the certificate is signed by itself
if err := cert.CheckSignatureFrom(cert); err != nil { if err := cert.CheckSignatureFrom(cert); err != nil {
return false return false
} }
return true return true
} }
func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string) { func verifyAgainstCA(caPool *x509.CertPool, path string) (bool, string) {
certs, err := loadCertsFromFile(path) certs, err := loadCertsFromFile(path)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -117,14 +119,14 @@ func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expiry string) { func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (bool, string) {
certs, err := loadCertsFromBytes(certData) certs, err := loadCertsFromBytes(certData)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -143,92 +145,159 @@ func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expi
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
// selftest runs built-in validation using embedded certificates. type testCase struct {
func selftest() int { name string
type testCase struct { caFile string
name string certFile string
caFile string expectOK bool
certFile string }
expectOK bool
func (tc testCase) Run() error {
caBytes, err := embeddedTestdata.ReadFile(tc.caFile)
if err != nil {
return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.caFile, err)
} }
cases := []testCase{ certBytes, err := embeddedTestdata.ReadFile(tc.certFile)
{name: "ISRG Root X1 validates LE E7", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/le-e7.pem", expectOK: true}, if err != nil {
{name: "ISRG Root X1 does NOT validate Google WR2", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/goog-wr2.pem", expectOK: false}, return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.certFile, err)
{name: "GTS R1 validates Google WR2", caFile: "testdata/gts-r1.pem", certFile: "testdata/goog-wr2.pem", expectOK: true}, }
{name: "GTS R1 does NOT validate LE E7", caFile: "testdata/gts-r1.pem", certFile: "testdata/le-e7.pem", expectOK: false},
}
failures := 0 pool, err := makePoolFromBytes(caBytes)
for _, tc := range cases { if err != nil || pool == nil {
caBytes, err := embeddedTestdata.ReadFile(tc.caFile) return fmt.Errorf("selftest: failed to build CA pool for %s: %w", tc.caFile, err)
}
ok, exp := verifyAgainstCABytes(pool, certBytes)
if ok != tc.expectOK {
return fmt.Errorf("%s: unexpected result: got %v, want %v", tc.name, ok, tc.expectOK)
}
if ok {
fmt.Printf("%s: OK (expires %s)\n", tc.name, exp)
}
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
return nil
}
var cases = []testCase{
{
name: "ISRG Root X1 validates LE E7",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/le-e7.pem",
expectOK: true,
},
{
name: "ISRG Root X1 does NOT validate Google WR2",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: false,
},
{
name: "GTS R1 validates Google WR2",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: true,
},
{
name: "GTS R1 does NOT validate LE E7",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/le-e7.pem",
expectOK: false,
},
}
// selftest runs built-in validation using embedded certificates.
func selftest() int {
failures := 0
for _, tc := range cases {
err := tc.Run()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.caFile, err) fmt.Fprintln(os.Stderr, err)
failures++ failures++
continue continue
} }
certBytes, err := embeddedTestdata.ReadFile(tc.certFile) }
// Verify that both embedded root CAs are detected as self-signed
roots := []string{"testdata/gts-r1.pem", "testdata/isrg-root-x1.pem"}
for _, root := range roots {
b, err := embeddedTestdata.ReadFile(root)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.certFile, err) fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", root, err)
failures++ failures++
continue continue
} }
pool, err := makePoolFromBytes(caBytes) certs, err := loadCertsFromBytes(b)
if err != nil || pool == nil { if err != nil || len(certs) == 0 {
fmt.Fprintf(os.Stderr, "selftest: failed to build CA pool for %s: %v\n", tc.caFile, err) fmt.Fprintf(os.Stderr, "selftest: failed to parse cert(s) from %s: %v\n", root, err)
failures++ failures++
continue continue
} }
ok, exp := verifyAgainstCABytes(pool, certBytes) leaf := certs[0]
if ok != tc.expectOK { if isSelfSigned(leaf) {
fmt.Printf("%s: unexpected result: got %v, want %v\n", tc.name, ok, tc.expectOK) fmt.Printf("%s: SELF-SIGNED (as expected)\n", root)
failures++
} else { } else {
if ok { fmt.Printf("%s: expected SELF-SIGNED, but was not detected as such\n", root)
fmt.Printf("%s: OK (expires %s)\n", tc.name, exp) failures++
} else {
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
}
} }
} }
// Verify that both embedded root CAs are detected as self-signed if failures == 0 {
roots := []string{"testdata/gts-r1.pem", "testdata/isrg-root-x1.pem"} fmt.Println("selftest: PASS")
for _, root := range roots { return 0
b, err := embeddedTestdata.ReadFile(root) }
if err != nil { fmt.Fprintf(os.Stderr, "selftest: FAIL (%d failure(s))\n", failures)
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", root, err) return 1
failures++ }
continue
}
certs, err := loadCertsFromBytes(b)
if err != nil || len(certs) == 0 {
fmt.Fprintf(os.Stderr, "selftest: failed to parse cert(s) from %s: %v\n", root, err)
failures++
continue
}
leaf := certs[0]
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED (as expected)\n", root)
} else {
fmt.Printf("%s: expected SELF-SIGNED, but was not detected as such\n", root)
failures++
}
}
if failures == 0 { // expiryString returns a YYYY-MM-DD date string to display for certificate
fmt.Println("selftest: PASS") // expiry. If an explicit exp string is provided, it is used. Otherwise, if a
return 0 // leaf certificate is available, its NotAfter is formatted. As a last resort,
} // it falls back to today's date (should not normally happen).
fmt.Fprintf(os.Stderr, "selftest: FAIL (%d failure(s))\n", failures) func expiryString(leaf *x509.Certificate, exp string) string {
return 1 if exp != "" {
return exp
}
if leaf != nil {
return leaf.NotAfter.Format("2006-01-02")
}
return time.Now().Format("2006-01-02")
}
// processCert verifies a single certificate file against the provided CA pool
// and prints the result in the required format, handling self-signed
// certificates specially.
func processCert(caPool *x509.CertPool, certPath string) {
ok, exp := verifyAgainstCA(caPool, certPath)
name := filepath.Base(certPath)
// Try to load the leaf cert for self-signed detection and expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// Prefer the SELF-SIGNED label if applicable
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
return
}
if ok {
fmt.Printf("%s: OK (expires %s)\n", name, expiryString(leaf, exp))
return
}
fmt.Printf("%s: INVALID\n", name)
} }
func main() { func main() {
@@ -250,38 +319,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
for _, certPath := range os.Args[2:] { for _, certPath := range os.Args[2:] {
ok, exp := verifyAgainstCA(caPool, certPath) processCert(caPool, certPath)
name := filepath.Base(certPath) }
// Load the leaf once for self-signed detection and potential expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// If the certificate is self-signed, prefer the SELF-SIGNED label
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
continue
}
if ok {
// Display with the requested format
// Example: file: OK (expires 2031-01-01)
// Ensure deterministic date formatting
// Note: no timezone displayed; date only as per example
// If exp ended up empty for some reason, recompute safely
if exp == "" {
if leaf != nil {
exp = leaf.NotAfter.Format("2006-01-02")
} else {
// fallback to the current date to avoid empty; though shouldn't happen
exp = time.Now().Format("2006-01-02")
}
}
fmt.Printf("%s: OK (expires %s)\n", name, exp)
} else {
fmt.Printf("%s: INVALID\n", name)
}
}
} }

View File

@@ -0,0 +1,28 @@
# Build and runtime image for cert-bundler
# Usage (from repo root or cmd/cert-bundler directory):
# docker build -t cert-bundler:latest -f cmd/cert-bundler/Dockerfile .
# docker run --rm -v "$PWD":/work cert-bundler:latest
# This expects a /work/bundle.yaml file in the mounted directory and
# will write generated bundles to /work/bundle.
# Build stage
FROM golang:1.24.3-alpine AS build
WORKDIR /src
# Copy go module files and download dependencies first for better caching
RUN go install git.wntrmute.dev/kyle/goutils/cmd/cert-bundler@v1.13.2 && \
mv /go/bin/cert-bundler /usr/local/bin/cert-bundler
# Runtime stage (kept as golang:alpine per requirement)
FROM golang:1.24.3-alpine
# Create a work directory that users will typically mount into
WORKDIR /work
VOLUME ["/work"]
# Copy the built binary from the builder stage
COPY --from=build /usr/local/bin/cert-bundler /usr/local/bin/cert-bundler
# Default command: read bundle.yaml from current directory and output to ./bundle
ENTRYPOINT ["/usr/local/bin/cert-bundler"]
CMD ["-c", "/work/bundle.yaml", "-o", "/work/bundle"]

View File

@@ -1,64 +1,19 @@
package main package main
import ( import (
"archive/tar"
"archive/zip"
"compress/gzip"
"crypto/sha256"
"crypto/x509"
_ "embed" _ "embed"
"encoding/pem"
"flag" "flag"
"fmt" "fmt"
"os" "os"
"path/filepath"
"strings"
"time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib/bundler"
"gopkg.in/yaml.v2"
) )
// Config represents the top-level YAML configuration
type Config struct {
Config struct {
Hashes string `yaml:"hashes"`
Expiry string `yaml:"expiry"`
} `yaml:"config"`
Chains map[string]ChainGroup `yaml:"chains"`
}
// ChainGroup represents a named group of certificate chains
type ChainGroup struct {
Certs []CertChain `yaml:"certs"`
Outputs Outputs `yaml:"outputs"`
}
// CertChain represents a root certificate and its intermediates
type CertChain struct {
Root string `yaml:"root"`
Intermediates []string `yaml:"intermediates"`
}
// Outputs defines output format options
type Outputs struct {
IncludeSingle bool `yaml:"include_single"`
IncludeIndividual bool `yaml:"include_individual"`
Manifest bool `yaml:"manifest"`
Formats []string `yaml:"formats"`
Encoding string `yaml:"encoding"`
}
var ( var (
configFile string configFile string
outputDir string outputDir string
) )
var formatExtensions = map[string]string{
"zip": ".zip",
"tgz": ".tar.gz",
}
//go:embed README.txt //go:embed README.txt
var readmeContent string var readmeContent string
@@ -77,452 +32,10 @@ func main() {
os.Exit(1) os.Exit(1)
} }
// Load and parse configuration if err := bundler.Run(configFile, outputDir); err != nil {
cfg, err := loadConfig(configFile) fmt.Fprintf(os.Stderr, "Error: %v\n", err)
if err != nil {
fmt.Fprintf(os.Stderr, "Error loading config: %v\n", err)
os.Exit(1) os.Exit(1)
} }
// Parse expiry duration (default 1 year)
expiryDuration := 365 * 24 * time.Hour
if cfg.Config.Expiry != "" {
expiryDuration, err = parseDuration(cfg.Config.Expiry)
if err != nil {
fmt.Fprintf(os.Stderr, "Error parsing expiry: %v\n", err)
os.Exit(1)
}
}
// Create output directory if it doesn't exist
if err := os.MkdirAll(outputDir, 0755); err != nil {
fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err)
os.Exit(1)
}
// Process each chain group
// Pre-allocate createdFiles based on total number of formats across all groups
totalFormats := 0
for _, group := range cfg.Chains {
totalFormats += len(group.Outputs.Formats)
}
createdFiles := make([]string, 0, totalFormats)
for groupName, group := range cfg.Chains {
files, err := processChainGroup(groupName, group, expiryDuration)
if err != nil {
fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, err)
os.Exit(1)
}
createdFiles = append(createdFiles, files...)
}
// Generate hash file for all created archives
if cfg.Config.Hashes != "" {
hashFile := filepath.Join(outputDir, cfg.Config.Hashes)
if err := generateHashFile(hashFile, createdFiles); err != nil {
fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", err)
os.Exit(1)
}
}
fmt.Println("Certificate bundling completed successfully") fmt.Println("Certificate bundling completed successfully")
} }
func loadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return &cfg, nil
}
func parseDuration(s string) (time.Duration, error) {
// Support simple formats like "1y", "6m", "30d"
if len(s) < 2 {
return 0, fmt.Errorf("invalid duration format: %s", s)
}
unit := s[len(s)-1]
value := s[:len(s)-1]
var multiplier time.Duration
switch unit {
case 'y', 'Y':
multiplier = 365 * 24 * time.Hour
case 'm', 'M':
multiplier = 30 * 24 * time.Hour
case 'd', 'D':
multiplier = 24 * time.Hour
default:
return time.ParseDuration(s)
}
var num int
_, err := fmt.Sscanf(value, "%d", &num)
if err != nil {
return 0, fmt.Errorf("invalid duration value: %s", s)
}
return time.Duration(num) * multiplier, nil
}
func processChainGroup(groupName string, group ChainGroup, expiryDuration time.Duration) ([]string, error) {
// Default encoding to "pem" if not specified
encoding := group.Outputs.Encoding
if encoding == "" {
encoding = "pem"
}
// Collect certificates from all chains in the group
singleFileCerts, individualCerts, err := loadAndCollectCerts(group.Certs, group.Outputs, expiryDuration)
if err != nil {
return nil, err
}
// Prepare files for inclusion in archives
archiveFiles, err := prepareArchiveFiles(singleFileCerts, individualCerts, group.Outputs, encoding)
if err != nil {
return nil, err
}
// Create archives for the entire group
createdFiles, err := createArchiveFiles(groupName, group.Outputs.Formats, archiveFiles)
if err != nil {
return nil, err
}
return createdFiles, nil
}
// loadAndCollectCerts loads all certificates from chains and collects them for processing
func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration time.Duration) ([]*x509.Certificate, []certWithPath, error) {
var singleFileCerts []*x509.Certificate
var individualCerts []certWithPath
for _, chain := range chains {
// Load root certificate
rootCert, err := certlib.LoadCertificate(chain.Root)
if err != nil {
return nil, nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err)
}
// Check expiry for root
checkExpiry(chain.Root, rootCert, expiryDuration)
// Add root to collections if needed
if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, rootCert)
}
if outputs.IncludeIndividual {
individualCerts = append(individualCerts, certWithPath{
cert: rootCert,
path: chain.Root,
})
}
// Load and validate intermediates
for _, intPath := range chain.Intermediates {
intCert, err := certlib.LoadCertificate(intPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err)
}
// Validate that intermediate is signed by root
if err := intCert.CheckSignatureFrom(rootCert); err != nil {
return nil, nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err)
}
// Check expiry for intermediate
checkExpiry(intPath, intCert, expiryDuration)
// Add intermediate to collections if needed
if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, intCert)
}
if outputs.IncludeIndividual {
individualCerts = append(individualCerts, certWithPath{
cert: intCert,
path: intPath,
})
}
}
}
return singleFileCerts, individualCerts, nil
}
// prepareArchiveFiles prepares all files to be included in archives
func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []certWithPath, outputs Outputs, encoding string) ([]fileEntry, error) {
var archiveFiles []fileEntry
// Handle a single bundle file
if outputs.IncludeSingle && len(singleFileCerts) > 0 {
files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true)
if err != nil {
return nil, fmt.Errorf("failed to encode single bundle: %v", err)
}
archiveFiles = append(archiveFiles, files...)
}
// Handle individual files
if outputs.IncludeIndividual {
for _, cp := range individualCerts {
baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path))
files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false)
if err != nil {
return nil, fmt.Errorf("failed to encode individual cert %s: %v", cp.path, err)
}
archiveFiles = append(archiveFiles, files...)
}
}
// Generate manifest if requested
if outputs.Manifest {
manifestContent := generateManifest(archiveFiles)
archiveFiles = append(archiveFiles, fileEntry{
name: "MANIFEST",
content: manifestContent,
})
}
return archiveFiles, nil
}
// createArchiveFiles creates archive files in the specified formats
func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) {
createdFiles := make([]string, 0, len(formats))
for _, format := range formats {
ext, ok := formatExtensions[format]
if !ok {
return nil, fmt.Errorf("unsupported format: %s", format)
}
archivePath := filepath.Join(outputDir, groupName+ext)
switch format {
case "zip":
if err := createZipArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create zip archive: %v", err)
}
case "tgz":
if err := createTarGzArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create tar.gz archive: %v", err)
}
default:
return nil, fmt.Errorf("unsupported format: %s", format)
}
createdFiles = append(createdFiles, archivePath)
}
return createdFiles, nil
}
func checkExpiry(path string, cert *x509.Certificate, expiryDuration time.Duration) {
now := time.Now()
expiryThreshold := now.Add(expiryDuration)
if cert.NotAfter.Before(expiryThreshold) {
daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24)
if daysUntilExpiry < 0 {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s has EXPIRED (expired %d days ago)\n", path, -daysUntilExpiry)
} else {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02"))
}
}
}
type fileEntry struct {
name string
content []byte
}
type certWithPath struct {
cert *x509.Certificate
path string
}
// encodeCertsToFiles converts certificates to file entries based on encoding type
// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file
func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding string, isSingle bool) ([]fileEntry, error) {
var files []fileEntry
switch encoding {
case "pem":
pemContent := encodeCertsToPEM(certs)
files = append(files, fileEntry{
name: baseName + ".pem",
content: pemContent,
})
case "der":
if isSingle {
// For single file in DER, concatenate all cert DER bytes
var derContent []byte
for _, cert := range certs {
derContent = append(derContent, cert.Raw...)
}
files = append(files, fileEntry{
name: baseName + ".crt",
content: derContent,
})
} else {
// Individual DER file (should only have one cert)
if len(certs) > 0 {
files = append(files, fileEntry{
name: baseName + ".crt",
content: certs[0].Raw,
})
}
}
case "both":
// Add PEM version
pemContent := encodeCertsToPEM(certs)
files = append(files, fileEntry{
name: baseName + ".pem",
content: pemContent,
})
// Add DER version
if isSingle {
var derContent []byte
for _, cert := range certs {
derContent = append(derContent, cert.Raw...)
}
files = append(files, fileEntry{
name: baseName + ".crt",
content: derContent,
})
} else {
if len(certs) > 0 {
files = append(files, fileEntry{
name: baseName + ".crt",
content: certs[0].Raw,
})
}
}
default:
return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding)
}
return files, nil
}
// encodeCertsToPEM encodes certificates to PEM format
func encodeCertsToPEM(certs []*x509.Certificate) []byte {
var pemContent []byte
for _, cert := range certs {
pemBlock := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
pemContent = append(pemContent, pem.EncodeToMemory(pemBlock)...)
}
return pemContent
}
func generateManifest(files []fileEntry) []byte {
var manifest strings.Builder
for _, file := range files {
if file.name == "MANIFEST" {
continue
}
hash := sha256.Sum256(file.content)
manifest.WriteString(fmt.Sprintf("%x %s\n", hash, file.name))
}
return []byte(manifest.String())
}
func createZipArchive(path string, files []fileEntry) error {
f, err := os.Create(path)
if err != nil {
return err
}
w := zip.NewWriter(f)
for _, file := range files {
fw, err := w.Create(file.name)
if err != nil {
w.Close()
f.Close()
return err
}
if _, err := fw.Write(file.content); err != nil {
w.Close()
f.Close()
return err
}
}
// Check errors on close operations
if err := w.Close(); err != nil {
f.Close()
return err
}
return f.Close()
}
func createTarGzArchive(path string, files []fileEntry) error {
f, err := os.Create(path)
if err != nil {
return err
}
gw := gzip.NewWriter(f)
tw := tar.NewWriter(gw)
for _, file := range files {
hdr := &tar.Header{
Name: file.name,
Mode: 0644,
Size: int64(len(file.content)),
}
if err := tw.WriteHeader(hdr); err != nil {
tw.Close()
gw.Close()
f.Close()
return err
}
if _, err := tw.Write(file.content); err != nil {
tw.Close()
gw.Close()
f.Close()
return err
}
}
// Check errors on close operations in the correct order
if err := tw.Close(); err != nil {
gw.Close()
f.Close()
return err
}
if err := gw.Close(); err != nil {
f.Close()
return err
}
return f.Close()
}
func generateHashFile(path string, files []string) error {
f, err := os.Create(path)
if err != nil {
return err
}
defer f.Close()
for _, file := range files {
data, err := os.ReadFile(file)
if err != nil {
return err
}
hash := sha256.Sum256(data)
fmt.Fprintf(f, "%x %s\n", hash, filepath.Base(file))
}
return nil
}

View File

@@ -1,194 +0,0 @@
Task: build a certificate bundling tool in cmd/cert-bundler. It
creates archives of certificates chains.
A YAML file for this looks something like:
``` yaml
config:
hashes: bundle.sha256
expiry: 1y
chains:
core_certs:
certs:
- root: roots/core-ca.pem
intermediates:
- int/cca1.pem
- int/cca2.pem
- int/cca3.pem
- root: roots/ssh-ca.pem
intermediates:
- ssh/ssh_dmz1.pem
- ssh/ssh_internal.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
```
Some requirements:
1. First, all the certificates should be loaded.
2. For each root, each of the indivudal intermediates should be
checked to make sure they are properly signed by the root CA.
3. The program should optionally take an expiration period (defaulting
to one year), specified in config.expiration, and if any certificate
is within that expiration period, a warning should be printed.
4. If outputs.include_single is true, all certificates under chains
should be concatenated into a single file.
5. If outputs.include_individual is true, all certificates under
chains should be included at the root level (e.g. int/cca2.pem
would be cca2.pem in the archive).
6. If bundle.manifest is true, a "MANIFEST" file is created with
SHA256 sums of each file included in the archive.
7. For each of the formats, create an archive file in the output
directory (specified with `-o`) with that format.
- If zip is included, create a .zip file.
- If tgz is included, create a .tar.gz file with default compression
levels.
- All archive files should include any generated files (single
and/or individual) in the top-level directory.
8. In the output directory, create a file with the same name as
config.hashes that contains the SHA256 sum of all files created.
-----
The outputs.include_single and outputs.include_individual describe
what should go in the final archive. If both are specified, the output
archive should include both a single bundle.pem and each individual
certificate, for example.
-----
As it stands, given the following `bundle.yaml`:
``` yaml
config:
hashes: bundle.sha256
expiry: 1y
chains:
core_certs:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
- root: pems/isrg-root-x1.pem
intermediates:
- pems/le-e7.pem
outputs:
include_single: true
include_individual: false
manifest: true
formats:
- zip
- tgz
google_certs:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
outputs:
include_single: true
include_individual: false
manifest: true
formats:
- tgz
lets_encrypt:
certs:
- root: pems/isrg-root-x1.pem
intermediates:
- pems/le-e7.pem
outputs:
include_single: false
include_individual: true
manifest: false
formats:
- zip
```
The program outputs the following files:
- bundle.sha256
- core_certs_0.tgz (contains individual certs)
- core_certs_0.zip (contains individual certs)
- core_certs_1.tgz (contains core_certs.pem)
- core_certs_1.zip (contains core_certs.pem)
- google_certs_0.tgz
- lets_encrypt_0.zip
It should output
- bundle.sha256
- core_certs.tgz
- core_certs.zip
- google_certs.tgz
- lets_encrypt.zip
core_certs.* should contain `bundle.pem` and all the individual
certs. There should be no _$n$ variants of archives.
-----
Add an additional field to outputs: encoding. It should accept one of
`der`, `pem`, or `both`. If `der`, certificates should be output as a
`.crt` file containing a DER-encoded certificate. If `pem`, certificates
should be output as a `.pem` file containing a PEM-encoded certificate.
If both, both the `.crt` and `.pem` certificate should be included.
For example, given the previous config, if `encoding` is der, the
google_certs.tgz archive should contain
- bundle.crt
- MANIFEST
Or with lets_encrypt.zip:
- isrg-root-x1.crt
- le-e7.crt
However, if `encoding` is pem, the lets_encrypt.zip archive should contain:
- isrg-root-x1.pem
- le-e7.pem
And if it `encoding` is both, the lets_encrypt.zip archive should contain:
- isrg-root-x1.crt
- isrg-root-x1.pem
- le-e7.crt
- le-e7.pem
-----
The tgz format should output a `.tar.gz` file instead of a `.tgz` file.
-----
Move the format extensions to a global variable.
-----
Write a README.txt with a description of the bundle.yaml format.
Additionally, update the help text for the program (e.g. with `-h`)
to provide the same detailed information.
-----
It may be easier to embed the README.txt in the program on build.
-----
For the archive (tar.gz and zip) writers, make sure errors are
checked at the end, and don't just defer the close operations.

View File

@@ -2,6 +2,19 @@ config:
hashes: bundle.sha256 hashes: bundle.sha256
expiry: 1y expiry: 1y
chains: chains:
weird:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
- root: pems/isrg-root-x1.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
core_certs: core_certs:
certs: certs:
- root: pems/gts-r1.pem - root: pems/gts-r1.pem

View File

@@ -1,4 +0,0 @@
5ed8bf9ed693045faa8a5cb0edc4a870052e56aef6291ce8b1604565affbc2a4 core_certs.zip
e59eddc590d2f7b790a87c5b56e81697088ab54be382c0e2c51b82034006d308 core_certs.tgz
51b9b63b1335118079e90700a3a5b847c363808e9116e576ca84f301bc433289 google_certs.tgz
3d1910ca8835c3ded1755a8c7d6c48083c2f3ff68b2bfbf932aaf27e29d0a232 lets_encrypt.zip

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,13 +1,15 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
@@ -22,6 +24,13 @@ var (
verbose bool verbose bool
) )
var (
strOK = "OK"
strExpired = "EXPIRED"
strRevoked = "REVOKED"
strUnknown = "UNKNOWN"
)
func main() { func main() {
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal") flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects") flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
@@ -41,16 +50,16 @@ func main() {
for _, target := range flag.Args() { for _, target := range flag.Args() {
status, err := processTarget(target) status, err := processTarget(target)
switch status { switch status {
case "OK": case strOK:
fmt.Printf("%s: OK\n", target) fmt.Printf("%s: %s\n", target, strOK)
case "EXPIRED": case strExpired:
fmt.Printf("%s: EXPIRED: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strExpired, err)
exitCode = 1 exitCode = 1
case "REVOKED": case strRevoked:
fmt.Printf("%s: REVOKED\n", target) fmt.Printf("%s: %s\n", target, strRevoked)
exitCode = 1 exitCode = 1
case "UNKNOWN": case strUnknown:
fmt.Printf("%s: UNKNOWN: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
if hardfail { if hardfail {
// In hardfail, treat unknown as failure // In hardfail, treat unknown as failure
exitCode = 1 exitCode = 1
@@ -66,74 +75,77 @@ func processTarget(target string) (string, error) {
return checkFile(target) return checkFile(target)
} }
// Not a file; treat as site
return checkSite(target) return checkSite(target)
} }
func checkFile(path string) (string, error) { func checkFile(path string) (string, error) {
in, err := ioutil.ReadFile(path) // Prefer high-level helpers from certlib to load certificates from disk
if err != nil { if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
return "UNKNOWN", err // Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
} }
// Try PEM first; if that fails, try single DER cert cert, err := certlib.LoadCertificate(path)
certs, err := certlib.ReadCertificates(in) if err != nil || cert == nil {
if err != nil || len(certs) == 0 { return strUnknown, err
cert, _, derr := certlib.ReadCertificate(in)
if derr != nil || cert == nil {
if err == nil {
err = derr
}
return "UNKNOWN", err
}
return evaluateCert(cert)
} }
return evaluateCert(cert)
// Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
} }
func checkSite(hostport string) (string, error) { func checkSite(hostport string) (string, error) {
// Use certlib/hosts to parse host/port (supports https URLs and host:port) // Use certlib/hosts to parse host/port (supports https URLs and host:port)
target, err := hosts.ParseHost(hostport) target, err := hosts.ParseHost(hostport)
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
d := &net.Dialer{Timeout: timeout} d := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}) tcfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: target.Host,
} // #nosec G402 -- CLI tool only verifies revocation
td := &tls.Dialer{NetDialer: d, Config: tcfg}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := td.DialContext(ctx, "tcp", target.String())
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() tconn, ok := conn.(*tls.Conn)
if !ok {
return strUnknown, errors.New("connection is not TLS")
}
state := tconn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
return "UNKNOWN", fmt.Errorf("no peer certificates presented") return strUnknown, errors.New("no peer certificates presented")
} }
return evaluateCert(state.PeerCertificates[0]) return evaluateCert(state.PeerCertificates[0])
} }
func evaluateCert(cert *x509.Certificate) (string, error) { func evaluateCert(cert *x509.Certificate) (string, error) {
// Expiry check // Delegate validity and revocation checks to certlib/revoke helper.
now := time.Now() // It returns revoked=true for both revoked and expired/not-yet-valid.
if !now.Before(cert.NotAfter) { // Map those cases back to our statuses using the returned error text.
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
}
if !now.After(cert.NotBefore) {
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
}
// Revocation check using certlib/revoke
revoked, ok, err := revoke.VerifyCertificateError(cert) revoked, ok, err := revoke.VerifyCertificateError(cert)
if revoked { if revoked {
// If revoked is true, ok will be true per implementation, err may describe why if err != nil {
return "REVOKED", err msg := err.Error()
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
strings.Contains(msg, "not valid until") {
return strExpired, err
}
}
return strRevoked, err
} }
if !ok { if !ok {
// Revocation status could not be determined // Revocation status could not be determined
return "UNKNOWN", err return strUnknown, err
} }
return "OK", nil return strOK, nil
} }

View File

@@ -1,11 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"os"
"regexp" "regexp"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -20,20 +23,26 @@ func main() {
server += ":443" server += ":443"
} }
var chain string d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", server)
conn, err := tls.Dial("tcp", server, nil)
die.If(err) die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
details := conn.ConnectionState() details := conn.ConnectionState()
var chain strings.Builder
for _, cert := range details.PeerCertificates { for _, cert := range details.PeerCertificates {
p := pem.Block{ p := pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
} }
chain += string(pem.EncodeToMemory(&p)) chain.Write(pem.EncodeToMemory(&p))
} }
fmt.Println(chain) fmt.Fprintln(os.Stdout, chain.String())
} }
} }

View File

@@ -1,328 +0,0 @@
package main
import (
"bytes"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"flag"
"fmt"
"io"
"os"
"sort"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/lib"
)
func certPublic(cert *x509.Certificate) string {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return fmt.Sprintf("RSA-%d", pub.N.BitLen())
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return "ECDSA-prime256v1"
case elliptic.P384():
return "ECDSA-secp384r1"
case elliptic.P521():
return "ECDSA-secp521r1"
default:
return "ECDSA (unknown curve)"
}
case *dsa.PublicKey:
return "DSA"
default:
return "Unknown"
}
}
func displayName(name pkix.Name) string {
var ns []string
if name.CommonName != "" {
ns = append(ns, name.CommonName)
}
for i := range name.Country {
ns = append(ns, fmt.Sprintf("C=%s", name.Country[i]))
}
for i := range name.Organization {
ns = append(ns, fmt.Sprintf("O=%s", name.Organization[i]))
}
for i := range name.OrganizationalUnit {
ns = append(ns, fmt.Sprintf("OU=%s", name.OrganizationalUnit[i]))
}
for i := range name.Locality {
ns = append(ns, fmt.Sprintf("L=%s", name.Locality[i]))
}
for i := range name.Province {
ns = append(ns, fmt.Sprintf("ST=%s", name.Province[i]))
}
if len(ns) > 0 {
return "/" + strings.Join(ns, "/")
}
return "*** no subject information ***"
}
func keyUsages(ku x509.KeyUsage) string {
var uses []string
for u, s := range keyUsage {
if (ku & u) != 0 {
uses = append(uses, s)
}
}
sort.Strings(uses)
return strings.Join(uses, ", ")
}
func extUsage(ext []x509.ExtKeyUsage) string {
ns := make([]string, 0, len(ext))
for i := range ext {
ns = append(ns, extKeyUsages[ext[i]])
}
sort.Strings(ns)
return strings.Join(ns, ", ")
}
func showBasicConstraints(cert *x509.Certificate) {
fmt.Printf("\tBasic constraints: ")
if cert.BasicConstraintsValid {
fmt.Printf("valid")
} else {
fmt.Printf("invalid")
}
if cert.IsCA {
fmt.Printf(", is a CA certificate")
if !cert.BasicConstraintsValid {
fmt.Printf(" (basic constraint failure)")
}
} else {
fmt.Printf("is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Printf(" (key encipherment usage enabled!)")
}
}
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Printf(", max path length %d", cert.MaxPathLen)
}
fmt.Printf("\n")
}
const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
var (
dateFormat string
showHash bool // if true, print a SHA256 hash of the certificate's Raw field
)
func wrapPrint(text string, indent int) {
tabs := ""
for i := 0; i < indent; i++ {
tabs += "\t"
}
fmt.Printf(tabs+"%s\n", wrap(text, indent))
}
func displayCert(cert *x509.Certificate) {
fmt.Println("CERTIFICATE")
if showHash {
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
}
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm))
fmt.Println("Details:")
wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
}
if len(cert.SubjectKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
}
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 {
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
}
showBasicConstraints(cert)
validNames := make([]string, 0, len(cert.DNSNames)+len(cert.EmailAddresses)+len(cert.IPAddresses))
for i := range cert.DNSNames {
validNames = append(validNames, "dns:"+cert.DNSNames[i])
}
for i := range cert.EmailAddresses {
validNames = append(validNames, "email:"+cert.EmailAddresses[i])
}
for i := range cert.IPAddresses {
validNames = append(validNames, "ip:"+cert.IPAddresses[i].String())
}
sans := fmt.Sprintf("SANs (%d): %s\n", len(validNames), strings.Join(validNames, ", "))
wrapPrint(sans, 1)
l := len(cert.IssuingCertificateURL)
if l != 0 {
var aia string
if l == 1 {
aia = "AIA"
} else {
aia = "AIAs"
}
wrapPrint(fmt.Sprintf("%d %s:", l, aia), 1)
for _, url := range cert.IssuingCertificateURL {
wrapPrint(url, 2)
}
}
l = len(cert.OCSPServer)
if l > 0 {
title := "OCSP server"
if l > 1 {
title += "s"
}
wrapPrint(title+":\n", 1)
for _, ocspServer := range cert.OCSPServer {
wrapPrint(fmt.Sprintf("- %s\n", ocspServer), 2)
}
}
}
func displayAllCerts(in []byte, leafOnly bool) {
certs, err := certlib.ParseCertificatesPEM(in)
if err != nil {
certs, _, err = certlib.ParseCertificatesDER(in, "")
if err != nil {
lib.Warn(err, "failed to parse certificates")
return
}
}
if len(certs) == 0 {
lib.Warnx("no certificates found")
return
}
if leafOnly {
displayCert(certs[0])
return
}
for i := range certs {
displayCert(certs[i])
}
}
func displayAllCertsWeb(uri string, leafOnly bool) {
ci := getConnInfo(uri)
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig())
if err != nil {
lib.Warn(err, "couldn't connect to %s", ci.Addr)
return
}
defer conn.Close()
state := conn.ConnectionState()
conn.Close()
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host))
if err == nil {
err = conn.VerifyHostname(ci.Host)
if err == nil {
state = conn.ConnectionState()
}
conn.Close()
} else {
lib.Warn(err, "TLS verification error with server name %s", ci.Host)
}
if len(state.PeerCertificates) == 0 {
lib.Warnx("no certificates found")
return
}
if leafOnly {
displayCert(state.PeerCertificates[0])
return
}
if len(state.VerifiedChains) == 0 {
lib.Warnx("no verified chains found; using peer chain")
for i := range state.PeerCertificates {
displayCert(state.PeerCertificates[i])
}
} else {
fmt.Println("TLS chain verified successfully.")
for i := range state.VerifiedChains {
fmt.Printf("--- Verified certificate chain %d ---\n", i+1)
for j := range state.VerifiedChains[i] {
displayCert(state.VerifiedChains[i][j])
}
}
}
}
func main() {
var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
flag.StringVar(&dateFormat, "s", oneTrueDateFormat, "date `format` in Go time format")
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse()
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1)
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
} else {
for _, filename := range flag.Args() {
fmt.Printf("--%s ---\n", filename)
if strings.HasPrefix(filename, "https://") {
displayAllCertsWeb(filename, leafOnly)
} else {
in, err := os.ReadFile(filename)
if err != nil {
lib.Warn(err, "couldn't read certificate")
continue
}
displayAllCerts(in, leafOnly)
}
}
}
}

376
cmd/certdump/main.go Normal file
View File

@@ -0,0 +1,376 @@
//lint:file-ignore SA1019 allow strict compatibility for old certs
package main
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"flag"
"fmt"
"os"
"sort"
"strings"
"github.com/kr/text"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/lib"
)
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment",
x509.KeyUsageKeyEncipherment: "key encipherment",
x509.KeyUsageKeyAgreement: "key agreement",
x509.KeyUsageDataEncipherment: "data encipherment",
x509.KeyUsageCertSign: "cert sign",
x509.KeyUsageCRLSign: "crl sign",
x509.KeyUsageEncipherOnly: "encipher only",
x509.KeyUsageDecipherOnly: "decipher only",
}
var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
}
func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default:
return "unknown public key algorithm"
}
}
func sigAlgoHash(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA:
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default:
return "unknown hash algorithm"
}
}
const maxLine = 78
func makeIndent(n int) string {
s := " "
var sSb97 strings.Builder
for range n {
sSb97.WriteString(" ")
}
s += sSb97.String()
return s
}
func indentLen(n int) int {
return 4 + (8 * n)
}
// this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string {
if indent > 3 {
indent = 3
}
wrapped := text.Wrap(s, maxLine)
lines := strings.SplitN(wrapped, "\n", 2)
if len(lines) == 1 {
return lines[0]
}
if (maxLine - indentLen(indent)) <= 0 {
panic("too much indentation")
}
rest := strings.Join(lines[1:], " ")
wrapped = text.Wrap(rest, maxLine-indentLen(indent))
return lines[0] + "\n" + text.Indent(wrapped, makeIndent(indent))
}
func dumpHex(in []byte) string {
return lib.HexEncode(in, lib.HexEncodeUpperColon)
}
func certPublic(cert *x509.Certificate) string {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return fmt.Sprintf("RSA-%d", pub.N.BitLen())
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return "ECDSA-prime256v1"
case elliptic.P384():
return "ECDSA-secp384r1"
case elliptic.P521():
return "ECDSA-secp521r1"
default:
return "ECDSA (unknown curve)"
}
case *dsa.PublicKey:
return "DSA"
default:
return "Unknown"
}
}
func displayName(name pkix.Name) string {
var ns []string
if name.CommonName != "" {
ns = append(ns, name.CommonName)
}
for i := range name.Country {
ns = append(ns, fmt.Sprintf("C=%s", name.Country[i]))
}
for i := range name.Organization {
ns = append(ns, fmt.Sprintf("O=%s", name.Organization[i]))
}
for i := range name.OrganizationalUnit {
ns = append(ns, fmt.Sprintf("OU=%s", name.OrganizationalUnit[i]))
}
for i := range name.Locality {
ns = append(ns, fmt.Sprintf("L=%s", name.Locality[i]))
}
for i := range name.Province {
ns = append(ns, fmt.Sprintf("ST=%s", name.Province[i]))
}
if len(ns) > 0 {
return "/" + strings.Join(ns, "/")
}
return "*** no subject information ***"
}
func keyUsages(ku x509.KeyUsage) string {
var uses []string
for u, s := range keyUsage {
if (ku & u) != 0 {
uses = append(uses, s)
}
}
sort.Strings(uses)
return strings.Join(uses, ", ")
}
func extUsage(ext []x509.ExtKeyUsage) string {
ns := make([]string, 0, len(ext))
for i := range ext {
ns = append(ns, extKeyUsages[ext[i]])
}
sort.Strings(ns)
return strings.Join(ns, ", ")
}
func showBasicConstraints(cert *x509.Certificate) {
fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, "valid")
} else {
fmt.Fprint(os.Stdout, "invalid")
}
if cert.IsCA {
fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, " (basic constraint failure)")
}
} else {
fmt.Fprint(os.Stdout, ", is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
}
}
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
}
fmt.Fprintln(os.Stdout)
}
const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
var (
dateFormat string
showHash bool // if true, print a SHA256 hash of the certificate's Raw field
)
func wrapPrint(text string, indent int) {
tabs := ""
var tabsSb140 strings.Builder
for range indent {
tabsSb140.WriteString("\t")
}
tabs += tabsSb140.String()
fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
}
func displayCert(cert *x509.Certificate) {
fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash {
fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
}
fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm))
fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 {
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
}
if len(cert.SubjectKeyId) > 0 {
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
}
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 {
fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
}
showBasicConstraints(cert)
validNames := make([]string, 0, len(cert.DNSNames)+len(cert.EmailAddresses)+len(cert.IPAddresses))
for i := range cert.DNSNames {
validNames = append(validNames, "dns:"+cert.DNSNames[i])
}
for i := range cert.EmailAddresses {
validNames = append(validNames, "email:"+cert.EmailAddresses[i])
}
for i := range cert.IPAddresses {
validNames = append(validNames, "ip:"+cert.IPAddresses[i].String())
}
sans := fmt.Sprintf("SANs (%d): %s\n", len(validNames), strings.Join(validNames, ", "))
wrapPrint(sans, 1)
l := len(cert.IssuingCertificateURL)
if l != 0 {
var aia string
if l == 1 {
aia = "AIA"
} else {
aia = "AIAs"
}
wrapPrint(fmt.Sprintf("%d %s:", l, aia), 1)
for _, url := range cert.IssuingCertificateURL {
wrapPrint(url, 2)
}
}
l = len(cert.OCSPServer)
if l > 0 {
title := "OCSP server"
if l > 1 {
title += "s"
}
wrapPrint(title+":\n", 1)
for _, ocspServer := range cert.OCSPServer {
wrapPrint(fmt.Sprintf("- %s\n", ocspServer), 2)
}
}
}
func main() {
var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
flag.StringVar(&dateFormat, "s", oneTrueDateFormat, "date `format` in Go time format")
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse()
opts := &certlib.FetcherOpts{
SkipVerify: true,
Roots: nil,
}
for _, filename := range flag.Args() {
fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
certs, err := certlib.GetCertificateChain(filename, opts)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificate")
continue
}
if leafOnly {
displayCert(certs[0])
continue
}
for i := range certs {
displayCert(certs[i])
}
}
}

View File

@@ -1,176 +0,0 @@
package main
import (
"crypto/tls"
"crypto/x509"
"fmt"
"net"
"strings"
"github.com/kr/text"
)
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,")
var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment",
x509.KeyUsageKeyEncipherment: "key encipherment",
x509.KeyUsageKeyAgreement: "key agreement",
x509.KeyUsageDataEncipherment: "data encipherment",
x509.KeyUsageCertSign: "cert sign",
x509.KeyUsageCRLSign: "crl sign",
x509.KeyUsageEncipherOnly: "encipher only",
x509.KeyUsageDecipherOnly: "decipher only",
}
var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
}
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
switch a {
case x509.RSA:
return "RSA"
case x509.ECDSA:
return "ECDSA"
case x509.DSA:
return "DSA"
default:
return "unknown public key algorithm"
}
}
func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA"
default:
return "unknown public key algorithm"
}
}
func sigAlgoHash(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA:
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return "SHA256"
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return "SHA512"
default:
return "unknown hash algorithm"
}
}
const maxLine = 78
func makeIndent(n int) string {
s := " "
for i := 0; i < n; i++ {
s += " "
}
return s
}
func indentLen(n int) int {
return 4 + (8 * n)
}
// this isn't real efficient, but that's not a problem here
func wrap(s string, indent int) string {
if indent > 3 {
indent = 3
}
wrapped := text.Wrap(s, maxLine)
lines := strings.SplitN(wrapped, "\n", 2)
if len(lines) == 1 {
return lines[0]
}
if (maxLine - indentLen(indent)) <= 0 {
panic("too much indentation")
}
rest := strings.Join(lines[1:], " ")
wrapped = text.Wrap(rest, maxLine-indentLen(indent))
return lines[0] + "\n" + text.Indent(wrapped, makeIndent(indent))
}
func dumpHex(in []byte) string {
var s string
for i := range in {
s += fmt.Sprintf("%02X:", in[i])
}
return strings.Trim(s, ":")
}
// permissiveConfig returns a maximally-accepting TLS configuration;
// the purpose is to look at the cert, not verify the security properties
// of the connection.
func permissiveConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
}
}
// verifyConfig returns a config that will verify the connection.
func verifyConfig(hostname string) *tls.Config {
return &tls.Config{
ServerName: hostname,
}
}
type connInfo struct {
// The original URI provided.
URI string
// The hostname of the server.
Host string
// The port to connect on.
Port string
// The address to connect to.
Addr string
}
func getConnInfo(uri string) *connInfo {
ci := &connInfo{URI: uri}
ci.Host = uri[len("https://"):]
host, port, err := net.SplitHostPort(ci.Host)
if err != nil {
ci.Port = "443"
} else {
ci.Host = host
ci.Port = port
}
ci.Addr = net.JoinHostPort(ci.Host, ci.Port)
return ci
}

View File

@@ -5,7 +5,6 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
} }
func expires(cert *x509.Certificate) time.Duration { func expires(cert *x509.Certificate) time.Duration {
return cert.NotAfter.Sub(time.Now()) return time.Until(cert.NotAfter)
} }
func inDanger(cert *x509.Certificate) bool { func inDanger(cert *x509.Certificate) bool {
@@ -76,20 +75,17 @@ func checkCert(cert *x509.Certificate) {
} }
func main() { func main() {
opts := &certlib.FetcherOpts{}
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs")
flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring")
flag.Parse() flag.Parse()
for _, file := range flag.Args() { for _, file := range flag.Args() {
in, err := ioutil.ReadFile(file) certs, err := certlib.GetCertificateChain(file, opts)
if err != nil { if err != nil {
lib.Warn(err, "failed to read file") _, _ = lib.Warn(err, "while parsing certificates")
continue
}
certs, err := certlib.ParseCertificatesPEM(in)
if err != nil {
lib.Warn(err, "while parsing certificates")
continue continue
} }

53
cmd/certser/main.go Normal file
View File

@@ -0,0 +1,53 @@
package main
import (
"crypto/x509"
"flag"
"fmt"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const displayInt lib.HexEncodeMode = iota
func parseDisplayMode(mode string) lib.HexEncodeMode {
mode = strings.ToLower(mode)
if mode == "int" {
return displayInt
}
return lib.ParseHexEncodeMode(mode)
}
func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
if mode == displayInt {
return cert.SerialNumber.String()
}
return lib.HexEncode(cert.SerialNumber.Bytes(), mode)
}
func main() {
opts := &certlib.FetcherOpts{}
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date")
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
flag.Parse()
displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() {
cert, err := certlib.GetCertificate(arg, opts)
die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode))
if *showExpiry {
fmt.Printf(" (%s)", cert.NotAfter.Format("2006-01-02"))
}
fmt.Println()
}
}

View File

@@ -4,13 +4,11 @@ import (
"crypto/x509" "crypto/x509"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke" "git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
@@ -30,83 +28,122 @@ func printRevocation(cert *x509.Certificate) {
} }
} }
func main() { type appConfig struct {
var caFile, intFile string caFile, intFile string
var forceIntermediateBundle, revexp, verbose bool forceIntermediateBundle bool
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`") revexp, skipVerify, verbose bool
flag.StringVar(&intFile, "i", "", "intermediate `bundle`") }
flag.BoolVar(&forceIntermediateBundle, "f", false,
func parseFlags() appConfig {
var cfg appConfig
flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") "force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification")
flag.BoolVar(&verbose, "v", false, "verbose") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse() flag.Parse()
return cfg
}
var roots *x509.CertPool func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile != "" { if caFile == "" {
var err error return x509.SystemCertPool()
if verbose {
fmt.Println("[+] loading root certificates from", caFile)
}
roots, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
} }
var ints *x509.CertPool
if intFile != "" {
var err error
if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile)
}
ints, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
} else {
ints = x509.NewCertPool()
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
lib.ProgName())
}
fileData, err := ioutil.ReadFile(flag.Arg(0))
die.If(err)
chain, err := certlib.ParseCertificatesPEM(fileData)
die.If(err)
if verbose { if verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain)) fmt.Println("[+] loading root certificates from", caFile)
} }
return certlib.LoadPEMCertPool(caFile)
}
cert := chain[0] func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if len(chain) > 1 { if intFile == "" {
if !forceIntermediateBundle { return x509.NewCertPool(), nil
for _, intermediate := range chain[1:] { }
if verbose { if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId) fmt.Println("[+] loading intermediate certificates from", intFile)
} }
// Note: use intFile here (previously used caFile mistakenly)
return certlib.LoadPEMCertPool(intFile)
}
ints.AddCert(intermediate) func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
} for _, intermediate := range chain[1:] {
if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
} }
pool.AddCert(intermediate)
} }
}
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Intermediates: ints, Intermediates: ints,
Roots: roots, Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
_, err := cert.Verify(opts)
return err
}
_, err = cert.Verify(opts) func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err) return err
os.Exit(1)
} }
if verbose { ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
combinedPool, err := certlib.LoadFullCertPool(cfg.caFile, cfg.intFile)
if err != nil {
return fmt.Errorf("failed to build combined pool: %w", err)
}
opts := &certlib.FetcherOpts{
Roots: combinedPool,
SkipVerify: cfg.skipVerify,
}
chain, err := certlib.GetCertificateChain(flag.Arg(0), opts)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK") fmt.Println("OK")
} }
if revexp { if cfg.revexp {
printRevocation(cert) printRevocation(cert)
} }
return nil
}
func main() {
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
} }

View File

@@ -2,6 +2,8 @@ package main
import ( import (
"bufio" "bufio"
"context"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
} }
func sshAgent() ssh.AuthMethod { func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil { if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers) return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
} }
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
} }
} }
func logError(host string, err error, format string, args ...interface{}) { func logError(host string, err error, format string, args ...any) {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err) log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
} }
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
} }
shutdown = append(shutdown, session.Close) shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close() session.Close()
logError(host, err, "request for pty failed") logError(host, err, "request for pty failed")
return return
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")

View File

@@ -10,6 +10,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/fileutil" "git.wntrmute.dev/kyle/goutils/fileutil"
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
if verbose { if verbose {
fmt.Printf("\tchmod %0#o\n", hdr.Mode) fmt.Printf("\tchmod %0#o\n", hdr.Mode)
} }
err := file.Chmod(os.FileMode(hdr.Mode)) err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
if err != nil { if err != nil {
return err return err
} }
@@ -48,73 +49,105 @@ func linkTarget(target, top string) string {
return target return target
} }
return filepath.Clean(filepath.Join(target, top)) return filepath.Clean(filepath.Join(top, target))
}
// safeJoin joins base and elem and ensures the resulting path does not escape base.
func safeJoin(base, elem string) (string, error) {
cleanBase := filepath.Clean(base)
joined := filepath.Clean(filepath.Join(cleanBase, elem))
absBase, err := filepath.Abs(cleanBase)
if err != nil {
return "", err
}
absJoined, err := filepath.Abs(joined)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absBase, absJoined)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
}
return joined, nil
}
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
if _, err = io.Copy(file, tfr); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeLink(hdr *tar.Header, top, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
srcPath, err := safeJoin(top, hdr.Linkname)
if err != nil {
return err
}
source, err := os.Open(srcPath)
if err != nil {
return err
}
defer source.Close()
if _, err = io.Copy(file, source); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
}
func handleTypeDir(hdr *tar.Header, filePath string) error {
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
} }
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error { func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
if verbose { if verbose {
fmt.Println(hdr.Name) fmt.Println(hdr.Name)
} }
filePath := filepath.Clean(filepath.Join(top, hdr.Name))
switch hdr.Typeflag {
case tar.TypeReg:
file, err := os.Create(filePath)
if err != nil {
return err
}
_, err = io.Copy(file, tfr) filePath, err := safeJoin(top, hdr.Name)
if err != nil { if err != nil {
return err return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeLink:
file, err := os.Create(filePath)
if err != nil {
return err
}
source, err := os.Open(hdr.Linkname)
if err != nil {
return err
}
_, err = io.Copy(file, source)
if err != nil {
return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeSymlink:
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s",
hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath)
if err != nil {
return err
}
case tar.TypeDir:
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode))
if err != nil {
return err
}
} }
switch hdr.Typeflag {
case tar.TypeReg:
return handleTypeReg(tfr, hdr, filePath)
case tar.TypeLink:
return handleTypeLink(hdr, top, filePath)
case tar.TypeSymlink:
return handleTypeSymlink(hdr, top, filePath)
case tar.TypeDir:
return handleTypeDir(hdr, filePath)
}
return nil return nil
} }
@@ -261,16 +294,16 @@ func main() {
die.If(err) die.If(err)
tfr := tar.NewReader(r) tfr := tar.NewReader(r)
var hdr *tar.Header
for { for {
hdr, err := tfr.Next() hdr, err = tfr.Next()
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
die.If(err) die.If(err)
err = processFile(tfr, hdr, top) err = processFile(tfr, hdr, top)
die.If(err) die.If(err)
} }
r.Close() r.Close()

View File

@@ -7,9 +7,9 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "os"
"log"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -17,17 +17,10 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
in, err := ioutil.ReadFile(fileName) in, err := os.ReadFile(fileName)
die.If(err) die.If(err)
if p, _ := pem.Decode(in); p != nil { csr, _, err := certlib.ParseCSR(in)
if p.Type != "CERTIFICATE REQUEST" {
log.Fatal("INVALID FILE TYPE")
}
in = p.Bytes
}
csr, err := x509.ParseCertificateRequest(in)
die.If(err) die.If(err)
out, err := x509.MarshalPKIXPublicKey(csr.PublicKey) out, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
@@ -48,8 +41,8 @@ func main() {
Bytes: out, Bytes: out,
} }
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644) err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.\n", fileName+".pub") fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
return err return err
} }
cmd := exec.Command(path, args...) cmd := exec.CommandContext(context.Background(), path, args...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
@@ -163,7 +164,6 @@ func init() {
} }
func main() { func main() {
var logLevel, mountDir, syncDir, target string var logLevel, mountDir, syncDir, target string
var dryRun, quietMode, noSyslog, verboseRsync bool var dryRun, quietMode, noSyslog, verboseRsync bool
@@ -219,7 +219,7 @@ func main() {
if excludeFile != "" { if excludeFile != "" {
defer func() { defer func() {
log.Infof("removing exclude file %s", excludeFile) log.Infof("removing exclude file %s", excludeFile)
if err := os.Remove(excludeFile); err != nil { if rmErr := os.Remove(excludeFile); rmErr != nil {
log.Warningf("failed to remove temp file %s", excludeFile) log.Warningf("failed to remove temp file %s", excludeFile)
} }
}() }()

View File

@@ -15,43 +15,41 @@ import (
const defaultHashAlgorithm = "sha256" const defaultHashAlgorithm = "sha256"
var ( var (
hAlgo string hAlgo string
debug = dbg.New() debug = dbg.New()
) )
func openImage(imageFile string) (*os.File, []byte, error) {
func openImage(imageFile string) (image *os.File, hash []byte, err error) { f, err := os.Open(imageFile)
image, err = os.Open(imageFile)
if err != nil { if err != nil {
return return nil, nil, err
} }
hash, err = ahash.SumReader(hAlgo, image) h, err := ahash.SumReader(hAlgo, f)
if err != nil { if err != nil {
return return nil, nil, err
} }
_, err = image.Seek(0, 0) if _, err = f.Seek(0, 0); err != nil {
if err != nil { return nil, nil, err
return
} }
debug.Printf("%s %x\n", imageFile, hash) debug.Printf("%s %x\n", imageFile, h)
return return f, h, nil
} }
func openDevice(devicePath string) (device *os.File, err error) { func openDevice(devicePath string) (*os.File, error) {
fi, err := os.Stat(devicePath) fi, err := os.Stat(devicePath)
if err != nil { if err != nil {
return return nil, err
} }
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode()) device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
if err != nil { if err != nil {
return return nil, err
} }
return return device, nil
} }
func main() { func main() {
@@ -105,12 +103,12 @@ func main() {
die.If(err) die.If(err)
if !bytes.Equal(deviceHash, hash) { if !bytes.Equal(deviceHash, hash) {
fmt.Fprintln(os.Stderr, "Hash mismatch:") buf := &bytes.Buffer{}
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash) fmt.Fprintln(buf, "Hash mismatch:")
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash) fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
os.Exit(1) fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
die.With(buf.String())
} }
debug.Println("OK") debug.Println("OK")
os.Exit(0)
} }

View File

@@ -1,30 +1,33 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/die"
"io" "io"
"os" "os"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
) )
func usage(w io.Writer, exc int) { func usage(w io.Writer, exc int) {
fmt.Fprintln(w, `usage: dumpbytes <file>`) fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
os.Exit(exc) os.Exit(exc)
} }
func printBytes(buf []byte) { func printBytes(buf []byte) {
fmt.Printf("\t") fmt.Printf("\t")
for i := 0; i < len(buf); i++ { for i := range buf {
fmt.Printf("0x%02x, ", buf[i]) fmt.Printf("0x%02x, ", buf[i])
} }
fmt.Println() fmt.Println()
} }
func dumpFile(path string, indentLevel int) error { func dumpFile(path string, indentLevel int) error {
indent := "" var indent strings.Builder
for i := 0; i < indentLevel; i++ { for range indentLevel {
indent += "\t" indent.WriteByte('\t')
} }
file, err := os.Open(path) file, err := os.Open(path)
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
defer file.Close() defer file.Close()
fmt.Printf("%svar buffer = []byte{\n", indent) fmt.Printf("%svar buffer = []byte{\n", indent.String())
var n int
for { for {
buf := make([]byte, 8) buf := make([]byte, 8)
n, err := file.Read(buf) n, err = file.Read(buf)
if err == io.EOF { if errors.Is(err, io.EOF) {
if n > 0 { if n > 0 {
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
break break
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
return err return err
} }
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
fmt.Printf("%s}\n", indent) fmt.Printf("%s}\n", indent.String())
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
// size of a kilobit in bytes // size of a kilobit in bytes.
const kilobit = 128 const kilobit = 128
const pageSize = 4096 const pageSize = 4096
@@ -26,10 +26,10 @@ func main() {
path = flag.Arg(0) path = flag.Arg(0)
} }
fillByte := uint8(*fill) fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
buf := make([]byte, pageSize) buf := make([]byte, pageSize)
for i := 0; i < pageSize; i++ { for i := range pageSize {
buf[i] = fillByte buf[i] = fillByte
} }
@@ -40,7 +40,7 @@ func main() {
die.If(err) die.If(err)
defer file.Close() defer file.Close()
for i := 0; i < pages; i++ { for range pages {
_, err = file.Write(buf) _, err = file.Write(buf)
die.If(err) die.If(err)
} }

View File

@@ -72,15 +72,13 @@ func main() {
if end < start { if end < start {
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values") fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
tmp := end start, end = end, start
end = start
start = tmp
} }
var fmtStr string var fmtStr string
if !*quiet { if !*quiet {
maxLine := fmt.Sprintf("%d", len(lines)) maxLine := strconv.Itoa(len(lines))
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine)) fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
} }
@@ -98,9 +96,9 @@ func main() {
fmtStr += "\n" fmtStr += "\n"
for i := start; !endFunc(i); i++ { for i := start; !endFunc(i); i++ {
if *quiet { if *quiet {
fmt.Println(lines[i]) fmt.Fprintln(os.Stdout, lines[i])
} else { } else {
fmt.Printf(fmtStr, i, lines[i]) fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@@ -8,7 +9,8 @@ import (
) )
func lookupHost(host string) error { func lookupHost(host string) error {
cname, err := net.LookupCNAME(host) r := &net.Resolver{}
cname, err := r.LookupCNAME(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
host = cname host = cname
} }
addrs, err := net.LookupHost(host) addrs, err := r.LookupHost(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Indent(buf, in, "", " ") err = json.Indent(buf, in, "", " ")
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Compact(buf, in) err = json.Compact(buf, in)
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
func usage() { func usage() {
progname := lib.ProgName() progname := lib.ProgName()
fmt.Printf(`Usage: %s [-h] files... fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
%s is used to lint and prettify (or compact) JSON files. The %s is used to lint and prettify (or compact) JSON files. The
files will be updated in-place. files will be updated in-place.
@@ -100,7 +100,6 @@ func usage() {
-h Print this help message. -h Print this help message.
-n Don't prettify; only perform validation. -n Don't prettify; only perform validation.
`, progname, progname) `, progname, progname)
} }
func init() { func init() {

View File

@@ -11,7 +11,10 @@ based on whether the source filename ends in ".gz".
Flags: Flags:
-l level Compression level (0-9). Only meaninful when -l level Compression level (0-9). Only meaninful when
compressing a file. compressing a file.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.

View File

@@ -9,8 +9,6 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
"github.com/pkg/errors"
) )
const gzipExt = ".gz" const gzipExt = ".gz"
@@ -18,56 +16,68 @@ const gzipExt = ".gz"
func compress(path, target string, level int) error { func compress(path, target string, level int) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return errors.Wrap(err, "opening file for read") return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
destFile, err := os.Create(target) destFile, err := os.Create(target)
if err != nil { if err != nil {
return errors.Wrap(err, "opening file for write") return fmt.Errorf("opening file for write: %w", err)
} }
defer destFile.Close() defer destFile.Close()
gzipCompressor, err := gzip.NewWriterLevel(destFile, level) gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
if err != nil { if err != nil {
return errors.Wrap(err, "invalid compression level") return fmt.Errorf("invalid compression level: %w", err)
} }
defer gzipCompressor.Close() defer gzipCompressor.Close()
_, err = io.Copy(gzipCompressor, sourceFile) _, err = io.Copy(gzipCompressor, sourceFile)
if err != nil { if err != nil {
return errors.Wrap(err, "compressing file") return fmt.Errorf("compressing file: %w", err)
}
if err != nil {
return errors.Wrap(err, "stat(2)ing destination file")
} }
return nil return nil
} }
func uncompress(path, target string) error { func uncompress(path, target string, unrestrict bool) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return errors.Wrap(err, "opening file for read") return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
fi, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("reading file stats: %w", err)
}
maxDecompressionSize := fi.Size() * 32
gzipUncompressor, err := gzip.NewReader(sourceFile) gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil { if err != nil {
return errors.Wrap(err, "reading gzip headers") return fmt.Errorf("reading gzip headers: %w", err)
} }
defer gzipUncompressor.Close() defer gzipUncompressor.Close()
var reader io.Reader = &io.LimitedReader{
R: gzipUncompressor,
N: maxDecompressionSize,
}
if unrestrict {
reader = gzipUncompressor
}
destFile, err := os.Create(target) destFile, err := os.Create(target)
if err != nil { if err != nil {
return errors.Wrap(err, "opening file for write") return fmt.Errorf("opening file for write: %w", err)
} }
defer destFile.Close() defer destFile.Close()
_, err = io.Copy(destFile, gzipUncompressor) _, err = io.Copy(destFile, reader)
if err != nil { if err != nil {
return errors.Wrap(err, "uncompressing file") return fmt.Errorf("uncompressing file: %w", err)
} }
return nil return nil
@@ -93,8 +103,8 @@ func isDir(path string) bool {
file, err := os.Open(path) file, err := os.Open(path)
if err == nil { if err == nil {
defer file.Close() defer file.Close()
stat, err := file.Stat() stat, err2 := file.Stat()
if err != nil { if err2 != nil {
return false return false
} }
@@ -113,7 +123,7 @@ func pathForUncompressing(source, dest string) (string, error) {
source = filepath.Base(source) source = filepath.Base(source)
if !strings.HasSuffix(source, gzipExt) { if !strings.HasSuffix(source, gzipExt) {
return "", errors.Errorf("%s is a not gzip-compressed file", source) return "", fmt.Errorf("%s is a not gzip-compressed file", source)
} }
outFile := source[:len(source)-len(gzipExt)] outFile := source[:len(source)-len(gzipExt)]
outFile = filepath.Join(dest, outFile) outFile = filepath.Join(dest, outFile)
@@ -127,7 +137,7 @@ func pathForCompressing(source, dest string) (string, error) {
source = filepath.Base(source) source = filepath.Base(source)
if strings.HasSuffix(source, gzipExt) { if strings.HasSuffix(source, gzipExt) {
return "", errors.Errorf("%s is a gzip-compressed file", source) return "", fmt.Errorf("%s is a gzip-compressed file", source)
} }
dest = filepath.Join(dest, source+gzipExt) dest = filepath.Join(dest, source+gzipExt)
@@ -138,8 +148,11 @@ func main() {
var level int var level int
var path string var path string
var target = "." var target = "."
var err error
var unrestrict bool
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level") flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.Parse() flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 { if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -153,30 +166,31 @@ func main() {
} }
if strings.HasSuffix(path, gzipExt) { if strings.HasSuffix(path, gzipExt) {
target, err := pathForUncompressing(path, target) target, err = pathForUncompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
err = uncompress(path, target) err = uncompress(path, target, unrestrict)
if err != nil { if err != nil {
os.Remove(target) os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
} else { return
target, err := pathForCompressing(path, target) }
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
err = compress(path, target, level) target, err = pathForCompressing(path, target)
if err != nil { if err != nil {
os.Remove(target) fmt.Fprintf(os.Stderr, "%s\n", err)
fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1)
os.Exit(1) }
}
err = compress(path, target, level)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
} }
} }

View File

@@ -40,14 +40,14 @@ func main() {
usage() usage()
} }
min, err := strconv.Atoi(flag.Arg(1)) minVal, err := strconv.Atoi(flag.Arg(1))
dieIf(err) dieIf(err)
max, err := strconv.Atoi(flag.Arg(2)) maxVal, err := strconv.Atoi(flag.Arg(2))
dieIf(err) dieIf(err)
code := kind << 6 code := kind << 6
code += (min << 3) code += (minVal << 3)
code += max code += maxVal
fmt.Printf("%0o\n", code) fmt.Fprintf(os.Stdout, "%0o\n", code)
} }

View File

@@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@@ -47,7 +46,7 @@ func help(w io.Writer) {
} }
func loadDatabase() { func loadDatabase() {
data, err := ioutil.ReadFile(dbFile) data, err := os.ReadFile(dbFile)
if err != nil && os.IsNotExist(err) { if err != nil && os.IsNotExist(err) {
partsDB = &database{ partsDB = &database{
Version: dbVersion, Version: dbVersion,
@@ -74,7 +73,7 @@ func writeDB() {
data, err := json.Marshal(partsDB) data, err := json.Marshal(partsDB)
die.If(err) die.If(err)
err = ioutil.WriteFile(dbFile, data, 0644) err = os.WriteFile(dbFile, data, 0644)
die.If(err) die.If(err)
} }

View File

@@ -4,14 +4,13 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
var ext = ".bin" var ext = ".bin"
func stripPEM(path string) error { func stripPEM(path string) error {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return err
} }
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n") fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
} }
return ioutil.WriteFile(path+ext, p.Bytes, 0644) return os.WriteFile(path+ext, p.Bytes, 0644)
} }
func main() { func main() {

View File

@@ -3,8 +3,7 @@ package main
import ( import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "io"
"io/ioutil"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -21,9 +20,9 @@ func main() {
path := flag.Arg(0) path := flag.Arg(0)
if path == "-" { if path == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(flag.Arg(0)) in, err = os.ReadFile(flag.Arg(0))
} }
if err != nil { if err != nil {
lib.Err(lib.ExitFailure, err, "couldn't read file") lib.Err(lib.ExitFailure, err, "couldn't read file")
@@ -33,5 +32,7 @@ func main() {
if p == nil { if p == nil {
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
} }
fmt.Printf("%s", p.Bytes) if _, err = os.Stdout.Write(p.Bytes); err != nil {
lib.Err(lib.ExitFailure, err, "writing body")
}
} }

View File

@@ -70,7 +70,7 @@ func main() {
lib.Err(lib.ExitFailure, err, "failed to read input") lib.Err(lib.ExitFailure, err, "failed to read input")
} }
case argc > 1: case argc > 1:
for i := 0; i < argc; i++ { for i := range argc {
path := flag.Arg(i) path := flag.Arg(i)
err = copyFile(path, buf) err = copyFile(path, buf)
if err != nil { if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
@@ -13,14 +12,14 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
data, err := ioutil.ReadFile(fileName) data, err := os.ReadFile(fileName)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
continue continue
} }
fmt.Printf("[+] %s:\n", fileName) fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
rest := data[:] rest := data
for { for {
var p *pem.Block var p *pem.Block
p, rest = pem.Decode(rest) p, rest = pem.Decode(rest)
@@ -28,13 +27,14 @@ func main() {
break break
} }
cert, err := x509.ParseCertificate(p.Bytes) var cert *x509.Certificate
cert, err = x509.ParseCertificate(p.Bytes)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
break break
} }
fmt.Printf("\t%+v\n", cert.Subject.CommonName) fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
} }
} }
} }

View File

@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
return hashName(path, encodedHash), nil return hashName(path, encodedHash), nil
} }
func move(dst, src string, force bool) (err error) { func move(dst, src string, force bool) error {
if fileutil.FileDoesExist(dst) && !force { if fileutil.FileDoesExist(dst) && !force {
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst) return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
} }
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
return err return err
} }
defer func(e error) { var retErr error
defer func(e *error) {
dstFile.Close() dstFile.Close()
if e != nil { if *e != nil {
os.Remove(dst) os.Remove(dst)
} }
}(err) }(&retErr)
srcFile, err := os.Open(src) srcFile, err := os.Open(src)
if err != nil { if err != nil {
retErr = err
return err return err
} }
defer srcFile.Close() defer srcFile.Close()
_, err = io.Copy(dstFile, srcFile) if _, err = io.Copy(dstFile, srcFile); err != nil {
if err != nil { retErr = err
return err return err
} }
@@ -94,6 +96,44 @@ func init() {
flag.Usage = func() { usage(os.Stdout) } flag.Usage = func() { usage(os.Stdout) }
} }
type options struct {
dryRun, force, printChanged, verbose bool
}
func processOne(file string, opt options) error {
renamed, err := newName(file)
if err != nil {
_, _ = lib.Warn(err, "failed to get new file name")
return err
}
if opt.verbose && !opt.printChanged {
fmt.Fprintln(os.Stdout, file)
}
if renamed == file {
return nil
}
if !opt.dryRun {
if err = move(renamed, file, opt.force); err != nil {
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
return err
}
}
if opt.printChanged && !opt.verbose {
fmt.Fprintln(os.Stdout, file, "->", renamed)
}
return nil
}
func run(dryRun, force, printChanged, verbose bool, files []string) {
if verbose && printChanged {
printChanged = false
}
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
for _, file := range files {
_ = processOne(file, opt)
}
}
func main() { func main() {
var dryRun, force, printChanged, verbose bool var dryRun, force, printChanged, verbose bool
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision") flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
@@ -102,34 +142,5 @@ func main() {
flag.BoolVar(&verbose, "v", false, "list all processed files") flag.BoolVar(&verbose, "v", false, "list all processed files")
flag.Parse() flag.Parse()
run(dryRun, force, printChanged, verbose, flag.Args())
if verbose && printChanged {
printChanged = false
}
for _, file := range flag.Args() {
renamed, err := newName(file)
if err != nil {
lib.Warn(err, "failed to get new file name")
continue
}
if verbose && !printChanged {
fmt.Println(file)
}
if renamed != file {
if !dryRun {
err = move(renamed, file, force)
if err != nil {
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
continue
}
}
if printChanged && !verbose {
fmt.Println(file, "->", renamed)
}
}
}
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -66,24 +67,25 @@ func main() {
for _, remote := range flag.Args() { for _, remote := range flag.Args() {
u, err := url.Parse(remote) u, err := url.Parse(remote)
if err != nil { if err != nil {
lib.Warn(err, "parsing %s", remote) _, _ = lib.Warn(err, "parsing %s", remote)
continue continue
} }
name := filepath.Base(u.Path) name := filepath.Base(u.Path)
if name == "" { if name == "" {
lib.Warnx("source URL doesn't appear to name a file") _, _ = lib.Warnx("source URL doesn't appear to name a file")
continue continue
} }
resp, err := http.Get(remote) req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
if err != nil { if reqErr != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(reqErr, "building request for %s", remote)
continue continue
} }
client := &http.Client{}
resp, err := client.Do(req)
if err != nil { if err != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(err, "fetching %s", remote)
continue continue
} }

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
sum := 0 sum := 0
var rolls []int var rolls []int
for i := 0; i < count; i++ { for range count {
roll := rand.Intn(sides) + 1 roll := rand.IntN(sides) + 1 // #nosec G404
sum += roll sum += roll
rolls = append(rolls, roll) rolls = append(rolls, roll)
} }

View File

@@ -53,7 +53,7 @@ func init() {
project = wd[len(gopath):] project = wd[len(gopath):]
} }
func walkFile(path string, info os.FileInfo, err error) error { func walkFile(path string, _ os.FileInfo, err error) error {
if ignores[path] { if ignores[path] {
return filepath.SkipDir return filepath.SkipDir
} }
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
return nil return nil
} }
debug.Println(path)
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err != nil { if err != nil {
return err return err
} }
debug.Println(path)
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err2 != nil {
return err2
}
for _, importSpec := range f.Imports { for _, importSpec := range f.Imports {
importPath := strings.Trim(importSpec.Path.Value, `"`) importPath := strings.Trim(importSpec.Path.Value, `"`)
if stdLibRegexp.MatchString(importPath) { switch {
case stdLibRegexp.MatchString(importPath):
debug.Println("standard lib:", importPath) debug.Println("standard lib:", importPath)
continue continue
} else if strings.HasPrefix(importPath, project) { case strings.HasPrefix(importPath, project):
debug.Println("internal import:", importPath) debug.Println("internal import:", importPath)
continue continue
} else if strings.HasPrefix(importPath, "golang.org/") { case strings.HasPrefix(importPath, "golang.org/"):
debug.Println("extended lib:", importPath) debug.Println("extended lib:", importPath)
continue continue
} }
@@ -102,7 +107,7 @@ func main() {
ignores["vendor"] = true ignores["vendor"] = true
} }
for _, word := range strings.Split(ignoreLine, ",") { for word := range strings.SplitSeq(ignoreLine, ",") {
ignores[strings.TrimSpace(word)] = true ignores[strings.TrimSpace(word)] = true
} }

View File

@@ -2,10 +2,9 @@ package main
import ( import (
"bytes" "bytes"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1" // #nosec G505
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
@@ -13,14 +12,18 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) { func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -28,10 +31,10 @@ Usage:
ski [-hm] files... ski [-hm] files...
Flags: Flags:
-d Hex encoding mode.
-h Print this help message. -h Print this help message.
-m All SKIs should match; as soon as an SKI mismatch is found, -m All SKIs should match; as soon as an SKI mismatch is found,
it is reported. it is reported.
`) `)
} }
@@ -39,14 +42,14 @@ func init() {
flag.Usage = func() { usage(os.Stderr) } flag.Usage = func() { usage(os.Stderr) }
} }
func parse(path string) (public []byte, kt, ft string) { func parse(path string) ([]byte, string, string) {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
die.If(err) die.If(err)
data = bytes.TrimSpace(data) data = bytes.TrimSpace(data)
p, rest := pem.Decode(data) p, rest := pem.Decode(data)
if len(rest) > 0 { if len(rest) > 0 {
lib.Warnx("trailing data in PEM file") _, _ = lib.Warnx("trailing data in PEM file")
} }
if p == nil { if p == nil {
@@ -55,6 +58,12 @@ func parse(path string) (public []byte, kt, ft string) {
data = p.Bytes data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type { switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY": case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data) public, kt = parseKey(data)
@@ -69,84 +78,74 @@ func parse(path string) (public []byte, kt, ft string) {
die.With("unknown PEM type %s", p.Type) die.With("unknown PEM type %s", p.Type)
} }
return return public, kt, ft
} }
func parseKey(data []byte) (public []byte, kt string) { func parseKey(data []byte) ([]byte, string) {
privInterface, err := x509.ParsePKCS8PrivateKey(data) priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil { if err != nil {
privInterface, err = x509.ParsePKCS1PrivateKey(data) die.If(err)
if err != nil {
privInterface, err = x509.ParseECPrivateKey(data)
if err != nil {
die.With("couldn't parse private key.")
}
}
} }
var priv crypto.Signer var kt string
switch privInterface.(type) { switch priv.Public().(type) {
case *rsa.PrivateKey: case *rsa.PublicKey:
priv = privInterface.(*rsa.PrivateKey) kt = keyTypeRSA
kt = "RSA" case *ecdsa.PublicKey:
case *ecdsa.PrivateKey: kt = keyTypeECDSA
priv = privInterface.(*ecdsa.PrivateKey)
kt = "ECDSA"
default: default:
die.With("unknown private key type %T", privInterface) die.With("unknown private key type %T", priv)
} }
public, err = x509.MarshalPKIXPublicKey(priv.Public()) public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err) die.If(err)
return return public, kt
} }
func parseCertificate(data []byte) (public []byte, kt string) { func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data) cert, err := x509.ParseCertificate(data)
die.If(err) die.If(err)
pub := cert.PublicKey pub := cert.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func parseCSR(data []byte) (public []byte, kt string) { func parseCSR(data []byte) ([]byte, string) {
csr, err := x509.ParseCertificateRequest(data) // Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err) die.If(err)
pub := csr.PublicKey pub := csr.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func dumpHex(in []byte) string { func dumpHex(in []byte, mode lib.HexEncodeMode) string {
var s string return lib.HexEncode(in, mode)
for i := range in {
s += fmt.Sprintf("%02X:", in[i])
}
return strings.Trim(s, ":")
} }
type subjectPublicKeyInfo struct { type subjectPublicKeyInfo struct {
@@ -156,10 +155,14 @@ type subjectPublicKeyInfo struct {
func main() { func main() {
var help, shouldMatch bool var help, shouldMatch bool
var displayModeString string
flag.StringVar(&displayModeString, "d", "lower", "hex encoding mode")
flag.BoolVar(&help, "h", false, "print a help message and exit") flag.BoolVar(&help, "h", false, "print a help message and exit")
flag.BoolVar(&shouldMatch, "m", false, "all SKIs should match") flag.BoolVar(&shouldMatch, "m", false, "all SKIs should match")
flag.Parse() flag.Parse()
displayMode := lib.ParseHexEncodeMode(displayModeString)
if help { if help {
usage(os.Stdout) usage(os.Stdout)
os.Exit(0) os.Exit(0)
@@ -172,18 +175,18 @@ func main() {
var subPKI subjectPublicKeyInfo var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI) _, err := asn1.Unmarshal(public, &subPKI)
if err != nil { if err != nil {
lib.Warn(err, "failed to get subject PKI") _, _ = lib.Warn(err, "failed to get subject PKI")
continue continue
} }
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:]) pubHashString := dumpHex(pubHash[:], displayMode)
if ski == "" { if ski == "" {
ski = pubHashString ski = pubHashString
} }
if shouldMatch && ski != pubHashString { if shouldMatch && ski != pubHashString {
lib.Warnx("%s: SKI mismatch (%s != %s)", _, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString) path, ski, pubHashString)
} }
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft) fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)

View File

@@ -1,16 +1,17 @@
package main package main
import ( import (
"context"
"flag" "flag"
"io" "io"
"log"
"net" "net"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
) )
func proxy(conn net.Conn, inside string) error { func proxy(conn net.Conn, inside string) error {
proxyConn, err := net.Dial("tcp", inside) proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
if err != nil { if err != nil {
return err return err
} }
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
defer conn.Close() defer conn.Close()
go func() { go func() {
io.Copy(conn, proxyConn) _, _ = io.Copy(conn, proxyConn)
}() }()
_, err = io.Copy(proxyConn, conn) _, err = io.Copy(proxyConn, conn)
return err return err
@@ -31,16 +32,22 @@ func main() {
flag.StringVar(&inside, "p", "4000", "inside port") flag.StringVar(&inside, "p", "4000", "inside port")
flag.Parse() flag.Parse()
l, err := net.Listen("tcp", "0.0.0.0:"+outside) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
die.If(err) die.If(err)
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
log.Println(err) _, _ = lib.Warn(err, "accept failed")
continue continue
} }
go proxy(conn, "127.0.0.1:"+inside) go func() {
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
_, _ = lib.Warn(err, "proxy error")
}
}()
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@@ -8,7 +9,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -16,7 +16,7 @@ import (
) )
func main() { func main() {
cfg := &tls.Config{} cfg := &tls.Config{} // #nosec G402
var sysRoot, listenAddr, certFile, keyFile string var sysRoot, listenAddr, certFile, keyFile string
var verify bool var verify bool
@@ -47,7 +47,8 @@ func main() {
} }
cfg.Certificates = append(cfg.Certificates, cert) cfg.Certificates = append(cfg.Certificates, cert)
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) var pemList []byte
pemList, err = os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -59,48 +60,54 @@ func main() {
cfg.RootCAs = roots cfg.RootCAs = roots
} }
l, err := net.Listen("tcp", listenAddr) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
} }
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
}
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
err = tconn.Handshake()
if err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
continue continue
} }
cs := tconn.ConnectionState() handleConn(conn, cfg)
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
continue
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
_, err = rand.Read(nonce[:])
if err != nil {
panic(err)
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
err = ioutil.WriteFile(fname, chain, 0644)
die.If(err)
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
} }
} }
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
func handleConn(conn net.Conn, cfg *tls.Config) {
defer conn.Close()
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
if err := tconn.HandshakeContext(context.Background()); err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
return
}
cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
return
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
if _, err := rand.Read(nonce[:]); err != nil {
fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
return
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
if err := os.WriteFile(fname, chain, 0o644); err != nil {
fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
return
}
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
}

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -14,7 +14,7 @@ import (
) )
func main() { func main() {
var cfg = &tls.Config{} var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string var sysRoot, serverName string
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
@@ -23,7 +23,7 @@ func main() {
flag.Parse() flag.Parse()
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) pemList, err := os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -44,10 +44,13 @@ func main() {
if err != nil { if err != nil {
site += ":443" site += ":443"
} }
conn, err := tls.Dial("tcp", site, cfg) d := &tls.Dialer{Config: cfg}
if err != nil { nc, err := d.DialContext(context.Background(), "tcp", site)
fmt.Println(err.Error()) die.If(err)
os.Exit(1)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
cs := conn.ConnectionState() cs := conn.ConnectionState()
@@ -61,8 +64,9 @@ func main() {
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
err = ioutil.WriteFile(site+".pem", chain, 0644) err = os.WriteFile(site+".pem", chain, 0644)
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.pem.\n", site) fmt.Printf("[+] wrote %s.pem.\n", site)
} }
} }

View File

@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
for _, path := range paths { for _, path := range paths {
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
lib.Warn(err, "failed to load certificate from %s", path) _, _ = lib.Warn(err, "failed to load certificate from %s", path)
continue continue
} }
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
} }
var invalid int var invalid int
for { for len(paths) > 0 {
if len(paths) == 0 {
break
}
fst := paths[0] fst := paths[0]
snd := paths[1] snd := paths[1]
paths = paths[2:] paths = paths[2:]
fstCert, err := certlib.LoadCertificate(fst) fstCert, err := certlib.LoadCertificate(fst)
die.If(err) die.If(err)
sndCert, err := certlib.LoadCertificate(snd) sndCert, err := certlib.LoadCertificate(snd)
die.If(err) die.If(err)
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
lib.Warnx("certificates don't match: %s and %s", fst, snd) _, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
invalid++ invalid++
} }
} }

View File

@@ -1,10 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/die"
) )
func main() { func main() {
@@ -13,16 +17,23 @@ func main() {
os.Exit(1) os.Exit(1)
} }
hostPort := os.Args[1] hostPort, err := hosts.ParseHost(os.Args[1])
conn, err := tls.Dial("tcp", hostPort, &tls.Config{ die.If(err)
InsecureSkipVerify: true,
})
if err != nil { d := &tls.Dialer{Config: &tls.Config{
fmt.Printf("Failed to connect to the TLS server: %v\n", err) InsecureSkipVerify: true,
os.Exit(1) }} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
printConnectionDetails(state) printConnectionDetails(state)
} }
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
func tlsVersion(version uint16) string { func tlsVersion(version uint16) string {
switch version { switch version {
case tls.VersionTLS13: case tls.VersionTLS13:
return "TLS 1.3" return "TLS 1.3"
case tls.VersionTLS12: case tls.VersionTLS12:

View File

@@ -11,10 +11,9 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -32,7 +31,7 @@ const (
curveP521 curveP521
) )
func getECCurve(pub interface{}) int { func getECCurve(pub any) int {
switch pub := pub.(type) { switch pub := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
switch pub.Curve { switch pub.Curve {
@@ -52,42 +51,88 @@ func getECCurve(pub interface{}) int {
} }
} }
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) { func loadKey(path string) (crypto.Signer, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, _ := pem.Decode(in) if p, _ := pem.Decode(in); p != nil {
if p != nil {
if !validPEMs[p.Type] { if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type) return nil, errors.New("invalid private key file type " + p.Type)
} }
in = p.Bytes return certlib.ParsePrivateKeyPEM(in)
} }
priv, err := x509.ParsePKCS8PrivateKey(in) return certlib.ParsePrivateKeyDER(in)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(in)
if err != nil {
priv, err = x509.ParseECPrivateKey(in)
if err != nil {
return nil, err
}
}
}
switch priv.(type) {
case *rsa.PrivateKey:
return priv.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return priv.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, errors.New("invalid private key")
} }
func main() { func main() {
@@ -96,7 +141,7 @@ func main() {
flag.StringVar(&certFile, "c", "", "TLS `certificate` file") flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse() flag.Parse()
in, err := ioutil.ReadFile(certFile) in, err := os.ReadFile(certFile)
die.If(err) die.If(err)
p, _ := pem.Decode(in) p, _ := pem.Decode(in)
@@ -112,50 +157,11 @@ func main() {
priv, err := loadKey(keyFile) priv, err := loadKey(keyFile)
die.If(err) die.If(err)
switch pub := priv.Public().(type) { matched, reason := matchKeys(cert, priv)
case *rsa.PublicKey: if matched {
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.")
return
case *ecdsa.PublicKey:
fmt.Println("No match (RSA private key, EC public key).")
os.Exit(1)
}
case *ecdsa.PublicKey:
privCurve := getECCurve(pub)
certCurve := getECCurve(cert.PublicKey)
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
if certCurve == curveRSA {
fmt.Println("No match (private key is EC, certificate is RSA).")
os.Exit(1)
} else if privCurve == curveInvalid {
fmt.Println("No match (invalid private key curve).")
os.Exit(1)
} else if privCurve != certCurve {
fmt.Println("No match (EC curves don't match).")
os.Exit(1)
}
certPub := cert.PublicKey.(*ecdsa.PublicKey)
if pub.X.Cmp(certPub.X) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
if pub.Y.Cmp(certPub.Y) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.") fmt.Println("Match.")
default: return
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
os.Exit(1)
} }
fmt.Printf("No match (%s).\n", reason)
os.Exit(1)
} }

View File

@@ -201,10 +201,6 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if fromLoc == time.UTC {
}
toLoc = time.UTC toLoc = time.UTC
} }
@@ -257,15 +253,16 @@ func main() {
showTime(time.Now()) showTime(time.Now())
os.Exit(0) os.Exit(0)
case 1: case 1:
if flag.Arg(0) == "-" { switch {
case flag.Arg(0) == "-":
s := bufio.NewScanner(os.Stdin) s := bufio.NewScanner(os.Stdin)
for s.Scan() { for s.Scan() {
times = append(times, s.Text()) times = append(times, s.Text())
} }
} else if flag.Arg(0) == "help" { case flag.Arg(0) == "help":
usageExamples() usageExamples()
} else { default:
times = flag.Args() times = flag.Args()
} }
default: default:

View File

@@ -4,7 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@@ -12,9 +11,8 @@ import (
type empty struct{} type empty struct{}
func errorf(format string, args ...interface{}) { func errorf(path string, err error) {
format += "\n" fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
fmt.Fprintf(os.Stderr, format, args...)
} }
func usage(w io.Writer) { func usage(w io.Writer) {
@@ -44,16 +42,16 @@ func main() {
if flag.NArg() == 1 && flag.Arg(0) == "-" { if flag.NArg() == 1 && flag.Arg(0) == "-" {
path := "stdin" path := "stdin"
in, err := ioutil.ReadAll(os.Stdin) in, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
@@ -65,16 +63,16 @@ func main() {
} }
for _, path := range flag.Args() { for _, path := range flag.Args() {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }

View File

@@ -14,16 +14,16 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"git.wntrmute.dev/kyle/goutils/lib"
) )
const defaultDirectory = ".git/objects" const defaultDirectory = ".git/objects"
func errorf(format string, a ...interface{}) { // maxDecompressedSize limits how many bytes we will decompress from a zlib
fmt.Fprintf(os.Stderr, format, a...) // stream to mitigate decompression bombs (gosec G110).
if format[len(format)-1] != '\n' { // Increase this if you expect larger objects.
fmt.Fprintf(os.Stderr, "\n") const maxDecompressedSize int64 = 64 << 30 // 64 GiB
}
}
func isDir(path string) bool { func isDir(path string) bool {
fi, err := os.Stat(path) fi, err := os.Stat(path)
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
} }
defer zread.Close() defer zread.Close()
_, err = io.Copy(buf, zread) // Protect against decompression bombs by limiting how much we read.
if err != nil { lr := io.LimitReader(zread, maxDecompressedSize+1)
if _, err = buf.ReadFrom(lr); err != nil {
return nil, err return nil, err
} }
if int64(buf.Len()) > maxDecompressedSize {
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
}
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func showFile(path string) { func showFile(path string) {
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return return
} }
@@ -68,39 +72,71 @@ func showFile(path string) {
func searchFile(path string, search *regexp.Regexp) error { func searchFile(path string, search *regexp.Regexp) error {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to open %s", path)
return err return err
} }
defer file.Close() defer file.Close()
zread, err := zlib.NewReader(file) zread, err := zlib.NewReader(file)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to decompress %s", path)
return err return err
} }
defer zread.Close() defer zread.Close()
zbuf := bufio.NewReader(zread) // Limit how much we scan to avoid DoS via huge decompression.
if search.MatchReader(zbuf) { lr := io.LimitReader(zread, maxDecompressedSize+1)
fileData, err := loadFile(path) zbuf := bufio.NewReader(lr)
if err != nil { if !search.MatchReader(zbuf) {
errorf("%v", err) return nil
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
} }
fileData, err := loadFile(path)
if err != nil {
lib.Warn(err, "failed to load %s", path)
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
return nil return nil
} }
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc { func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error { return func(path string, info os.FileInfo, _ error) error {
if info.Mode().IsRegular() { if !info.Mode().IsRegular() {
return searchFile(path, searchExpr) return nil
} }
return nil return searchFile(path, searchExpr)
} }
} }
// runSearch compiles the search expression and processes the provided paths.
// It returns an error for fatal conditions; per-file errors are logged.
func runSearch(expr string) error {
search, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("invalid regexp: %w", err)
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
return err2
}
continue
}
if err2 := searchFile(path, search); err2 != nil {
// Non-fatal: keep going, but report it.
lib.Warn(err2, "non-fatal error while searching files")
}
}
return nil
}
func main() { func main() {
flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)") flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)")
flag.Parse() flag.Parse()
@@ -109,28 +145,10 @@ func main() {
for _, path := range flag.Args() { for _, path := range flag.Args() {
showFile(path) showFile(path)
} }
} else { return
search, err := regexp.Compile(*flSearch) }
if err != nil {
errorf("Bad regexp: %v", err)
return
}
pathList := flag.Args() if err := runSearch(*flSearch); err != nil {
if len(pathList) == 0 { lib.Err(lib.ExitFailure, err, "failed to run search")
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
err := filepath.Walk(path, buildWalker(search))
if err != nil {
errorf("%v", err)
return
}
} else {
searchFile(path, search)
}
}
} }
} }

View File

@@ -11,7 +11,7 @@ package config
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"log" "maps"
"os" "os"
"sort" "sort"
"strings" "strings"
@@ -33,14 +33,15 @@ func SetEnvPrefix(pfx string) {
prefix = pfx prefix = pfx
} }
const keyValueSplitLength = 2
func addLine(line string) { func addLine(line string) {
if strings.HasPrefix(line, "#") || line == "" { if strings.HasPrefix(line, "#") || line == "" {
return return
} }
lineParts := strings.SplitN(line, "=", 2) lineParts := strings.SplitN(line, "=", keyValueSplitLength)
if len(lineParts) != 2 { if len(lineParts) != keyValueSplitLength {
log.Print("skipping line: ", line)
return // silently ignore empty keys return // silently ignore empty keys
} }
@@ -49,7 +50,7 @@ func addLine(line string) {
vars[lineParts[0]] = lineParts[1] vars[lineParts[0]] = lineParts[1]
} }
// LoadFile scans the file at path for key=value pairs and adds them // LoadFile scans the file at 'path' for key=value pairs and adds them
// to the configuration. // to the configuration.
func LoadFile(path string) error { func LoadFile(path string) error {
file, err := os.Open(path) file, err := os.Open(path)
@@ -67,18 +68,16 @@ func LoadFile(path string) error {
return scanner.Err() return scanner.Err()
} }
// LoadFileFor scans the ini file at path, loading the default section // LoadFileFor scans the ini file at 'path', loading the default section
// and overriding any keys found under section. If strict is true, the // and overriding any keys found under 'section'. If strict is true, the
// named section must exist (i.e. to catch typos in the section name). // named section must exist (i.e., to catch typos in the section name).
func LoadFileFor(path, section string, strict bool) error { func LoadFileFor(path, section string, strict bool) error {
cmap, err := iniconf.ParseFile(path) cmap, err := iniconf.ParseFile(path)
if err != nil { if err != nil {
return err return err
} }
for key, value := range cmap[iniconf.DefaultSection] { maps.Copy(vars, cmap[iniconf.DefaultSection])
vars[key] = value
}
smap, ok := cmap[section] smap, ok := cmap[section]
if !ok { if !ok {
@@ -88,9 +87,7 @@ func LoadFileFor(path, section string, strict bool) error {
return nil return nil
} }
for key, value := range smap { maps.Copy(vars, smap)
vars[key] = value
}
return nil return nil
} }
@@ -107,7 +104,7 @@ func Get(key string) string {
// GetDefault retrieves a value from either a configuration file or // GetDefault retrieves a value from either a configuration file or
// the environment. Note that value from a file will override // the environment. Note that value from a file will override
// environment variables. If a value isn't found (e.g. Get returns an // environment variables. If a value isn't found (e.g., Get returns an
// empty string), the default value will be used. // empty string), the default value will be used.
func GetDefault(key, def string) string { func GetDefault(key, def string) string {
if v := Get(key); v != "" { if v := Get(key); v != "" {
@@ -117,8 +114,7 @@ func GetDefault(key, def string) string {
} }
// Require retrieves a value from either a configuration file or the // Require retrieves a value from either a configuration file or the
// environment. If the key isn't present, it will call log.Fatal, printing // environment. If the key isn't present, it will panic.
// the missing key.
func Require(key string) string { func Require(key string) string {
if v, ok := vars[key]; ok { if v, ok := vars[key]; ok {
return v return v
@@ -131,7 +127,7 @@ func Require(key string) string {
envMessage = " (note: looked for the key " + prefix + key envMessage = " (note: looked for the key " + prefix + key
envMessage += " in the local env)" envMessage += " in the local env)"
} }
log.Fatalf("missing required configuration value %s%s", key, envMessage) panic(fmt.Sprintf("missing required configuration value %s%s", key, envMessage))
} }
return v return v
@@ -139,7 +135,8 @@ func Require(key string) string {
// ListKeys returns a slice of the currently known keys. // ListKeys returns a slice of the currently known keys.
func ListKeys() []string { func ListKeys() []string {
keyList := []string{} var keyList []string
for k := range vars { for k := range vars {
keyList = append(keyList, k) keyList = append(keyList, k)
} }

View File

@@ -1,27 +1,26 @@
package config package config_test
import ( import (
"os" "os"
"testing" "testing"
"git.wntrmute.dev/kyle/goutils/config"
) )
const ( const (
testFilePath = "testdata/test.env" testFilePath = "testdata/test.env"
// Keys // Key constants.
kOrder = "ORDER" kOrder = "ORDER"
kSpecies = "SPECIES" kSpecies = "SPECIES"
kName = "COMMON_NAME" kName = "COMMON_NAME"
// Env
eOrder = "corvus" eOrder = "corvus"
eSpecies = "corvus corax" eSpecies = "corvus corax"
eName = "northern raven" eName = "northern raven"
// File
fOrder = "stringiformes" fOrder = "stringiformes"
fSpecies = "strix aluco" fSpecies = "strix aluco"
// Name isn't set in the file to test fall through.
) )
func init() { func init() {
@@ -31,8 +30,8 @@ func init() {
} }
func TestLoadEnvOnly(t *testing.T) { func TestLoadEnvOnly(t *testing.T) {
order := Get(kOrder) order := config.Get(kOrder)
species := Get(kSpecies) species := config.Get(kSpecies)
if order != eOrder { if order != eOrder {
t.Errorf("want %s, have %s", eOrder, order) t.Errorf("want %s, have %s", eOrder, order)
} }
@@ -43,14 +42,14 @@ func TestLoadEnvOnly(t *testing.T) {
} }
func TestLoadFile(t *testing.T) { func TestLoadFile(t *testing.T) {
err := LoadFile(testFilePath) err := config.LoadFile(testFilePath)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
order := Get(kOrder) order := config.Get(kOrder)
species := Get(kSpecies) species := config.Get(kSpecies)
name := Get(kName) name := config.Get(kName)
if order != fOrder { if order != fOrder {
t.Errorf("want %s, have %s", fOrder, order) t.Errorf("want %s, have %s", fOrder, order)

View File

@@ -2,6 +2,7 @@ package iniconf
import ( import (
"bufio" "bufio"
"errors"
"fmt" "fmt"
"io" "io"
"os" "os"
@@ -23,30 +24,31 @@ var (
var DefaultSection = "default" var DefaultSection = "default"
// ParseFile attempts to load the named config file. // ParseFile attempts to load the named config file.
func ParseFile(fileName string) (cfg ConfigMap, err error) { func ParseFile(fileName string) (ConfigMap, error) {
var file *os.File file, err := os.Open(fileName)
file, err = os.Open(fileName)
if err != nil { if err != nil {
return return nil, err
} }
defer file.Close() defer file.Close()
return ParseReader(file) return ParseReader(file)
} }
// ParseReader reads a configuration from an io.Reader. // ParseReader reads a configuration from an io.Reader.
func ParseReader(r io.Reader) (cfg ConfigMap, err error) { func ParseReader(r io.Reader) (ConfigMap, error) {
cfg = ConfigMap{} cfg := ConfigMap{}
buf := bufio.NewReader(r) buf := bufio.NewReader(r)
var ( var (
line string line string
longLine bool longLine bool
currentSection string currentSection string
err error
) )
for { for {
line, longLine, err = readConfigLine(buf, line, longLine) line, longLine, err = readConfigLine(buf, line, longLine)
if err == io.EOF { if errors.Is(err, io.EOF) {
err = nil err = nil
break break
} else if err != nil { } else if err != nil {
@@ -62,11 +64,12 @@ func ParseReader(r io.Reader) (cfg ConfigMap, err error) {
break break
} }
} }
return
return cfg, err
} }
// readConfigLine reads and assembles a complete configuration line, handling long lines. // readConfigLine reads and assembles a complete configuration line, handling long lines.
func readConfigLine(buf *bufio.Reader, currentLine string, longLine bool) (line string, stillLong bool, err error) { func readConfigLine(buf *bufio.Reader, currentLine string, longLine bool) (string, bool, error) {
lineBytes, isPrefix, err := buf.ReadLine() lineBytes, isPrefix, err := buf.ReadLine()
if err != nil { if err != nil {
return "", false, err return "", false, err
@@ -94,14 +97,14 @@ func processConfigLine(cfg ConfigMap, line string, currentSection string) (strin
return handleConfigLine(cfg, line, currentSection) return handleConfigLine(cfg, line, currentSection)
} }
return currentSection, fmt.Errorf("invalid config file") return currentSection, errors.New("invalid config file")
} }
// handleSectionLine processes a section header line. // handleSectionLine processes a section header line.
func handleSectionLine(cfg ConfigMap, line string) (string, error) { func handleSectionLine(cfg ConfigMap, line string) (string, error) {
section := configSection.ReplaceAllString(line, "$1") section := configSection.ReplaceAllString(line, "$1")
if section == "" { if section == "" {
return "", fmt.Errorf("invalid structure in file") return "", errors.New("invalid structure in file")
} }
if !cfg.SectionInConfig(section) { if !cfg.SectionInConfig(section) {
cfg[section] = make(map[string]string, 0) cfg[section] = make(map[string]string, 0)
@@ -139,41 +142,39 @@ func (c ConfigMap) SectionInConfig(section string) bool {
} }
// ListSections returns the list of sections in the config map. // ListSections returns the list of sections in the config map.
func (c ConfigMap) ListSections() (sections []string) { func (c ConfigMap) ListSections() []string {
sections := make([]string, 0, len(c))
for section := range c { for section := range c {
sections = append(sections, section) sections = append(sections, section)
} }
return return sections
} }
// WriteFile writes out the configuration to a file. // WriteFile writes out the configuration to a file.
func (c ConfigMap) WriteFile(filename string) (err error) { func (c ConfigMap) WriteFile(filename string) error {
file, err := os.Create(filename) file, err := os.Create(filename)
if err != nil { if err != nil {
return return err
} }
defer file.Close() defer file.Close()
for _, section := range c.ListSections() { for _, section := range c.ListSections() {
sName := fmt.Sprintf("[ %s ]\n", section) sName := fmt.Sprintf("[ %s ]\n", section)
_, err = file.Write([]byte(sName)) if _, err = file.WriteString(sName); err != nil {
if err != nil { return err
return
} }
for k, v := range c[section] { for k, v := range c[section] {
line := fmt.Sprintf("%s = %s\n", k, v) line := fmt.Sprintf("%s = %s\n", k, v)
_, err = file.Write([]byte(line)) if _, err = file.WriteString(line); err != nil {
if err != nil { return err
return
} }
} }
_, err = file.Write([]byte{0x0a}) if _, err = file.Write([]byte{0x0a}); err != nil {
if err != nil { return err
return
} }
} }
return return nil
} }
// AddSection creates a new section in the config map. // AddSection creates a new section in the config map.
@@ -197,27 +198,26 @@ func (c ConfigMap) AddKeyVal(section, key, val string) {
} }
// GetValue retrieves the value from a key map. // GetValue retrieves the value from a key map.
func (c ConfigMap) GetValue(section, key string) (val string, present bool) { func (c ConfigMap) GetValue(section, key string) (string, bool) {
if c == nil { if c == nil {
return return "", false
} }
if section == "" { if section == "" {
section = DefaultSection section = DefaultSection
} }
_, ok := c[section] if _, ok := c[section]; !ok {
if !ok { return "", false
return
} }
val, present = c[section][key] val, present := c[section][key]
return return val, present
} }
// GetValueDefault retrieves the value from a key map if present, // GetValueDefault retrieves the value from a key map if present,
// otherwise the default value. // otherwise the default value.
func (c ConfigMap) GetValueDefault(section, key, value string) (val string) { func (c ConfigMap) GetValueDefault(section, key, value string) string {
kval, ok := c.GetValue(section, key) kval, ok := c.GetValue(section, key)
if !ok { if !ok {
return value return value
@@ -226,7 +226,7 @@ func (c ConfigMap) GetValueDefault(section, key, value string) (val string) {
} }
// SectionKeys returns the sections in the config map. // SectionKeys returns the sections in the config map.
func (c ConfigMap) SectionKeys(section string) (keys []string, present bool) { func (c ConfigMap) SectionKeys(section string) ([]string, bool) {
if c == nil { if c == nil {
return nil, false return nil, false
} }
@@ -235,13 +235,12 @@ func (c ConfigMap) SectionKeys(section string) (keys []string, present bool) {
section = DefaultSection section = DefaultSection
} }
cm := c s, ok := c[section]
s, ok := cm[section]
if !ok { if !ok {
return nil, false return nil, false
} }
keys = make([]string, 0, len(s)) keys := make([]string, 0, len(s))
for key := range s { for key := range s {
keys = append(keys, key) keys = append(keys, key)
} }

View File

@@ -1,18 +1,19 @@
package iniconf package iniconf_test
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"sort" "sort"
"testing" "testing"
"git.wntrmute.dev/kyle/goutils/config/iniconf"
) )
// FailWithError is a utility for dumping errors and failing the test. // FailWithError is a utility for dumping errors and failing the test.
func FailWithError(t *testing.T, err error) { func FailWithError(t *testing.T, err error) {
fmt.Println("failed") t.Log("failed")
if err != nil { if err != nil {
fmt.Println("[!] ", err.Error()) t.Log("[!] ", err.Error())
} }
t.FailNow() t.FailNow()
} }
@@ -49,47 +50,50 @@ func stringSlicesEqual(slice1, slice2 []string) bool {
func TestGoodConfig(t *testing.T) { func TestGoodConfig(t *testing.T) {
testFile := "testdata/test.conf" testFile := "testdata/test.conf"
fmt.Printf("[+] validating known-good config... ") t.Logf("[+] validating known-good config... ")
cmap, err := ParseFile(testFile) cmap, err := iniconf.ParseFile(testFile)
if err != nil { if err != nil {
FailWithError(t, err) FailWithError(t, err)
} else if len(cmap) != 2 { } else if len(cmap) != 2 {
FailWithError(t, err) FailWithError(t, err)
} }
fmt.Println("ok") t.Log("ok")
} }
func TestGoodConfig2(t *testing.T) { func TestGoodConfig2(t *testing.T) {
testFile := "testdata/test2.conf" testFile := "testdata/test2.conf"
fmt.Printf("[+] validating second known-good config... ") t.Logf("[+] validating second known-good config... ")
cmap, err := ParseFile(testFile) cmap, err := iniconf.ParseFile(testFile)
if err != nil { switch {
case err != nil:
FailWithError(t, err) FailWithError(t, err)
} else if len(cmap) != 1 { case len(cmap) != 1:
FailWithError(t, err) FailWithError(t, err)
} else if len(cmap["default"]) != 3 { case len(cmap["default"]) != 3:
FailWithError(t, err) FailWithError(t, err)
default:
// nothing to do here
} }
fmt.Println("ok") t.Log("ok")
} }
func TestBadConfig(t *testing.T) { func TestBadConfig(t *testing.T) {
testFile := "testdata/bad.conf" testFile := "testdata/bad.conf"
fmt.Printf("[+] ensure invalid config file fails... ") t.Logf("[+] ensure invalid config file fails... ")
_, err := ParseFile(testFile) _, err := iniconf.ParseFile(testFile)
if err == nil { if err == nil {
err = fmt.Errorf("invalid config file should fail") err = errors.New("invalid config file should fail")
FailWithError(t, err) FailWithError(t, err)
} }
fmt.Println("ok") t.Log("ok")
} }
func TestWriteConfigFile(t *testing.T) { func TestWriteConfigFile(t *testing.T) {
fmt.Printf("[+] ensure config file is written properly... ") t.Logf("[+] ensure config file is written properly... ")
const testFile = "testdata/test.conf" const testFile = "testdata/test.conf"
const testOut = "testdata/test.out" const testOut = "testdata/test.out"
cmap, err := ParseFile(testFile) cmap, err := iniconf.ParseFile(testFile)
if err != nil { if err != nil {
FailWithError(t, err) FailWithError(t, err)
} }
@@ -100,7 +104,7 @@ func TestWriteConfigFile(t *testing.T) {
FailWithError(t, err) FailWithError(t, err)
} }
cmap2, err := ParseFile(testOut) cmap2, err := iniconf.ParseFile(testOut)
if err != nil { if err != nil {
FailWithError(t, err) FailWithError(t, err)
} }
@@ -110,25 +114,25 @@ func TestWriteConfigFile(t *testing.T) {
sort.Strings(sectionList1) sort.Strings(sectionList1)
sort.Strings(sectionList2) sort.Strings(sectionList2)
if !stringSlicesEqual(sectionList1, sectionList2) { if !stringSlicesEqual(sectionList1, sectionList2) {
err = fmt.Errorf("section lists don't match") err = errors.New("section lists don't match")
FailWithError(t, err) FailWithError(t, err)
} }
for _, section := range sectionList1 { for _, section := range sectionList1 {
for _, k := range cmap[section] { for _, k := range cmap[section] {
if cmap[section][k] != cmap2[section][k] { if cmap[section][k] != cmap2[section][k] {
err = fmt.Errorf("config key doesn't match") err = errors.New("config key doesn't match")
FailWithError(t, err) FailWithError(t, err)
} }
} }
} }
fmt.Println("ok") t.Log("ok")
} }
func TestQuotedValue(t *testing.T) { func TestQuotedValue(t *testing.T) {
testFile := "testdata/test.conf" testFile := "testdata/test.conf"
fmt.Printf("[+] validating quoted value... ") t.Logf("[+] validating quoted value... ")
cmap, _ := ParseFile(testFile) cmap, _ := iniconf.ParseFile(testFile)
val := cmap["sectionName"]["key4"] val := cmap["sectionName"]["key4"]
if val != " space at beginning and end " { if val != " space at beginning and end " {
FailWithError(t, errors.New("Wrong value in double quotes ["+val+"]")) FailWithError(t, errors.New("Wrong value in double quotes ["+val+"]"))
@@ -138,5 +142,5 @@ func TestQuotedValue(t *testing.T) {
if val != " is quoted with single quotes " { if val != " is quoted with single quotes " {
FailWithError(t, errors.New("Wrong value in single quotes ["+val+"]")) FailWithError(t, errors.New("Wrong value in single quotes ["+val+"]"))
} }
fmt.Println("ok") t.Log("ok")
} }

View File

@@ -1,5 +1,4 @@
//go:build !linux //go:build !linux
// +build !linux
package config package config

View File

@@ -1,7 +1,11 @@
package config package config_test
import "testing" import (
"testing"
"git.wntrmute.dev/kyle/goutils/config"
)
func TestDefaultPath(t *testing.T) { func TestDefaultPath(t *testing.T) {
t.Log(DefaultConfigPath("demoapp", "app.conf")) t.Log(config.DefaultConfigPath("demoapp", "app.conf"))
} }

View File

@@ -47,7 +47,7 @@ func ToFile(path string) (*DebugPrinter, error) {
}, nil }, nil
} }
// To sets up a new DebugPrint to an io.WriteCloser. // To will set up a new DebugPrint to an io.WriteCloser.
func To(w io.WriteCloser) *DebugPrinter { func To(w io.WriteCloser) *DebugPrinter {
return &DebugPrinter{ return &DebugPrinter{
out: w, out: w,
@@ -55,21 +55,21 @@ func To(w io.WriteCloser) *DebugPrinter {
} }
// Print calls fmt.Print if Enabled is true. // Print calls fmt.Print if Enabled is true.
func (dbg *DebugPrinter) Print(v ...interface{}) { func (dbg *DebugPrinter) Print(v ...any) {
if dbg.Enabled { if dbg.Enabled {
fmt.Fprint(dbg.out, v...) fmt.Fprint(dbg.out, v...)
} }
} }
// Println calls fmt.Println if Enabled is true. // Println calls fmt.Println if Enabled is true.
func (dbg *DebugPrinter) Println(v ...interface{}) { func (dbg *DebugPrinter) Println(v ...any) {
if dbg.Enabled { if dbg.Enabled {
fmt.Fprintln(dbg.out, v...) fmt.Fprintln(dbg.out, v...)
} }
} }
// Printf calls fmt.Printf if Enabled is true. // Printf calls fmt.Printf if Enabled is true.
func (dbg *DebugPrinter) Printf(format string, v ...interface{}) { func (dbg *DebugPrinter) Printf(format string, v ...any) {
if dbg.Enabled { if dbg.Enabled {
fmt.Fprintf(dbg.out, format, v...) fmt.Fprintf(dbg.out, format, v...)
} }

View File

@@ -2,7 +2,6 @@ package dbg
import ( import (
"fmt" "fmt"
"io/ioutil"
"os" "os"
"testing" "testing"
@@ -50,7 +49,7 @@ func TestTo(t *testing.T) {
} }
func TestToFile(t *testing.T) { func TestToFile(t *testing.T) {
testFile, err := ioutil.TempFile("", "dbg") testFile, err := os.CreateTemp(t.TempDir(), "dbg")
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
err = testFile.Close() err = testFile.Close()
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
@@ -103,7 +102,7 @@ func TestWriting(t *testing.T) {
} }
func TestToFileError(t *testing.T) { func TestToFileError(t *testing.T) {
testFile, err := ioutil.TempFile("", "dbg") testFile, err := os.CreateTemp(t.TempDir(), "dbg")
assert.NoErrorT(t, err) assert.NoErrorT(t, err)
err = testFile.Chmod(0400) err = testFile.Chmod(0400)
assert.NoErrorT(t, err) assert.NoErrorT(t, err)

View File

@@ -1,12 +0,0 @@
Simple fatal utilities for Go programs.
```
result, err := doSomething()
die.If(err)
ok := processResult(result)
if !ok {
die.With("failed to process result %s", result.Name)
}
```

View File

@@ -1,4 +1,5 @@
// Package die contains utilities for fatal error handling. // Package die contains utilities for fatal error handling. It
// presents simple fatal utilities for Go programs.
package die package die
import ( import (
@@ -15,14 +16,14 @@ func If(err error) {
} }
// With prints the message to stderr, appending a newline, and exits. // With prints the message to stderr, appending a newline, and exits.
func With(fstr string, args ...interface{}) { func With(fstr string, args ...any) {
out := fmt.Sprintf("[!] %s\n", fstr) out := fmt.Sprintf("[!] %s\n", fstr)
fmt.Fprintf(os.Stderr, out, args...) fmt.Fprintf(os.Stderr, out, args...)
os.Exit(1) os.Exit(1)
} }
// When prints the error to stderr and exits if cond is true. // When prints the error to stderr and exits if cond is true.
func When(cond bool, fstr string, args ...interface{}) { func When(cond bool, fstr string, args ...any) {
if cond { if cond {
With(fstr, args...) With(fstr, args...)
} }

View File

@@ -1,10 +1,10 @@
//go:build !windows //go:build !windows
// +build !windows
// Package fileutil contains common file functions. // Package fileutil contains common file functions.
package fileutil package fileutil
import ( import (
"math"
"os" "os"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
@@ -46,5 +46,9 @@ const (
// Access returns a boolean indicating whether the mode being checked // Access returns a boolean indicating whether the mode being checked
// for is valid. // for is valid.
func Access(path string, mode int) error { func Access(path string, mode int) error {
return unix.Access(path, uint32(mode)) // Validate the conversion to avoid potential integer overflow (gosec G115).
if mode < 0 || uint64(mode) > uint64(math.MaxUint32) {
return unix.EINVAL
}
return unix.Access(path, uint32(mode)) // #nosec G115 - handled above.
} }

View File

@@ -1,5 +1,4 @@
//go:build windows //go:build windows
// +build windows
// Package fileutil contains common file functions. // Package fileutil contains common file functions.
package fileutil package fileutil

Some files were not shown because too many files have changed in this diff Show More