Compare commits

...

50 Commits

Author SHA1 Message Date
83f88c49fe Import twofactor. 2025-11-16 18:45:34 -08:00
7c437ac45f Add 'twofactor/' from commit 'c999bf35b0e47de4f63d59abbe0d7efc76c13ced'
git-subtree-dir: twofactor
git-subtree-mainline: 4dc135cfe0
git-subtree-split: c999bf35b0
2025-11-16 18:43:03 -08:00
c999bf35b0 linter fixes. 2025-11-16 18:39:18 -08:00
4dc135cfe0 Update CHANGELOG for v1.11.2. 2025-11-16 13:18:38 -08:00
790113e189 cmd: refactor for code reuse. 2025-11-16 13:15:08 -08:00
8348c5fd65 Update CHANGELOG. 2025-11-16 11:09:02 -08:00
1eafb638a8 cmd: finish linting fixes 2025-11-16 11:03:12 -08:00
3ad562b6fa cmd: continuing linter fixes 2025-11-16 02:54:02 -08:00
0f77bd49dc cmd: continue lint fixes. 2025-11-16 01:32:19 -08:00
f31d74243f cmd: start linting fixes. 2025-11-16 00:36:19 -08:00
a573f1cd20 Update CHANGELOG. 2025-11-15 23:48:54 -08:00
f93cf5fa9c adding lru/mru cache. 2025-11-15 23:48:00 -08:00
b879d62384 cert-bundler: lint fixes 2025-11-15 23:27:50 -08:00
c99ffd4394 cmd: cleaning up programs 2025-11-15 23:17:40 -08:00
ed8c07c1c5 Add 'mru/' from commit '2899885c4220560df4f60e4c052a6ab9773a0386'
git-subtree-dir: mru
git-subtree-mainline: cf2b016433
git-subtree-split: 2899885c42
2025-11-15 22:54:26 -08:00
cf2b016433 certlib: complete overhaul. 2025-11-15 22:54:12 -08:00
2899885c42 linter fixes 2025-11-15 22:46:42 -08:00
0dcd18c6f1 clean up code
- travisci is long dead
- golangci-lint the repo
2024-12-02 13:47:43 -08:00
024d552293 add circle ci config 2024-12-02 13:26:34 -08:00
9cd2ced695 There are different keys for different hash sizes. 2024-12-02 13:16:32 -08:00
b92e16fa4d Handle evictions properly when cache is empty. 2023-08-27 18:01:16 -07:00
6fbdece4be Initial import. 2022-02-24 21:39:10 -08:00
619c08a13f Update travis to latest Go versions. 2020-10-31 08:06:11 -07:00
944a57bf0e Switch to go modules. 2020-10-31 07:41:31 -07:00
0857b29624 Actually support clock mocking. 2020-10-31 07:24:34 -07:00
CodeLingo Bot
e95404bfc5 Fix function comments based on best practices from Effective Go
Signed-off-by: CodeLingo Bot <bot@codelingo.io>
2020-10-30 14:57:11 -07:00
ujjwalsh
924654e7c4 Added Support for Linux on Power 2020-10-30 07:50:32 -07:00
9e0979e07f Support clock mocking.
This addresses #15.
2018-12-07 08:23:01 -08:00
Aaron Bieber
bbc82ff8de Pad non-padded secrets. This lets us continue building on <= go1.8.
- Add tests for secrets using various padding methods.
- Add a new method/test to append padding to non-padded secrets.
2018-04-18 13:39:21 -07:00
Aaron Bieber
5fd928f69a Decode using WithPadding as pointed out by @gl-sergei.
This makes us print the same 6 digits as oathtool for non-padded
secrets like "a6mryljlbufszudtjdt42nh5by".
2018-04-18 13:39:21 -07:00
Aaron Bieber
acefe4a3b9 Don't assume our secret is base32 encoded.
According to https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm
secrets are only base32 encoded in gauthenticator and gauth friendly providers.
2018-04-16 13:14:03 -07:00
a1452cebc9 Travis requires a string for Go 1.10. 2018-04-16 13:03:16 -07:00
6e9812e6f5 Vendor dependencies and add more tests. 2018-04-16 13:03:16 -07:00
Aaron Bieber
8c34415c34 add readme 2018-04-16 12:52:39 -07:00
Paul TREHIOU
2cf2c15def Case insensitive algorithm match 2018-04-16 12:43:27 -07:00
Aaron Bieber
eaad1884d4 Make sure our secret is always uppercase
Non-uppercase secrets that are base32 encoded will fial to decode
unless we upper them.
2017-09-17 18:19:23 -07:00
5d57d844d4 Add license (MIT). 2017-04-13 10:02:20 -07:00
Kyle Isom
31b9d175dd Add travis config. 2017-03-20 14:20:49 -07:00
Aaron Bieber
79e106da2e point to new qr location 2017-03-20 13:18:56 -07:00
Kyle Isom
939b1bc272 Updating imports. 2015-08-12 12:29:34 -07:00
Kyle
89e74f390b Add doc.go, finish YubiKey removal. 2014-04-24 20:43:13 -06:00
Kyle
7881b6fdfc Remove test TOTP client. 2014-04-24 20:40:44 -06:00
Kyle
5bef33245f Remove YubiKey (not currently functional). 2014-04-24 20:37:53 -06:00
Kyle
84250b0501 More documentation. 2014-04-24 20:37:00 -06:00
Kyle Isom
459e9f880f Add function to build Google TOTPs from secret 2014-04-23 16:54:16 -07:00
Kyle Isom
0982f47ce3 Add last night's progress.
Basic functionality for HOTP, TOTP, and YubiKey OTP. Still need YubiKey
HMAC, serialisation, check, and scan.
2013-12-20 17:00:01 -07:00
Kyle Isom
1dec15fd11 add missing files
new files are
	oath_test
	totp code
2013-12-19 00:21:26 -07:00
Kyle Isom
2ee9cae5ba Add basic Google Authenticator TOTP client. 2013-12-19 00:20:00 -07:00
Kyle Isom
dc04475120 HOTP and TOTP-SHA-1 working.
why the frak aren't the SHA-256 and SHA-512 variants working
2013-12-19 00:04:26 -07:00
Kyle Isom
dbbd5116b5 Initial import.
Basic HOTP functionality.
2013-12-18 21:48:14 -07:00
87 changed files with 3573 additions and 1253 deletions

View File

@@ -64,4 +64,4 @@ workflows:
testbuild: testbuild:
jobs: jobs:
- testbuild - testbuild
# - lint - lint

View File

@@ -18,6 +18,19 @@ issues:
# Default: 3 # Default: 3
max-same-issues: 50 max-same-issues: 50
# Exclude some lints for CLI programs under cmd/ (package main).
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
exclude-rules:
- path: ^cmd/
linters:
- forbidigo
- path: cmd/.*
linters:
- forbidigo
- path: .*/cmd/.*
linters:
- forbidigo
formatters: formatters:
enable: enable:
- goimports # checks if the code and import statements are formatted according to the 'goimports' command - goimports # checks if the code and import statements are formatted according to the 'goimports' command
@@ -73,7 +86,6 @@ linters:
- godoclint # checks Golang's documentation practice - godoclint # checks Golang's documentation practice
- godot # checks if comments end in a period - godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- goprintffuncname # checks that printf-like functions are named with f at the end
- gosec # inspects source code for security problems - gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
@@ -230,6 +242,10 @@ linters:
check-type-assertions: true check-type-assertions: true
exclude-functions: exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write - (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
exhaustive: exhaustive:
# Program elements to check for exhaustiveness. # Program elements to check for exhaustiveness.
@@ -321,6 +337,12 @@ linters:
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link # https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
- no-unused-link - no-unused-link
gosec:
excludes:
- G104 # handled by errcheck
- G301
- G306
govet: govet:
# Enable all analyzers. # Enable all analyzers.
# Default: false # Default: false
@@ -356,6 +378,12 @@ linters:
- os.WriteFile - os.WriteFile
- prometheus.ExponentialBuckets.* - prometheus.ExponentialBuckets.*
- prometheus.LinearBuckets - prometheus.LinearBuckets
ignored-numbers:
- 1
- 2
- 3
- 4
- 8
nakedret: nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns. # Make an issue if func has more lines of code than this setting, and it has naked returns.
@@ -424,6 +452,8 @@ linters:
# Omit embedded fields from selector expression. # Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008 # https://staticcheck.dev/docs/checks/#QF1008
- -QF1008 - -QF1008
# We often explicitly enable old/deprecated ciphers for research.
- -SA1019
usetesting: usetesting:
# Enable/disable `os.TempDir()` detections. # Enable/disable `os.TempDir()` detections.
@@ -450,6 +480,10 @@ linters:
linters: [ forbidigo ] linters: [ forbidigo ]
- path: 'logging/example_test.go' - path: 'logging/example_test.go'
linters: [ testableexamples ] linters: [ testableexamples ]
- path: 'main.go'
linters: [ forbidigo, mnd, reassign ]
- path: 'cmd/cruntar/main.go'
linters: [ unparam ]
- source: 'TODO' - source: 'TODO'
linters: [ godot ] linters: [ godot ]
- text: 'should have a package comment' - text: 'should have a package comment'

View File

@@ -1,45 +1,80 @@
Unreleased - 2025-11-15 CHANGELOG
"Error handling modernization" (in progress) v1.12.0 - 2025-11-16
- Introduced typed, wrapped errors via certlib/certerr.Error (Source, Kind, Op, Err) with Unwrap. Added
- Standardized helper constructors: DecodeError, ParsingError, VerifyError, LoadingError. - twofactor: the github.com/kisom/twofactor repo has been subtree'd
- Preserved sentinel errors (e.g., ErrEncryptedPrivateKey, ErrInvalidPEMType, ErrEmptyCertificate) for errors.Is. into this repo.
- Refactored certlib to use certerr in key paths (CSR parsing/verification, PEM cert pool, certificate read/load).
- Migrated logging/file.go and cmd/kgz away from github.com/pkg/errors to stdlib wrapping.
- Removed dependency on github.com/pkg/errors; ran go mod tidy.
- Added package docs for certerr and a README section on error handling and matching.
- Added unit tests for certerr (Is/As and message formatting).
Planned next steps: v1.11.2 - 2025-11-16
- Continue refactoring remaining error paths for consistent wrapping.
- Add focused tests for key flows (encrypted private key, CSR invalid PEM types, etc.).
- Run golangci-lint (errorlint, errcheck) and address findings.
Release 1.2.1 - 2018-09-15 Changed
- cmd/ski, cmd/csrpubdump, cmd/tlskeypair: centralize
certificate/private-key/CSR parsing by reusing certlib helpers.
This reduces duplication and improves consistency across commands.
- csr: CSR parsing in the above commands now uses certlib.ParseCSR,
which verifies CSR signatures (behavioral hardening compared to
prior parsing without signature verification).
+ Add missing format argument to Errorf call in kgz. v1.11.1 - 2025-11-16
Release 1.2.0 - 2018-09-15 Changed
- cmd: complete linting fixes across programs; no functional changes.
+ Adds the kgz command line utility. v1.11.0 - 2025-11-15
Release 1.1.0 - 2017-11-16 Added
- cache/mru: introduce MRU cache implementation with timestamp utilities.
+ A number of new command line utilities were added Changed
- certlib: complete overhaul to simplify APIs and internals.
- repo: widespread linting cleanups across many packages (config, dbg, die,
fileutil, log/logging, mwc, sbuf, seekbuf, tee, testio, etc.).
- cmd: general program cleanups; `cert-bundler` lint fixes.
+ atping Removed
+ cruntar - rand: remove unused package.
+ renfnv - testutil: remove unused code.
+
+ ski
+ subjhash
+ yamll
+ new package: ahash
+ package for loading hashes from an algorithm string
+ new certificate loading functions in the lib package v1.10.1 — 2025-11-15
+ new package: tee Changed
+ emulates tee(1) - certlib: major overhaul and refactor.
- repo: linter autofixes ahead of release.
v1.10.0 — 2025-11-14
Added
- cmd: add `cert-revcheck` command.
Changed
- ci/lint: add golangci-lint stage and initial cleanup.
v1.9.1 — 2025-11-15
Fixed
- die: correct calls to `die.With`.
v1.9.0 — 2025-11-14
Added
- cmd: add `cert-bundler` tool.
Changed
- misc: minor updates and maintenance.
v1.8.1 — 2025-11-14
Added
- cmd: add `tlsinfo` tool.
v1.8.0 — 2025-11-14
Baseline
- Initial baseline for this changelog series.

View File

@@ -2,39 +2,52 @@ GOUTILS
This is a collection of small utility code I've written in Go; the `cmd/` This is a collection of small utility code I've written in Go; the `cmd/`
directory has a number of command-line utilities. Rather than keep all directory has a number of command-line utilities. Rather than keep all
of these in superfluous repositories of their own, or rewriting them of these in superfluous repositories of their own or rewriting them
for each project, I'm putting them here. for each project, I'm putting them here.
The project can be built with the standard Go tooling, or it can be built The project can be built with the standard Go tooling.
with Bazel.
Contents: Contents:
ahash/ Provides hashes from string algorithm specifiers. ahash/ Provides hashes from string algorithm specifiers.
assert/ Error handling, assertion-style. assert/ Error handling, assertion-style.
backoff/ Implementation of an intelligent backoff strategy. backoff/ Implementation of an intelligent backoff strategy.
cache/ Implementations of various caches.
lru/ Least-recently-used cache.
mru/ Most-recently-used cache.
certlib/ Library for working with TLS certificates.
cmd/ cmd/
atping/ Automated TCP ping, meant for putting in cronjobs. atping/ Automated TCP ping, meant for putting in cronjobs.
certchain/ Display the certificate chain from a ca-signed/ Validate whether a certificate is signed by a CA.
TLS connection. cert-bundler/
Create certificate bundles from a source of PEM
certificates.
cert-revcheck/
Check whether a certificate has been revoked or is
expired.
certchain/ Display the certificate chain from a TLS connection.
certdump/ Dump certificate information. certdump/ Dump certificate information.
certexpiry/ Print a list of certificate subjects and expiry times certexpiry/ Print a list of certificate subjects and expiry times
or warn about certificates expiring within a certain or warn about certificates expiring within a certain
window. window.
certverify/ Verify a TLS X.509 certificate, optionally printing certverify/ Verify a TLS X.509 certificate file, optionally printing
the time to expiry and checking for revocations. the time to expiry and checking for revocations.
clustersh/ Run commands or transfer files across multiple clustersh/ Run commands or transfer files across multiple
servers via SSH. servers via SSH.
cruntar/ Untar an archive with hard links, copying instead of cruntar/ (Un)tar an archive with hard links, copying instead of
linking. linking.
csrpubdump/ Dump the public key from an X.509 certificate request. csrpubdump/ Dump the public key from an X.509 certificate request.
data_sync/ Sync the user's homedir to external storage. data_sync/ Sync the user's homedir to external storage.
diskimg/ Write a disk image to a device. diskimg/ Write a disk image to a device.
dumpbytes/ Dump the contents of a file as hex bytes, printing it as
a Go []byte literal.
eig/ EEPROM image generator. eig/ EEPROM image generator.
fragment/ Print a fragment of a file. fragment/ Print a fragment of a file.
host/ Go imlpementation of the host(1) command.
jlp/ JSON linter/prettifier. jlp/ JSON linter/prettifier.
kgz/ Custom gzip compressor / decompressor that handles 99% kgz/ Custom gzip compressor / decompressor that handles 99%
of my use cases. of my use cases.
minmax/ Generate a minmax code for use in uLisp.
parts/ Simple parts database management for my collection of parts/ Simple parts database management for my collection of
electronic components. electronic components.
pem2bin/ Dump the binary body of a PEM-encoded block. pem2bin/ Dump the binary body of a PEM-encoded block.
@@ -44,41 +57,45 @@ Contents:
in a bundle. in a bundle.
renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash. renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash.
rhash/ Compute the digest of remote files. rhash/ Compute the digest of remote files.
rolldie/ Roll some dice.
showimp/ List the external (e.g. non-stdlib and outside the showimp/ List the external (e.g. non-stdlib and outside the
current working directory) imports for a Go file. current working directory) imports for a Go file.
ski Display the SKI for PEM-encoded TLS material. ski Display the SKI for PEM-encoded TLS material.
sprox/ Simple TCP proxy. sprox/ Simple TCP proxy.
stealchain/ Dump the verified chain from a TLS stealchain/ Dump the verified chain from a TLS connection to a
connection to a server. server.
stealchain- Dump the verified chain from a TLS stealchain-server/
server/ connection from a client. Dump the verified chain from a TLS connection from
from a client.
subjhash/ Print or match subject info from a certificate. subjhash/ Print or match subject info from a certificate.
tlsinfo/ Print information about a TLS connection (the TLS version
and cipher suite).
tlskeypair/ Check whether a TLS certificate and key file match. tlskeypair/ Check whether a TLS certificate and key file match.
utc/ Convert times to UTC. utc/ Convert times to UTC.
yamll/ A small YAML linter. yamll/ A small YAML linter.
zsearch/ Search for a string in directory of gzipped files.
config/ A simple global configuration system where configuration config/ A simple global configuration system where configuration
data is pulled from a file or an environment variable data is pulled from a file or an environment variable
transparently. transparently.
iniconf/ A simple INI-style configuration system.
dbg/ A debug printer. dbg/ A debug printer.
die/ Death of a program. die/ Death of a program.
fileutil/ Common file functions. fileutil/ Common file functions.
lib/ Commonly-useful functions for writing Go programs. lib/ Commonly-useful functions for writing Go programs.
log/ A syslog library.
logging/ A logging library. logging/ A logging library.
mwc/ MultiwriteCloser implementation. mwc/ MultiwriteCloser implementation.
rand/ Utilities for working with math/rand.
sbuf/ A byte buffer that can be wiped. sbuf/ A byte buffer that can be wiped.
seekbuf/ A read-seekable byte buffer. seekbuf/ A read-seekable byte buffer.
syslog/ Syslog-type logging. syslog/ Syslog-type logging.
tee/ Emulate tee(1)'s functionality in io.Writers. tee/ Emulate tee(1)'s functionality in io.Writers.
testio/ Various I/O utilities useful during testing. testio/ Various I/O utilities useful during testing.
testutil/ Various utility functions useful during testing. twofactor/ Two-factor authentication.
Each program should have a small README in the directory with more Each program should have a small README in the directory with more
information. information.
All code here is licensed under the ISC license. All code here is licensed under the Apache 2.0 license.
Error handling Error handling
-------------- --------------
@@ -99,7 +116,7 @@ Examples:
``` ```
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
// sentinel match // sentinel match:
if errors.Is(err, certerr.ErrEmptyCertificate) { if errors.Is(err, certerr.ErrEmptyCertificate) {
// handle empty input // handle empty input
} }
@@ -116,5 +133,3 @@ if err != nil {
} }
} }
``` ```
Avoid including sensitive data (keys, passwords, tokens) in error messages.

View File

@@ -91,7 +91,7 @@ func TestReset(t *testing.T) {
} }
} }
const decay = 5 * time.Millisecond const decay = 25 * time.Millisecond
const maxDuration = 10 * time.Millisecond const maxDuration = 10 * time.Millisecond
const interval = time.Millisecond const interval = time.Millisecond

179
cache/lru/lru.go vendored Normal file
View File

@@ -0,0 +1,179 @@
// Package lru implements a Least Recently Used cache.
package lru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the least recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new LRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("LRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("lru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("lru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("lru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

87
cache/lru/lru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,87 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests mirror the MRU-style behavior present in this LRU package
// implementation (eviction removes the most-recently-used entry).
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
// Since this implementation evicts the most-recently-used item, inserting
// "goat" when full evicts "owl" (the most recent at that time).
mock.Add(time.Second)
if _, ok := c.Get("owl"); ok {
t.Fatal("store should not have an entry for owl (MRU-evicted)")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
if !c.Has("raven") {
t.Fatal("store should contain an entry for 'raven'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/lru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package lru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t > ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("lru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("lru: duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("lru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

50
cache/lru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,50 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests validate timestamps ordering semantics for the LRU package.
// Note: The LRU timestamps are sorted with most-recent-first (descending by t).
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// make owl the most recent
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
if ts.K(0) != "owl" {
t.Fatalf("first key should be owl, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "raven" {
t.Fatalf("third key should be raven, have %s", ts.K(2))
}
}

178
cache/mru/mru.go vendored Normal file
View File

@@ -0,0 +1,178 @@
package mru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the most recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new MRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("MRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("mru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("mru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("mru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

92
cache/mru/mru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,92 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
v, ok := c.Get("owl")
if !ok {
t.Fatal("store should have an entry for owl")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
itm := v
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if itm != 2 {
t.Fatalf("stored item should be 2, have %d", itm)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
if !c.Has("owl") {
t.Fatal("store should contain an entry for 'owl'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/mru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package mru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t < ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("mru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("mru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

49
cache/mru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,49 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Millisecond)
// raven, goat, owl
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// at this point, the keys should be raven, goat, owl.
if ts.K(0) != "raven" {
t.Fatalf("first key should be raven, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "owl" {
t.Fatalf("third key should be owl, have %s", ts.K(2))
}
}

View File

@@ -79,24 +79,23 @@ func (e *Error) Error() string {
func (e *Error) Unwrap() error { return e.Err } func (e *Error) Unwrap() error { return e.Err }
// InvalidPEMType is used to indicate that we were expecting one type of PEM // InvalidPEMTypeError is used to indicate that we were expecting one type of PEM
// file, but saw another. // file, but saw another.
type InvalidPEMType struct { type InvalidPEMTypeError struct {
have string have string
want []string want []string
} }
func (err *InvalidPEMType) Error() string { func (err *InvalidPEMTypeError) Error() string {
if len(err.want) == 1 { if len(err.want) == 1 {
return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0]) return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0])
} else {
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
// ErrInvalidPEMType returns a new InvalidPEMType error. // ErrInvalidPEMType returns a new InvalidPEMTypeError error.
func ErrInvalidPEMType(have string, want ...string) error { func ErrInvalidPEMType(have string, want ...string) error {
return &InvalidPEMType{ return &InvalidPEMTypeError{
have: have, have: have,
want: want, want: want,
} }

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certerr package certerr
import ( import (

View File

@@ -11,7 +11,7 @@ import (
// ReadCertificate reads a DER or PEM-encoded certificate from the // ReadCertificate reads a DER or PEM-encoded certificate from the
// byte slice. // byte slice.
func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error) { func ReadCertificate(in []byte) (*x509.Certificate, []byte, error) {
if len(in) == 0 { if len(in) == 0 {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate)
} }
@@ -22,7 +22,7 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file")) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file"))
} }
rest = remaining rest := remaining
if p.Type != "CERTIFICATE" { if p.Type != "CERTIFICATE" {
return nil, rest, certerr.ParsingError( return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCertificate, certerr.ErrorSourceCertificate,
@@ -31,19 +31,26 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
} }
in = p.Bytes in = p.Bytes
cert, err := x509.ParseCertificate(in)
if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return cert, rest, nil
} }
cert, err = x509.ParseCertificate(in) cert, err := x509.ParseCertificate(in)
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return cert, rest, nil return cert, nil, nil
} }
// ReadCertificates tries to read all the certificates in a // ReadCertificates tries to read all the certificates in a
// PEM-encoded collection. // PEM-encoded collection.
func ReadCertificates(in []byte) (certs []*x509.Certificate, err error) { func ReadCertificates(in []byte) ([]*x509.Certificate, error) {
var cert *x509.Certificate var cert *x509.Certificate
var certs []*x509.Certificate
var err error
for { for {
cert, in, err = ReadCertificate(in) cert, in, err = ReadCertificate(in)
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certlib package certlib
import ( import (

View File

@@ -38,6 +38,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/certlib/certerr" "git.wntrmute.dev/kyle/goutils/certlib/certerr"
@@ -47,29 +48,36 @@ import (
// private key. The key must not be in PEM format. If an error is returned, it // private key. The key must not be in PEM format. If an error is returned, it
// may contain information about the private key, so care should be taken when // may contain information about the private key, so care should be taken when
// displaying it directly. // displaying it directly.
func ParsePrivateKeyDER(keyDER []byte) (key crypto.Signer, err error) { func ParsePrivateKeyDER(keyDER []byte) (crypto.Signer, error) {
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER) // Try common encodings in order without deep nesting.
if err != nil { if k, err := x509.ParsePKCS8PrivateKey(keyDER); err == nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER) switch kk := k.(type) {
if err != nil { case *rsa.PrivateKey:
generalKey, err = x509.ParseECPrivateKey(keyDER) return kk, nil
if err != nil { case *ecdsa.PrivateKey:
generalKey, err = ParseEd25519PrivateKey(keyDER) return kk, nil
if err != nil { case ed25519.PrivateKey:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err) return kk, nil
} default:
} return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
} }
} }
if k, err := x509.ParsePKCS1PrivateKey(keyDER); err == nil {
switch generalKey := generalKey.(type) { return k, nil
case *rsa.PrivateKey:
return generalKey, nil
case *ecdsa.PrivateKey:
return generalKey, nil
case ed25519.PrivateKey:
return generalKey, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %t", generalKey))
} }
if k, err := x509.ParseECPrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := ParseEd25519PrivateKey(keyDER); err == nil {
if kk, ok := k.(ed25519.PrivateKey); ok {
return kk, nil
}
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
// If all parsers failed, return the last error from Ed25519 attempt (approximate cause).
if _, err := ParseEd25519PrivateKey(keyDER); err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err)
}
// Fallback (should be unreachable)
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, errors.New("unknown key encoding"))
} }

View File

@@ -65,12 +65,14 @@ func MarshalEd25519PublicKey(pk crypto.PublicKey) ([]byte, error) {
return nil, errEd25519WrongKeyType return nil, errEd25519WrongKeyType
} }
const bitsPerByte = 8
spki := subjectPublicKeyInfo{ spki := subjectPublicKeyInfo{
Algorithm: pkix.AlgorithmIdentifier{ Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID, Algorithm: ed25519OID,
}, },
PublicKey: asn1.BitString{ PublicKey: asn1.BitString{
BitLength: len(pub) * 8, BitLength: len(pub) * bitsPerByte,
Bytes: pub, Bytes: pub,
}, },
} }
@@ -91,7 +93,8 @@ func ParseEd25519PublicKey(der []byte) (crypto.PublicKey, error) {
return nil, errEd25519WrongID return nil, errEd25519WrongID
} }
if spki.PublicKey.BitLength != ed25519.PublicKeySize*8 { const bitsPerByte = 8
if spki.PublicKey.BitLength != ed25519.PublicKeySize*bitsPerByte {
return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch") return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch")
} }

View File

@@ -49,14 +49,14 @@ import (
"strings" "strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
ct "github.com/google/certificate-transparency-go" ct "github.com/google/certificate-transparency-go"
cttls "github.com/google/certificate-transparency-go/tls" cttls "github.com/google/certificate-transparency-go/tls"
ctx509 "github.com/google/certificate-transparency-go/x509" ctx509 "github.com/google/certificate-transparency-go/x509"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"golang.org/x/crypto/pkcs12" "golang.org/x/crypto/pkcs12"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
) )
// OneYear is a time.Duration representing a year's worth of seconds. // OneYear is a time.Duration representing a year's worth of seconds.
@@ -68,7 +68,7 @@ const OneDay = 24 * time.Hour
// DelegationUsage is the OID for the DelegationUseage extensions. // DelegationUsage is the OID for the DelegationUseage extensions.
var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44} var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44}
// DelegationExtension. // DelegationExtension is a non-critical extension marking delegation usage.
var DelegationExtension = pkix.Extension{ var DelegationExtension = pkix.Extension{
Id: DelegationUsage, Id: DelegationUsage,
Critical: false, Critical: false,
@@ -81,13 +81,19 @@ func InclusiveDate(year int, month time.Month, day int) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond) return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond)
} }
const (
year2012 = 2012
year2015 = 2015
day1 = 1
)
// Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop // Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 5 years. // issuing certificates valid for more than 5 years.
var Jul2012 = InclusiveDate(2012, time.July, 01) var Jul2012 = InclusiveDate(year2012, time.July, day1)
// Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop // Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 39 months. // issuing certificates valid for more than 39 months.
var Apr2015 = InclusiveDate(2015, time.April, 01) var Apr2015 = InclusiveDate(year2015, time.April, day1)
// KeyLength returns the bit size of ECDSA or RSA PublicKey. // KeyLength returns the bit size of ECDSA or RSA PublicKey.
func KeyLength(key any) int { func KeyLength(key any) int {
@@ -108,11 +114,11 @@ func KeyLength(key any) int {
} }
// ExpiryTime returns the time when the certificate chain is expired. // ExpiryTime returns the time when the certificate chain is expired.
func ExpiryTime(chain []*x509.Certificate) (notAfter time.Time) { func ExpiryTime(chain []*x509.Certificate) time.Time {
var notAfter time.Time
if len(chain) == 0 { if len(chain) == 0 {
return notAfter return notAfter
} }
notAfter = chain[0].NotAfter notAfter = chain[0].NotAfter
for _, cert := range chain { for _, cert := range chain {
if notAfter.After(cert.NotAfter) { if notAfter.After(cert.NotAfter) {
@@ -158,18 +164,23 @@ func ValidExpiry(c *x509.Certificate) bool {
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm. // an X509 signature algorithm.
var signatureString = map[x509.SignatureAlgorithm]string{ var signatureString = map[x509.SignatureAlgorithm]string{
x509.MD2WithRSA: "MD2WithRSA", x509.UnknownSignatureAlgorithm: "Unknown Signature",
x509.MD5WithRSA: "MD5WithRSA", x509.MD2WithRSA: "MD2WithRSA",
x509.SHA1WithRSA: "SHA1WithRSA", x509.MD5WithRSA: "MD5WithRSA",
x509.SHA256WithRSA: "SHA256WithRSA", x509.SHA1WithRSA: "SHA1WithRSA",
x509.SHA384WithRSA: "SHA384WithRSA", x509.SHA256WithRSA: "SHA256WithRSA",
x509.SHA512WithRSA: "SHA512WithRSA", x509.SHA384WithRSA: "SHA384WithRSA",
x509.DSAWithSHA1: "DSAWithSHA1", x509.SHA512WithRSA: "SHA512WithRSA",
x509.DSAWithSHA256: "DSAWithSHA256", x509.SHA256WithRSAPSS: "SHA256WithRSAPSS",
x509.ECDSAWithSHA1: "ECDSAWithSHA1", x509.SHA384WithRSAPSS: "SHA384WithRSAPSS",
x509.ECDSAWithSHA256: "ECDSAWithSHA256", x509.SHA512WithRSAPSS: "SHA512WithRSAPSS",
x509.ECDSAWithSHA384: "ECDSAWithSHA384", x509.DSAWithSHA1: "DSAWithSHA1",
x509.ECDSAWithSHA512: "ECDSAWithSHA512", x509.DSAWithSHA256: "DSAWithSHA256",
x509.ECDSAWithSHA1: "ECDSAWithSHA1",
x509.ECDSAWithSHA256: "ECDSAWithSHA256",
x509.ECDSAWithSHA384: "ECDSAWithSHA384",
x509.ECDSAWithSHA512: "ECDSAWithSHA512",
x509.PureEd25519: "PureEd25519",
} }
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
@@ -184,18 +195,23 @@ func SignatureString(alg x509.SignatureAlgorithm) string {
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
// method. // method.
var hashAlgoString = map[x509.SignatureAlgorithm]string{ var hashAlgoString = map[x509.SignatureAlgorithm]string{
x509.MD2WithRSA: "MD2", x509.UnknownSignatureAlgorithm: "Unknown Hash Algorithm",
x509.MD5WithRSA: "MD5", x509.MD2WithRSA: "MD2",
x509.SHA1WithRSA: "SHA1", x509.MD5WithRSA: "MD5",
x509.SHA256WithRSA: "SHA256", x509.SHA1WithRSA: "SHA1",
x509.SHA384WithRSA: "SHA384", x509.SHA256WithRSA: "SHA256",
x509.SHA512WithRSA: "SHA512", x509.SHA384WithRSA: "SHA384",
x509.DSAWithSHA1: "SHA1", x509.SHA512WithRSA: "SHA512",
x509.DSAWithSHA256: "SHA256", x509.SHA256WithRSAPSS: "SHA256",
x509.ECDSAWithSHA1: "SHA1", x509.SHA384WithRSAPSS: "SHA384",
x509.ECDSAWithSHA256: "SHA256", x509.SHA512WithRSAPSS: "SHA512",
x509.ECDSAWithSHA384: "SHA384", x509.DSAWithSHA1: "SHA1",
x509.ECDSAWithSHA512: "SHA512", x509.DSAWithSHA256: "SHA256",
x509.ECDSAWithSHA1: "SHA1",
x509.ECDSAWithSHA256: "SHA256",
x509.ECDSAWithSHA384: "SHA384",
x509.ECDSAWithSHA512: "SHA512",
x509.PureEd25519: "SHA512", // per x509 docs Ed25519 uses SHA-512 internally
} }
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
@@ -273,7 +289,7 @@ func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
// ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key, // ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key,
// either PKCS #7, PKCS #12, or raw x509. // either PKCS #7, PKCS #12, or raw x509.
func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certificate, key crypto.Signer, err error) { func ParseCertificatesDER(certsDER []byte, password string) ([]*x509.Certificate, crypto.Signer, error) {
certsDER = bytes.TrimSpace(certsDER) certsDER = bytes.TrimSpace(certsDER)
// First, try PKCS #7 // First, try PKCS #7
@@ -284,7 +300,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
errors.New("can only extract certificates from signed data content info"), errors.New("can only extract certificates from signed data content info"),
) )
} }
certs = pkcs7data.Content.SignedData.Certificates certs := pkcs7data.Content.SignedData.Certificates
if certs == nil { if certs == nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded")) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded"))
} }
@@ -304,7 +320,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
} }
// Finally, attempt to parse raw X.509 certificates // Finally, attempt to parse raw X.509 certificates
certs, err = x509.ParseCertificates(certsDER) certs, err := x509.ParseCertificates(certsDER)
if err != nil { if err != nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err)
} }
@@ -318,7 +334,8 @@ func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
return nil, err return nil, err
} }
if err := cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil { err = cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature)
if err != nil {
return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err) return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err)
} }
return cert, nil return cert, nil
@@ -362,8 +379,8 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
pkcs7data, err := pkcs7.ParsePKCS7(block.Bytes) pkcs7data, err2 := pkcs7.ParsePKCS7(block.Bytes)
if err != nil { if err2 != nil {
return nil, rest, err return nil, rest, err
} }
if pkcs7data.ContentInfo != "SignedData" { if pkcs7data.ContentInfo != "SignedData" {
@@ -382,7 +399,7 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
// LoadPEMCertPool loads a pool of PEM certificates from file. // LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" { if certsFile == "" {
return nil, nil return nil, nil //nolint:nilnil // no CA file provided -> treat as no pool and no error
} }
pemCerts, err := os.ReadFile(certsFile) pemCerts, err := os.ReadFile(certsFile)
if err != nil { if err != nil {
@@ -395,7 +412,7 @@ func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
// PEMToCertPool concerts PEM certificates to a CertPool. // PEMToCertPool concerts PEM certificates to a CertPool.
func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) { func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
if len(pemCerts) == 0 { if len(pemCerts) == 0 {
return nil, nil return nil, nil //nolint:nilnil // empty input means no pool needed
} }
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
@@ -409,14 +426,14 @@ func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
// ParsePrivateKeyPEM parses and returns a PEM-encoded private // ParsePrivateKeyPEM parses and returns a PEM-encoded private
// key. The private key may be either an unencrypted PKCS#8, PKCS#1, // key. The private key may be either an unencrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEM(keyPEM []byte) (crypto.Signer, error) {
return ParsePrivateKeyPEMWithPassword(keyPEM, nil) return ParsePrivateKeyPEMWithPassword(keyPEM, nil)
} }
// ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private // ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private
// key. The private key may be a potentially encrypted PKCS#8, PKCS#1, // key. The private key may be a potentially encrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (crypto.Signer, error) {
keyDER, err := GetKeyDERFromPEM(keyPEM, password) keyDER, err := GetKeyDERFromPEM(keyPEM, password)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -436,47 +453,47 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
break break
} }
} }
if keyDER != nil { if keyDER == nil {
if procType, ok := keyDER.Headers["Proc-Type"]; ok { return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
if strings.Contains(procType, "ENCRYPTED") {
if password != nil {
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
}
}
return keyDER.Bytes, nil
} }
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key")) if password != nil {
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
}
return keyDER.Bytes, nil
} }
// ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request. // ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request.
func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error) { func ParseCSR(in []byte) (*x509.CertificateRequest, []byte, error) {
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, rest := pem.Decode(in) p, rest := pem.Decode(in)
if p != nil { if p == nil {
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" { csr, err := x509.ParseCertificateRequest(in)
return nil, rest, certerr.ParsingError( if err != nil {
certerr.ErrorSourceCSR, return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
certerr.ErrInvalidPEMType(p.Type, "NEW CERTIFICATE REQUEST", "CERTIFICATE REQUEST"),
)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
csr, err = x509.ParseCertificateRequest(p.Bytes) return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
} else { }
csr, err = x509.ParseCertificateRequest(in) return csr, rest, nil
} }
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" {
return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCSR,
certerr.ErrInvalidPEMType(p.Type, "NEW CERTIFICATE REQUEST", "CERTIFICATE REQUEST"),
)
}
csr, err := x509.ParseCertificateRequest(p.Bytes)
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err) return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
err = csr.CheckSignature() return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
if err != nil {
return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, err)
} }
return csr, rest, nil return csr, rest, nil
} }
@@ -484,7 +501,7 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
// It does not check the signature. This is useful for dumping data from a CSR // It does not check the signature. This is useful for dumping data from a CSR
// locally. // locally.
func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) { func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode(csrPEM)
if block == nil { if block == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty")) return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty"))
} }
@@ -499,15 +516,20 @@ func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
// SignerAlgo returns an X.509 signature algorithm from a crypto.Signer. // SignerAlgo returns an X.509 signature algorithm from a crypto.Signer.
func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm { func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
const (
rsaBits2048 = 2048
rsaBits3072 = 3072
rsaBits4096 = 4096
)
switch pub := priv.Public().(type) { switch pub := priv.Public().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
bitLength := pub.N.BitLen() bitLength := pub.N.BitLen()
switch { switch {
case bitLength >= 4096: case bitLength >= rsaBits4096:
return x509.SHA512WithRSA return x509.SHA512WithRSA
case bitLength >= 3072: case bitLength >= rsaBits3072:
return x509.SHA384WithRSA return x509.SHA384WithRSA
case bitLength >= 2048: case bitLength >= rsaBits2048:
return x509.SHA256WithRSA return x509.SHA256WithRSA
default: default:
return x509.SHA1WithRSA return x509.SHA1WithRSA
@@ -537,7 +559,7 @@ func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, e
} }
return &cert, nil return &cert, nil
} }
return nil, nil return nil, nil //nolint:nilnil // absence of client cert is not an error
} }
// CreateTLSConfig creates a tls.Config object from certs and roots. // CreateTLSConfig creates a tls.Config object from certs and roots.
@@ -549,6 +571,7 @@ func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Confi
return &tls.Config{ return &tls.Config{
Certificates: certs, Certificates: certs,
RootCAs: remoteCAs, RootCAs: remoteCAs,
MinVersion: tls.VersionTLS12, // secure default
} }
} }
@@ -582,11 +605,11 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList)) list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList))
for i, serializedSCT := range sctList.SCTList { for i, serializedSCT := range sctList.SCTList {
var sct ct.SignedCertificateTimestamp var sct ct.SignedCertificateTimestamp
rest, err := cttls.Unmarshal(serializedSCT.Val, &sct) rest2, err2 := cttls.Unmarshal(serializedSCT.Val, &sct)
if err != nil { if err2 != nil {
return nil, err return nil, err2
} }
if len(rest) != 0 { if len(rest2) != 0 {
return nil, certerr.ParsingError( return nil, certerr.ParsingError(
certerr.ErrorSourceSCTList, certerr.ErrorSourceSCTList,
errors.New("serialized SCT list contained trailing garbage"), errors.New("serialized SCT list contained trailing garbage"),
@@ -602,12 +625,12 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
// unmarshalled. // unmarshalled.
func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) { func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) {
// This loop finds the SCTListExtension in the OCSP response. // This loop finds the SCTListExtension in the OCSP response.
var SCTListExtension, ext pkix.Extension var sctListExtension, ext pkix.Extension
for _, ext = range response.Extensions { for _, ext = range response.Extensions {
// sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp. // sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp.
sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5} sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
if ext.Id.Equal(sctExtOid) { if ext.Id.Equal(sctExtOid) {
SCTListExtension = ext sctListExtension = ext
break break
} }
} }
@@ -615,10 +638,10 @@ func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTim
// This code block extracts the sctList from the SCT extension. // This code block extracts the sctList from the SCT extension.
var sctList []ct.SignedCertificateTimestamp var sctList []ct.SignedCertificateTimestamp
var err error var err error
if numBytes := len(SCTListExtension.Value); numBytes != 0 { if numBytes := len(sctListExtension.Value); numBytes != 0 {
var serializedSCTList []byte var serializedSCTList []byte
rest := make([]byte, numBytes) rest := make([]byte, numBytes)
copy(rest, SCTListExtension.Value) copy(rest, sctListExtension.Value)
for len(rest) != 0 { for len(rest) != 0 {
rest, err = asn1.Unmarshal(rest, &serializedSCTList) rest, err = asn1.Unmarshal(rest, &serializedSCTList)
if err != nil { if err != nil {

View File

@@ -9,6 +9,8 @@ import (
"strings" "strings"
) )
const defaultHTTPSPort = 443
type Target struct { type Target struct {
Host string Host string
Port int Port int
@@ -29,29 +31,29 @@ func parseURL(host string) (string, int, error) {
} }
if url.Port() == "" { if url.Port() == "" {
return url.Hostname(), 443, nil return url.Hostname(), defaultHTTPSPort, nil
} }
port, err := strconv.ParseInt(url.Port(), 10, 16) portInt, err2 := strconv.ParseInt(url.Port(), 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port()) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port())
} }
return url.Hostname(), int(port), nil return url.Hostname(), int(portInt), nil
} }
func parseHostPort(host string) (string, int, error) { func parseHostPort(host string) (string, int, error) {
host, sport, err := net.SplitHostPort(host) host, sport, err := net.SplitHostPort(host)
if err == nil { if err == nil {
port, err := strconv.ParseInt(sport, 10, 16) portInt, err2 := strconv.ParseInt(sport, 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport)
} }
return host, int(port), nil return host, int(portInt), nil
} }
return host, 443, nil return host, defaultHTTPSPort, nil
} }
func ParseHost(host string) (*Target, error) { func ParseHost(host string) (*Target, error) {

View File

@@ -158,9 +158,9 @@ type EncryptedContentInfo struct {
EncryptedContent []byte `asn1:"tag:0,optional"` EncryptedContent []byte `asn1:"tag:0,optional"`
} }
func unmarshalInit(raw []byte) (init initPKCS7, err error) { func unmarshalInit(raw []byte) (initPKCS7, error) {
_, err = asn1.Unmarshal(raw, &init) var init initPKCS7
if err != nil { if _, err := asn1.Unmarshal(raw, &init); err != nil {
return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return init, nil return init, nil
@@ -218,28 +218,28 @@ func populateEncryptedData(msg *PKCS7, contentBytes []byte) error {
// ParsePKCS7 attempts to parse the DER encoded bytes of a // ParsePKCS7 attempts to parse the DER encoded bytes of a
// PKCS7 structure. // PKCS7 structure.
func ParsePKCS7(raw []byte) (msg *PKCS7, err error) { func ParsePKCS7(raw []byte) (*PKCS7, error) {
pkcs7, err := unmarshalInit(raw) pkcs7, err := unmarshalInit(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg = new(PKCS7) msg := new(PKCS7)
msg.Raw = pkcs7.Raw msg.Raw = pkcs7.Raw
msg.ContentInfo = pkcs7.ContentType.String() msg.ContentInfo = pkcs7.ContentType.String()
switch msg.ContentInfo { switch msg.ContentInfo {
case ObjIDData: case ObjIDData:
if err := populateData(msg, pkcs7.Content); err != nil { if e := populateData(msg, pkcs7.Content); e != nil {
return nil, err return nil, e
} }
case ObjIDSignedData: case ObjIDSignedData:
if err := populateSignedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateSignedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
case ObjIDEncryptedData: case ObjIDEncryptedData:
if err := populateEncryptedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateEncryptedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
default: default:
return nil, certerr.ParsingError( return nil, certerr.ParsingError(

View File

@@ -5,6 +5,7 @@ package revoke
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -90,34 +91,34 @@ func ldapURL(url string) bool {
// - false, true: the certificate was checked successfully, and it is not revoked. // - false, true: the certificate was checked successfully, and it is not revoked.
// - true, true: the certificate was checked successfully, and it is revoked. // - true, true: the certificate was checked successfully, and it is revoked.
// - true, false: failure to check revocation status causes verification to fail. // - true, false: failure to check revocation status causes verification to fail.
func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) { func revCheck(cert *x509.Certificate) (bool, bool, error) {
for _, url := range cert.CRLDistributionPoints { for _, url := range cert.CRLDistributionPoints {
if ldapURL(url) { if ldapURL(url) {
log.Infof("skipping LDAP CRL: %s", url) log.Infof("skipping LDAP CRL: %s", url)
continue continue
} }
if revoked, ok, err := certIsRevokedCRL(cert, url); !ok { if rvk, ok2, err2 := certIsRevokedCRL(cert, url); !ok2 {
log.Warning("error checking revocation via CRL") log.Warning("error checking revocation via CRL")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via CRL") log.Info("certificate is revoked via CRL")
return true, true, err return true, true, err2
} }
} }
if revoked, ok, err := certIsRevokedOCSP(cert, HardFail); !ok { if rvk, ok2, err2 := certIsRevokedOCSP(cert, HardFail); !ok2 {
log.Warning("error checking revocation via OCSP") log.Warning("error checking revocation via OCSP")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via OCSP") log.Info("certificate is revoked via OCSP")
return true, true, err return true, true, err2
} }
return false, true, nil return false, true, nil
@@ -125,13 +126,17 @@ func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) {
// fetchCRL fetches and parses a CRL. // fetchCRL fetches and parses a CRL.
func fetchCRL(url string) (*x509.RevocationList, error) { func fetchCRL(url string) (*x509.RevocationList, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 300 { if resp.StatusCode >= http.StatusMultipleChoices {
return nil, errors.New("failed to retrieve CRL") return nil, errors.New("failed to retrieve CRL")
} }
@@ -158,7 +163,7 @@ func getIssuer(cert *x509.Certificate) *x509.Certificate {
// check a cert against a specific CRL. Returns the same bool pair // check a cert against a specific CRL. Returns the same bool pair
// as revCheck, plus an error if one occurred. // as revCheck, plus an error if one occurred.
func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err error) { func certIsRevokedCRL(cert *x509.Certificate, url string) (bool, bool, error) {
crlLock.Lock() crlLock.Lock()
crl, ok := CRLSet[url] crl, ok := CRLSet[url]
if ok && crl == nil { if ok && crl == nil {
@@ -186,10 +191,9 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
// check CRL signature // check CRL signature
if issuer != nil { if issuer != nil {
err = crl.CheckSignatureFrom(issuer) if sigErr := crl.CheckSignatureFrom(issuer); sigErr != nil {
if err != nil { log.Warningf("failed to verify CRL: %v", sigErr)
log.Warningf("failed to verify CRL: %v", err) return false, false, sigErr
return false, false, err
} }
} }
@@ -198,26 +202,26 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
crlLock.Unlock() crlLock.Unlock()
} }
for _, revoked := range crl.RevokedCertificates { for _, entry := range crl.RevokedCertificateEntries {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 { if cert.SerialNumber.Cmp(entry.SerialNumber) == 0 {
log.Info("Serial number match: intermediate is revoked.") log.Info("Serial number match: intermediate is revoked.")
return true, true, err return true, true, nil
} }
} }
return false, true, err return false, true, nil
} }
// VerifyCertificate ensures that the certificate passed in hasn't // VerifyCertificate ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificate(cert *x509.Certificate) (revoked, ok bool) { func VerifyCertificate(cert *x509.Certificate) (bool, bool) {
revoked, ok, _ = VerifyCertificateError(cert) revoked, ok, _ := VerifyCertificateError(cert)
return revoked, ok return revoked, ok
} }
// VerifyCertificateError ensures that the certificate passed in hasn't // VerifyCertificateError ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error) { func VerifyCertificateError(cert *x509.Certificate) (bool, bool, error) {
if !time.Now().Before(cert.NotAfter) { if !time.Now().Before(cert.NotAfter) {
msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter) msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter)
log.Info(msg) log.Info(msg)
@@ -231,7 +235,11 @@ func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error
} }
func fetchRemote(url string) (*x509.Certificate, error) { func fetchRemote(url string) (*x509.Certificate, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,8 +262,12 @@ var ocspOpts = ocsp.RequestOptions{
Hash: crypto.SHA1, Hash: crypto.SHA1,
} }
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e error) { const ocspGetURLMaxLen = 256
var err error
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (bool, bool, error) {
var revoked bool
var ok bool
var lastErr error
ocspURLs := leaf.OCSPServer ocspURLs := leaf.OCSPServer
if len(ocspURLs) == 0 { if len(ocspURLs) == 0 {
@@ -271,15 +283,16 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts) ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts)
if err != nil { if err != nil {
return revoked, ok, err return false, false, err
} }
for _, server := range ocspURLs { for _, server := range ocspURLs {
resp, err := sendOCSPRequest(server, ocspRequest, leaf, issuer) resp, e := sendOCSPRequest(server, ocspRequest, leaf, issuer)
if err != nil { if e != nil {
if strict { if strict {
return revoked, ok, err return false, false, e
} }
lastErr = e
continue continue
} }
@@ -291,9 +304,9 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
revoked = true revoked = true
} }
return revoked, ok, err return revoked, ok, nil
} }
return revoked, ok, err return revoked, ok, lastErr
} }
// sendOCSPRequest attempts to request an OCSP response from the // sendOCSPRequest attempts to request an OCSP response from the
@@ -302,12 +315,21 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) { func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) {
var resp *http.Response var resp *http.Response
var err error var err error
if len(req) > 256 { if len(req) > ocspGetURLMaxLen {
buf := bytes.NewBuffer(req) buf := bytes.NewBuffer(req)
resp, err = HTTPClient.Post(server, "application/ocsp-request", buf) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodPost, server, buf)
if e != nil {
return nil, e
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
resp, err = HTTPClient.Do(httpReq)
} else { } else {
reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req)) reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req))
resp, err = HTTPClient.Get(reqURL) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
if e != nil {
return nil, e
}
resp, err = HTTPClient.Do(httpReq)
} }
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package revoke package revoke
import ( import (
@@ -153,7 +154,7 @@ func mustParse(pemData string) *x509.Certificate {
panic("Invalid PEM type.") panic("Invalid PEM type.")
} }
cert, err := x509.ParseCertificate([]byte(block.Bytes)) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"net" "net"
@@ -28,10 +29,16 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Printf("connecting to %s/%s... ", addr, proto) fmt.Printf("connecting to %s/%s... ", addr, proto)
os.Stdout.Sync() if err = os.Stdout.Sync(); err != nil {
return err
}
} }
conn, err := net.DialTimeout(proto, addr, timeout) dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := dialer.DialContext(context.Background(), proto, addr)
if err != nil { if err != nil {
if verbose { if verbose {
fmt.Println("failed.") fmt.Println("failed.")
@@ -42,8 +49,8 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Println("OK") fmt.Println("OK")
} }
conn.Close()
return nil return conn.Close()
} }
func main() { func main() {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"crypto/x509" "crypto/x509"
"embed" "embed"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,22 +15,22 @@ import (
// loadCertsFromFile attempts to parse certificates from a file that may be in // loadCertsFromFile attempts to parse certificates from a file that may be in
// PEM or DER/PKCS#7 format. Returns the parsed certificates or an error. // PEM or DER/PKCS#7 format. Returns the parsed certificates or an error.
func loadCertsFromFile(path string) ([]*x509.Certificate, error) { func loadCertsFromFile(path string) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Try PEM first if certs, err = certlib.ParseCertificatesPEM(data); err == nil {
if certs, err := certlib.ParseCertificatesPEM(data); err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password) if certs, _, err = certlib.ParseCertificatesDER(data, ""); err == nil {
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil {
return certs, nil return certs, nil
} else {
return nil, err
} }
return nil, err
} }
func makePoolFromFile(path string) (*x509.CertPool, error) { func makePoolFromFile(path string) (*x509.CertPool, error) {
@@ -56,49 +57,50 @@ var embeddedTestdata embed.FS
// loadCertsFromBytes attempts to parse certificates from bytes that may be in // loadCertsFromBytes attempts to parse certificates from bytes that may be in
// PEM or DER/PKCS#7 format. // PEM or DER/PKCS#7 format.
func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) { func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) {
// Try PEM first certs, err := certlib.ParseCertificatesPEM(data)
if certs, err := certlib.ParseCertificatesPEM(data); err == nil { if err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password)
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil { certs, _, err = certlib.ParseCertificatesDER(data, "")
if err == nil {
return certs, nil return certs, nil
} else {
return nil, err
} }
return nil, err
} }
func makePoolFromBytes(data []byte) (*x509.CertPool, error) { func makePoolFromBytes(data []byte) (*x509.CertPool, error) {
certs, err := loadCertsFromBytes(data) certs, err := loadCertsFromBytes(data)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return nil, fmt.Errorf("failed to load CA certificates from embedded bytes") return nil, errors.New("failed to load CA certificates from embedded bytes")
} }
pool := x509.NewCertPool() pool := x509.NewCertPool()
for _, c := range certs { for _, c := range certs {
pool.AddCert(c) pool.AddCert(c)
} }
return pool, nil return pool, nil
} }
// isSelfSigned returns true if the given certificate is self-signed. // isSelfSigned returns true if the given certificate is self-signed.
// It checks that the subject and issuer match and that the certificate's // It checks that the subject and issuer match and that the certificate's
// signature verifies against its own public key. // signature verifies against its own public key.
func isSelfSigned(cert *x509.Certificate) bool { func isSelfSigned(cert *x509.Certificate) bool {
if cert == nil { if cert == nil {
return false return false
} }
// Quick check: subject and issuer match // Quick check: subject and issuer match
if cert.Subject.String() != cert.Issuer.String() { if cert.Subject.String() != cert.Issuer.String() {
return false return false
} }
// Cryptographic check: the certificate is signed by itself // Cryptographic check: the certificate is signed by itself
if err := cert.CheckSignatureFrom(cert); err != nil { if err := cert.CheckSignatureFrom(cert); err != nil {
return false return false
} }
return true return true
} }
func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string) { func verifyAgainstCA(caPool *x509.CertPool, path string) (bool, string) {
certs, err := loadCertsFromFile(path) certs, err := loadCertsFromFile(path)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -117,14 +119,14 @@ func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expiry string) { func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (bool, string) {
certs, err := loadCertsFromBytes(certData) certs, err := loadCertsFromBytes(certData)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -143,92 +145,159 @@ func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expi
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
// selftest runs built-in validation using embedded certificates. type testCase struct {
func selftest() int { name string
type testCase struct { caFile string
name string certFile string
caFile string expectOK bool
certFile string }
expectOK bool
func (tc testCase) Run() error {
caBytes, err := embeddedTestdata.ReadFile(tc.caFile)
if err != nil {
return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.caFile, err)
} }
cases := []testCase{ certBytes, err := embeddedTestdata.ReadFile(tc.certFile)
{name: "ISRG Root X1 validates LE E7", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/le-e7.pem", expectOK: true}, if err != nil {
{name: "ISRG Root X1 does NOT validate Google WR2", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/goog-wr2.pem", expectOK: false}, return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.certFile, err)
{name: "GTS R1 validates Google WR2", caFile: "testdata/gts-r1.pem", certFile: "testdata/goog-wr2.pem", expectOK: true}, }
{name: "GTS R1 does NOT validate LE E7", caFile: "testdata/gts-r1.pem", certFile: "testdata/le-e7.pem", expectOK: false},
}
failures := 0 pool, err := makePoolFromBytes(caBytes)
for _, tc := range cases { if err != nil || pool == nil {
caBytes, err := embeddedTestdata.ReadFile(tc.caFile) return fmt.Errorf("selftest: failed to build CA pool for %s: %w", tc.caFile, err)
}
ok, exp := verifyAgainstCABytes(pool, certBytes)
if ok != tc.expectOK {
return fmt.Errorf("%s: unexpected result: got %v, want %v", tc.name, ok, tc.expectOK)
}
if ok {
fmt.Printf("%s: OK (expires %s)\n", tc.name, exp)
}
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
return nil
}
var cases = []testCase{
{
name: "ISRG Root X1 validates LE E7",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/le-e7.pem",
expectOK: true,
},
{
name: "ISRG Root X1 does NOT validate Google WR2",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: false,
},
{
name: "GTS R1 validates Google WR2",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: true,
},
{
name: "GTS R1 does NOT validate LE E7",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/le-e7.pem",
expectOK: false,
},
}
// selftest runs built-in validation using embedded certificates.
func selftest() int {
failures := 0
for _, tc := range cases {
err := tc.Run()
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.caFile, err) fmt.Fprintln(os.Stderr, err)
failures++ failures++
continue continue
} }
certBytes, err := embeddedTestdata.ReadFile(tc.certFile) }
// Verify that both embedded root CAs are detected as self-signed
roots := []string{"testdata/gts-r1.pem", "testdata/isrg-root-x1.pem"}
for _, root := range roots {
b, err := embeddedTestdata.ReadFile(root)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.certFile, err) fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", root, err)
failures++ failures++
continue continue
} }
pool, err := makePoolFromBytes(caBytes) certs, err := loadCertsFromBytes(b)
if err != nil || pool == nil { if err != nil || len(certs) == 0 {
fmt.Fprintf(os.Stderr, "selftest: failed to build CA pool for %s: %v\n", tc.caFile, err) fmt.Fprintf(os.Stderr, "selftest: failed to parse cert(s) from %s: %v\n", root, err)
failures++ failures++
continue continue
} }
ok, exp := verifyAgainstCABytes(pool, certBytes) leaf := certs[0]
if ok != tc.expectOK { if isSelfSigned(leaf) {
fmt.Printf("%s: unexpected result: got %v, want %v\n", tc.name, ok, tc.expectOK) fmt.Printf("%s: SELF-SIGNED (as expected)\n", root)
failures++
} else { } else {
if ok { fmt.Printf("%s: expected SELF-SIGNED, but was not detected as such\n", root)
fmt.Printf("%s: OK (expires %s)\n", tc.name, exp) failures++
} else {
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
}
} }
} }
// Verify that both embedded root CAs are detected as self-signed if failures == 0 {
roots := []string{"testdata/gts-r1.pem", "testdata/isrg-root-x1.pem"} fmt.Println("selftest: PASS")
for _, root := range roots { return 0
b, err := embeddedTestdata.ReadFile(root) }
if err != nil { fmt.Fprintf(os.Stderr, "selftest: FAIL (%d failure(s))\n", failures)
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", root, err) return 1
failures++ }
continue
}
certs, err := loadCertsFromBytes(b)
if err != nil || len(certs) == 0 {
fmt.Fprintf(os.Stderr, "selftest: failed to parse cert(s) from %s: %v\n", root, err)
failures++
continue
}
leaf := certs[0]
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED (as expected)\n", root)
} else {
fmt.Printf("%s: expected SELF-SIGNED, but was not detected as such\n", root)
failures++
}
}
if failures == 0 { // expiryString returns a YYYY-MM-DD date string to display for certificate
fmt.Println("selftest: PASS") // expiry. If an explicit exp string is provided, it is used. Otherwise, if a
return 0 // leaf certificate is available, its NotAfter is formatted. As a last resort,
} // it falls back to today's date (should not normally happen).
fmt.Fprintf(os.Stderr, "selftest: FAIL (%d failure(s))\n", failures) func expiryString(leaf *x509.Certificate, exp string) string {
return 1 if exp != "" {
return exp
}
if leaf != nil {
return leaf.NotAfter.Format("2006-01-02")
}
return time.Now().Format("2006-01-02")
}
// processCert verifies a single certificate file against the provided CA pool
// and prints the result in the required format, handling self-signed
// certificates specially.
func processCert(caPool *x509.CertPool, certPath string) {
ok, exp := verifyAgainstCA(caPool, certPath)
name := filepath.Base(certPath)
// Try to load the leaf cert for self-signed detection and expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// Prefer the SELF-SIGNED label if applicable
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
return
}
if ok {
fmt.Printf("%s: OK (expires %s)\n", name, expiryString(leaf, exp))
return
}
fmt.Printf("%s: INVALID\n", name)
} }
func main() { func main() {
@@ -250,38 +319,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
for _, certPath := range os.Args[2:] { for _, certPath := range os.Args[2:] {
ok, exp := verifyAgainstCA(caPool, certPath) processCert(caPool, certPath)
name := filepath.Base(certPath) }
// Load the leaf once for self-signed detection and potential expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// If the certificate is self-signed, prefer the SELF-SIGNED label
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
continue
}
if ok {
// Display with the requested format
// Example: file: OK (expires 2031-01-01)
// Ensure deterministic date formatting
// Note: no timezone displayed; date only as per example
// If exp ended up empty for some reason, recompute safely
if exp == "" {
if leaf != nil {
exp = leaf.NotAfter.Format("2006-01-02")
} else {
// fallback to the current date to avoid empty; though shouldn't happen
exp = time.Now().Format("2006-01-02")
}
}
fmt.Printf("%s: OK (expires %s)\n", name, exp)
} else {
fmt.Printf("%s: INVALID\n", name)
}
}
} }

View File

@@ -8,8 +8,10 @@ import (
"crypto/x509" "crypto/x509"
_ "embed" _ "embed"
"encoding/pem" "encoding/pem"
"errors"
"flag" "flag"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -19,7 +21,7 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// Config represents the top-level YAML configuration // Config represents the top-level YAML configuration.
type Config struct { type Config struct {
Config struct { Config struct {
Hashes string `yaml:"hashes"` Hashes string `yaml:"hashes"`
@@ -28,19 +30,19 @@ type Config struct {
Chains map[string]ChainGroup `yaml:"chains"` Chains map[string]ChainGroup `yaml:"chains"`
} }
// ChainGroup represents a named group of certificate chains // ChainGroup represents a named group of certificate chains.
type ChainGroup struct { type ChainGroup struct {
Certs []CertChain `yaml:"certs"` Certs []CertChain `yaml:"certs"`
Outputs Outputs `yaml:"outputs"` Outputs Outputs `yaml:"outputs"`
} }
// CertChain represents a root certificate and its intermediates // CertChain represents a root certificate and its intermediates.
type CertChain struct { type CertChain struct {
Root string `yaml:"root"` Root string `yaml:"root"`
Intermediates []string `yaml:"intermediates"` Intermediates []string `yaml:"intermediates"`
} }
// Outputs defines output format options // Outputs defines output format options.
type Outputs struct { type Outputs struct {
IncludeSingle bool `yaml:"include_single"` IncludeSingle bool `yaml:"include_single"`
IncludeIndividual bool `yaml:"include_individual"` IncludeIndividual bool `yaml:"include_individual"`
@@ -95,7 +97,8 @@ func main() {
} }
// Create output directory if it doesn't exist // Create output directory if it doesn't exist
if err := os.MkdirAll(outputDir, 0755); err != nil { err = os.MkdirAll(outputDir, 0750)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err) fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -108,9 +111,9 @@ func main() {
} }
createdFiles := make([]string, 0, totalFormats) createdFiles := make([]string, 0, totalFormats)
for groupName, group := range cfg.Chains { for groupName, group := range cfg.Chains {
files, err := processChainGroup(groupName, group, expiryDuration) files, perr := processChainGroup(groupName, group, expiryDuration)
if err != nil { if perr != nil {
fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, err) fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, perr)
os.Exit(1) os.Exit(1)
} }
createdFiles = append(createdFiles, files...) createdFiles = append(createdFiles, files...)
@@ -119,8 +122,8 @@ func main() {
// Generate hash file for all created archives // Generate hash file for all created archives
if cfg.Config.Hashes != "" { if cfg.Config.Hashes != "" {
hashFile := filepath.Join(outputDir, cfg.Config.Hashes) hashFile := filepath.Join(outputDir, cfg.Config.Hashes)
if err := generateHashFile(hashFile, createdFiles); err != nil { if gerr := generateHashFile(hashFile, createdFiles); gerr != nil {
fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", err) fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", gerr)
os.Exit(1) os.Exit(1)
} }
} }
@@ -135,8 +138,8 @@ func loadConfig(path string) (*Config, error) {
} }
var cfg Config var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil { if uerr := yaml.Unmarshal(data, &cfg); uerr != nil {
return nil, err return nil, uerr
} }
return &cfg, nil return &cfg, nil
@@ -200,72 +203,107 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D
return createdFiles, nil return createdFiles, nil
} }
// loadAndCollectCerts loads all certificates from chains and collects them for processing // loadAndCollectCerts loads all certificates from chains and collects them for processing.
func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration time.Duration) ([]*x509.Certificate, []certWithPath, error) { func loadAndCollectCerts(
chains []CertChain,
outputs Outputs,
expiryDuration time.Duration,
) ([]*x509.Certificate, []certWithPath, error) {
var singleFileCerts []*x509.Certificate var singleFileCerts []*x509.Certificate
var individualCerts []certWithPath var individualCerts []certWithPath
for _, chain := range chains { for _, chain := range chains {
// Load root certificate s, i, cerr := collectFromChain(chain, outputs, expiryDuration)
rootCert, err := certlib.LoadCertificate(chain.Root) if cerr != nil {
if err != nil { return nil, nil, cerr
return nil, nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err)
} }
if len(s) > 0 {
// Check expiry for root singleFileCerts = append(singleFileCerts, s...)
checkExpiry(chain.Root, rootCert, expiryDuration)
// Add root to collections if needed
if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, rootCert)
} }
if outputs.IncludeIndividual { if len(i) > 0 {
individualCerts = append(individualCerts, certWithPath{ individualCerts = append(individualCerts, i...)
cert: rootCert,
path: chain.Root,
})
}
// Load and validate intermediates
for _, intPath := range chain.Intermediates {
intCert, err := certlib.LoadCertificate(intPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err)
}
// Validate that intermediate is signed by root
if err := intCert.CheckSignatureFrom(rootCert); err != nil {
return nil, nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err)
}
// Check expiry for intermediate
checkExpiry(intPath, intCert, expiryDuration)
// Add intermediate to collections if needed
if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, intCert)
}
if outputs.IncludeIndividual {
individualCerts = append(individualCerts, certWithPath{
cert: intCert,
path: intPath,
})
}
} }
} }
return singleFileCerts, individualCerts, nil return singleFileCerts, individualCerts, nil
} }
// prepareArchiveFiles prepares all files to be included in archives // collectFromChain loads a single chain, performs checks, and returns the certs to include.
func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []certWithPath, outputs Outputs, encoding string) ([]fileEntry, error) { func collectFromChain(
chain CertChain,
outputs Outputs,
expiryDuration time.Duration,
) (
[]*x509.Certificate,
[]certWithPath,
error,
) {
var single []*x509.Certificate
var indiv []certWithPath
// Load root certificate
rootCert, rerr := certlib.LoadCertificate(chain.Root)
if rerr != nil {
return nil, nil, fmt.Errorf("failed to load root certificate %s: %w", chain.Root, rerr)
}
// Check expiry for root
checkExpiry(chain.Root, rootCert, expiryDuration)
// Add root to collections if needed
if outputs.IncludeSingle {
single = append(single, rootCert)
}
if outputs.IncludeIndividual {
indiv = append(indiv, certWithPath{cert: rootCert, path: chain.Root})
}
// Load and validate intermediates
for _, intPath := range chain.Intermediates {
intCert, lerr := certlib.LoadCertificate(intPath)
if lerr != nil {
return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %w", intPath, lerr)
}
// Validate that intermediate is signed by root
if sigErr := intCert.CheckSignatureFrom(rootCert); sigErr != nil {
return nil, nil, fmt.Errorf(
"intermediate %s is not properly signed by root %s: %w",
intPath,
chain.Root,
sigErr,
)
}
// Check expiry for intermediate
checkExpiry(intPath, intCert, expiryDuration)
// Add intermediate to collections if needed
if outputs.IncludeSingle {
single = append(single, intCert)
}
if outputs.IncludeIndividual {
indiv = append(indiv, certWithPath{cert: intCert, path: intPath})
}
}
return single, indiv, nil
}
// prepareArchiveFiles prepares all files to be included in archives.
func prepareArchiveFiles(
singleFileCerts []*x509.Certificate,
individualCerts []certWithPath,
outputs Outputs,
encoding string,
) ([]fileEntry, error) {
var archiveFiles []fileEntry var archiveFiles []fileEntry
// Handle a single bundle file // Handle a single bundle file
if outputs.IncludeSingle && len(singleFileCerts) > 0 { if outputs.IncludeSingle && len(singleFileCerts) > 0 {
files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true) files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode single bundle: %v", err) return nil, fmt.Errorf("failed to encode single bundle: %w", err)
} }
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
@@ -276,7 +314,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []
baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path)) baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path))
files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false) files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode individual cert %s: %v", cp.path, err) return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err)
} }
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
@@ -294,7 +332,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []
return archiveFiles, nil return archiveFiles, nil
} }
// createArchiveFiles creates archive files in the specified formats // createArchiveFiles creates archive files in the specified formats.
func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) { func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) {
createdFiles := make([]string, 0, len(formats)) createdFiles := make([]string, 0, len(formats))
@@ -307,11 +345,11 @@ func createArchiveFiles(groupName string, formats []string, archiveFiles []fileE
switch format { switch format {
case "zip": case "zip":
if err := createZipArchive(archivePath, archiveFiles); err != nil { if err := createZipArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create zip archive: %v", err) return nil, fmt.Errorf("failed to create zip archive: %w", err)
} }
case "tgz": case "tgz":
if err := createTarGzArchive(archivePath, archiveFiles); err != nil { if err := createTarGzArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create tar.gz archive: %v", err) return nil, fmt.Errorf("failed to create tar.gz archive: %w", err)
} }
default: default:
return nil, fmt.Errorf("unsupported format: %s", format) return nil, fmt.Errorf("unsupported format: %s", format)
@@ -329,7 +367,12 @@ func checkExpiry(path string, cert *x509.Certificate, expiryDuration time.Durati
if cert.NotAfter.Before(expiryThreshold) { if cert.NotAfter.Before(expiryThreshold) {
daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24) daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24)
if daysUntilExpiry < 0 { if daysUntilExpiry < 0 {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s has EXPIRED (expired %d days ago)\n", path, -daysUntilExpiry) fmt.Fprintf(
os.Stderr,
"WARNING: Certificate %s has EXPIRED (expired %d days ago)\n",
path,
-daysUntilExpiry,
)
} else { } else {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02")) fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02"))
} }
@@ -347,8 +390,13 @@ type certWithPath struct {
} }
// encodeCertsToFiles converts certificates to file entries based on encoding type // encodeCertsToFiles converts certificates to file entries based on encoding type
// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file // If isSingle is true, certs are concatenated into a single file; otherwise one cert per file.
func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding string, isSingle bool) ([]fileEntry, error) { func encodeCertsToFiles(
certs []*x509.Certificate,
baseName string,
encoding string,
isSingle bool,
) ([]fileEntry, error) {
var files []fileEntry var files []fileEntry
switch encoding { switch encoding {
@@ -369,14 +417,12 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
name: baseName + ".crt", name: baseName + ".crt",
content: derContent, content: derContent,
}) })
} else { } else if len(certs) > 0 {
// Individual DER file (should only have one cert) // Individual DER file (should only have one cert)
if len(certs) > 0 { files = append(files, fileEntry{
files = append(files, fileEntry{ name: baseName + ".crt",
name: baseName + ".crt", content: certs[0].Raw,
content: certs[0].Raw, })
})
}
} }
case "both": case "both":
// Add PEM version // Add PEM version
@@ -395,13 +441,11 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
name: baseName + ".crt", name: baseName + ".crt",
content: derContent, content: derContent,
}) })
} else { } else if len(certs) > 0 {
if len(certs) > 0 { files = append(files, fileEntry{
files = append(files, fileEntry{ name: baseName + ".crt",
name: baseName + ".crt", content: certs[0].Raw,
content: certs[0].Raw, })
})
}
} }
default: default:
return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding) return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding)
@@ -410,7 +454,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
return files, nil return files, nil
} }
// encodeCertsToPEM encodes certificates to PEM format // encodeCertsToPEM encodes certificates to PEM format.
func encodeCertsToPEM(certs []*x509.Certificate) []byte { func encodeCertsToPEM(certs []*x509.Certificate) []byte {
var pemContent []byte var pemContent []byte
for _, cert := range certs { for _, cert := range certs {
@@ -435,40 +479,49 @@ func generateManifest(files []fileEntry) []byte {
return []byte(manifest.String()) return []byte(manifest.String())
} }
// closeWithErr attempts to close all provided closers, joining any close errors with baseErr.
func closeWithErr(baseErr error, closers ...io.Closer) error {
for _, c := range closers {
if c == nil {
continue
}
if cerr := c.Close(); cerr != nil {
baseErr = errors.Join(baseErr, cerr)
}
}
return baseErr
}
func createZipArchive(path string, files []fileEntry) error { func createZipArchive(path string, files []fileEntry) error {
f, err := os.Create(path) f, zerr := os.Create(path)
if err != nil { if zerr != nil {
return err return zerr
} }
w := zip.NewWriter(f) w := zip.NewWriter(f)
for _, file := range files { for _, file := range files {
fw, err := w.Create(file.name) fw, werr := w.Create(file.name)
if err != nil { if werr != nil {
w.Close() return closeWithErr(werr, w, f)
f.Close()
return err
} }
if _, err := fw.Write(file.content); err != nil { if _, werr = fw.Write(file.content); werr != nil {
w.Close() return closeWithErr(werr, w, f)
f.Close()
return err
} }
} }
// Check errors on close operations // Check errors on close operations
if err := w.Close(); err != nil { if cerr := w.Close(); cerr != nil {
f.Close() _ = f.Close()
return err return cerr
} }
return f.Close() return f.Close()
} }
func createTarGzArchive(path string, files []fileEntry) error { func createTarGzArchive(path string, files []fileEntry) error {
f, err := os.Create(path) f, terr := os.Create(path)
if err != nil { if terr != nil {
return err return terr
} }
gw := gzip.NewWriter(f) gw := gzip.NewWriter(f)
@@ -480,29 +533,23 @@ func createTarGzArchive(path string, files []fileEntry) error {
Mode: 0644, Mode: 0644,
Size: int64(len(file.content)), Size: int64(len(file.content)),
} }
if err := tw.WriteHeader(hdr); err != nil { if herr := tw.WriteHeader(hdr); herr != nil {
tw.Close() return closeWithErr(herr, tw, gw, f)
gw.Close()
f.Close()
return err
} }
if _, err := tw.Write(file.content); err != nil { if _, werr := tw.Write(file.content); werr != nil {
tw.Close() return closeWithErr(werr, tw, gw, f)
gw.Close()
f.Close()
return err
} }
} }
// Check errors on close operations in the correct order // Check errors on close operations in the correct order
if err := tw.Close(); err != nil { if cerr := tw.Close(); cerr != nil {
gw.Close() _ = gw.Close()
f.Close() _ = f.Close()
return err return cerr
} }
if err := gw.Close(); err != nil { if cerr := gw.Close(); cerr != nil {
f.Close() _ = f.Close()
return err return cerr
} }
return f.Close() return f.Close()
} }
@@ -515,9 +562,9 @@ func generateHashFile(path string, files []string) error {
defer f.Close() defer f.Close()
for _, file := range files { for _, file := range files {
data, err := os.ReadFile(file) data, rerr := os.ReadFile(file)
if err != nil { if rerr != nil {
return err return rerr
} }
hash := sha256.Sum256(data) hash := sha256.Sum256(data)

View File

@@ -1,14 +1,15 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"flag"
"errors" "errors"
"flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
@@ -23,6 +24,13 @@ var (
verbose bool verbose bool
) )
var (
strOK = "OK"
strExpired = "EXPIRED"
strRevoked = "REVOKED"
strUnknown = "UNKNOWN"
)
func main() { func main() {
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal") flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects") flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
@@ -42,16 +50,16 @@ func main() {
for _, target := range flag.Args() { for _, target := range flag.Args() {
status, err := processTarget(target) status, err := processTarget(target)
switch status { switch status {
case "OK": case strOK:
fmt.Printf("%s: OK\n", target) fmt.Printf("%s: %s\n", target, strOK)
case "EXPIRED": case strExpired:
fmt.Printf("%s: EXPIRED: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strExpired, err)
exitCode = 1 exitCode = 1
case "REVOKED": case strRevoked:
fmt.Printf("%s: REVOKED\n", target) fmt.Printf("%s: %s\n", target, strRevoked)
exitCode = 1 exitCode = 1
case "UNKNOWN": case strUnknown:
fmt.Printf("%s: UNKNOWN: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
if hardfail { if hardfail {
// In hardfail, treat unknown as failure // In hardfail, treat unknown as failure
exitCode = 1 exitCode = 1
@@ -67,74 +75,77 @@ func processTarget(target string) (string, error) {
return checkFile(target) return checkFile(target)
} }
// Not a file; treat as site
return checkSite(target) return checkSite(target)
} }
func checkFile(path string) (string, error) { func checkFile(path string) (string, error) {
in, err := ioutil.ReadFile(path) // Prefer high-level helpers from certlib to load certificates from disk
if err != nil { if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
return "UNKNOWN", err // Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
} }
// Try PEM first; if that fails, try single DER cert cert, err := certlib.LoadCertificate(path)
certs, err := certlib.ReadCertificates(in) if err != nil || cert == nil {
if err != nil || len(certs) == 0 { return strUnknown, err
cert, _, derr := certlib.ReadCertificate(in)
if derr != nil || cert == nil {
if err == nil {
err = derr
}
return "UNKNOWN", err
}
return evaluateCert(cert)
} }
return evaluateCert(cert)
// Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
} }
func checkSite(hostport string) (string, error) { func checkSite(hostport string) (string, error) {
// Use certlib/hosts to parse host/port (supports https URLs and host:port) // Use certlib/hosts to parse host/port (supports https URLs and host:port)
target, err := hosts.ParseHost(hostport) target, err := hosts.ParseHost(hostport)
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
d := &net.Dialer{Timeout: timeout} d := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}) tcfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: target.Host,
} // #nosec G402 -- CLI tool only verifies revocation
td := &tls.Dialer{NetDialer: d, Config: tcfg}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := td.DialContext(ctx, "tcp", target.String())
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() tconn, ok := conn.(*tls.Conn)
if !ok {
return strUnknown, errors.New("connection is not TLS")
}
state := tconn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
return "UNKNOWN", errors.New("no peer certificates presented") return strUnknown, errors.New("no peer certificates presented")
} }
return evaluateCert(state.PeerCertificates[0]) return evaluateCert(state.PeerCertificates[0])
} }
func evaluateCert(cert *x509.Certificate) (string, error) { func evaluateCert(cert *x509.Certificate) (string, error) {
// Expiry check // Delegate validity and revocation checks to certlib/revoke helper.
now := time.Now() // It returns revoked=true for both revoked and expired/not-yet-valid.
if !now.Before(cert.NotAfter) { // Map those cases back to our statuses using the returned error text.
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
}
if !now.After(cert.NotBefore) {
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
}
// Revocation check using certlib/revoke
revoked, ok, err := revoke.VerifyCertificateError(cert) revoked, ok, err := revoke.VerifyCertificateError(cert)
if revoked { if revoked {
// If revoked is true, ok will be true per implementation, err may describe why if err != nil {
return "REVOKED", err msg := err.Error()
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
strings.Contains(msg, "not valid until") {
return strExpired, err
}
}
return strRevoked, err
} }
if !ok { if !ok {
// Revocation status could not be determined // Revocation status could not be determined
return "UNKNOWN", err return strUnknown, err
} }
return "OK", nil return strOK, nil
} }

View File

@@ -1,11 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"os"
"regexp" "regexp"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -20,20 +23,26 @@ func main() {
server += ":443" server += ":443"
} }
var chain string d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", server)
conn, err := tls.Dial("tcp", server, nil)
die.If(err) die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
details := conn.ConnectionState() details := conn.ConnectionState()
var chain strings.Builder
for _, cert := range details.PeerCertificates { for _, cert := range details.PeerCertificates {
p := pem.Block{ p := pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
} }
chain += string(pem.EncodeToMemory(&p)) chain.Write(pem.EncodeToMemory(&p))
} }
fmt.Println(chain) fmt.Fprintln(os.Stdout, chain.String())
} }
} }

View File

@@ -1,7 +1,9 @@
//lint:file-ignore SA1019 allow strict compatibility for old certs
package main package main
import ( import (
"bytes" "bytes"
"context"
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@@ -101,30 +103,30 @@ func extUsage(ext []x509.ExtKeyUsage) string {
} }
func showBasicConstraints(cert *x509.Certificate) { func showBasicConstraints(cert *x509.Certificate) {
fmt.Printf("\tBasic constraints: ") fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid { if cert.BasicConstraintsValid {
fmt.Printf("valid") fmt.Fprint(os.Stdout, "valid")
} else { } else {
fmt.Printf("invalid") fmt.Fprint(os.Stdout, "invalid")
} }
if cert.IsCA { if cert.IsCA {
fmt.Printf(", is a CA certificate") fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid { if !cert.BasicConstraintsValid {
fmt.Printf(" (basic constraint failure)") fmt.Fprint(os.Stdout, " (basic constraint failure)")
} }
} else { } else {
fmt.Printf("is not a CA certificate") fmt.Fprint(os.Stdout, "is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Printf(" (key encipherment usage enabled!)") fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
} }
} }
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) { if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Printf(", max path length %d", cert.MaxPathLen) fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
} }
fmt.Printf("\n") fmt.Fprintln(os.Stdout)
} }
const oneTrueDateFormat = "2006-01-02T15:04:05-0700" const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
@@ -136,39 +138,41 @@ var (
func wrapPrint(text string, indent int) { func wrapPrint(text string, indent int) {
tabs := "" tabs := ""
for i := 0; i < indent; i++ { var tabsSb140 strings.Builder
tabs += "\t" for range indent {
tabsSb140.WriteString("\t")
} }
tabs += tabsSb140.String()
fmt.Printf(tabs+"%s\n", wrap(text, indent)) fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
} }
func displayCert(cert *x509.Certificate) { func displayCert(cert *x509.Certificate) {
fmt.Println("CERTIFICATE") fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash { if showHash {
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0)) fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
} }
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0)) fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0)) fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm), fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm)) sigAlgoHash(cert.SignatureAlgorithm))
fmt.Println("Details:") fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1) wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber) fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 { if len(cert.AuthorityKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
} }
if len(cert.SubjectKeyId) > 0 { if len(cert.SubjectKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
} }
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1) wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat)) fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage)) fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 { if len(cert.ExtKeyUsage) > 0 {
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage)) fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
} }
showBasicConstraints(cert) showBasicConstraints(cert)
@@ -221,13 +225,13 @@ func displayAllCerts(in []byte, leafOnly bool) {
if err != nil { if err != nil {
certs, _, err = certlib.ParseCertificatesDER(in, "") certs, _, err = certlib.ParseCertificatesDER(in, "")
if err != nil { if err != nil {
lib.Warn(err, "failed to parse certificates") _, _ = lib.Warn(err, "failed to parse certificates")
return return
} }
} }
if len(certs) == 0 { if len(certs) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -243,29 +247,45 @@ func displayAllCerts(in []byte, leafOnly bool) {
func displayAllCertsWeb(uri string, leafOnly bool) { func displayAllCertsWeb(uri string, leafOnly bool) {
ci := getConnInfo(uri) ci := getConnInfo(uri)
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig()) d := &tls.Dialer{Config: permissiveConfig()}
nc, err := d.DialContext(context.Background(), "tcp", ci.Addr)
if err != nil { if err != nil {
lib.Warn(err, "couldn't connect to %s", ci.Addr) _, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr)
return
}
conn, ok := nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return return
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
conn.Close() if err = conn.Close(); err != nil {
_, _ = lib.Warn(err, "couldn't close TLS connection")
}
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host)) d = &tls.Dialer{Config: verifyConfig(ci.Host)}
nc, err = d.DialContext(context.Background(), "tcp", ci.Addr)
if err == nil { if err == nil {
conn, ok = nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return
}
err = conn.VerifyHostname(ci.Host) err = conn.VerifyHostname(ci.Host)
if err == nil { if err == nil {
state = conn.ConnectionState() state = conn.ConnectionState()
} }
conn.Close() conn.Close()
} else { } else {
lib.Warn(err, "TLS verification error with server name %s", ci.Host) _, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host)
} }
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -275,14 +295,14 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
if len(state.VerifiedChains) == 0 { if len(state.VerifiedChains) == 0 {
lib.Warnx("no verified chains found; using peer chain") _, _ = lib.Warnx("no verified chains found; using peer chain")
for i := range state.PeerCertificates { for i := range state.PeerCertificates {
displayCert(state.PeerCertificates[i]) displayCert(state.PeerCertificates[i])
} }
} else { } else {
fmt.Println("TLS chain verified successfully.") fmt.Fprintln(os.Stdout, "TLS chain verified successfully.")
for i := range state.VerifiedChains { for i := range state.VerifiedChains {
fmt.Printf("--- Verified certificate chain %d ---\n", i+1) fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n")
for j := range state.VerifiedChains[i] { for j := range state.VerifiedChains[i] {
displayCert(state.VerifiedChains[i][j]) displayCert(state.VerifiedChains[i][j])
} }
@@ -290,6 +310,32 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
} }
func shouldReadStdin(argc int, argv []string) bool {
if argc == 0 {
return true
}
if argc == 1 && argv[0] == "-" {
return true
}
return false
}
func readStdin(leafOnly bool) {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.ReplaceAll(certs, []byte(`\n`), []byte{0xa})
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
}
func main() { func main() {
var leafOnly bool var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents") flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
@@ -297,32 +343,23 @@ func main() {
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate") flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse() flag.Parse()
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") { if shouldReadStdin(flag.NArg(), flag.Args()) {
certs, err := io.ReadAll(os.Stdin) readStdin(leafOnly)
if err != nil { return
lib.Warn(err, "couldn't read certificates from standard input") }
os.Exit(1)
}
// This is needed for getting certs from JSON/jq. for _, filename := range flag.Args() {
certs = bytes.TrimSpace(certs) fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1) if strings.HasPrefix(filename, "https://") {
certs = bytes.Trim(certs, `"`) displayAllCertsWeb(filename, leafOnly)
displayAllCerts(certs, leafOnly) } else {
} else { in, err := os.ReadFile(filename)
for _, filename := range flag.Args() { if err != nil {
fmt.Printf("--%s ---\n", filename) _, _ = lib.Warn(err, "couldn't read certificate")
if strings.HasPrefix(filename, "https://") { continue
displayAllCertsWeb(filename, leafOnly)
} else {
in, err := os.ReadFile(filename)
if err != nil {
lib.Warn(err, "couldn't read certificate")
continue
}
displayAllCerts(in, leafOnly)
} }
displayAllCerts(in, leafOnly)
} }
} }
} }

View File

@@ -13,6 +13,11 @@ import (
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\)," // following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,") // "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{ var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature", x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment", x509.KeyUsageContentCommitment: "content committment",
@@ -26,42 +31,36 @@ var keyUsage = map[x509.KeyUsage]string{
} }
var extKeyUsages = map[x509.ExtKeyUsage]string{ var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any", x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth", x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth", x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing", x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime", x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system", x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel", x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user", x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping", x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing", x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc", x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc", x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
} x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
switch a {
case x509.RSA:
return "RSA"
case x509.ECDSA:
return "ECDSA"
case x509.DSA:
return "DSA"
default:
return "unknown public key algorithm"
}
} }
func sigAlgoPK(a x509.SignatureAlgorithm) string { func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a { switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA" return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA" return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256: case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA" return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default: default:
return "unknown public key algorithm" return "unknown public key algorithm"
} }
@@ -76,11 +75,21 @@ func sigAlgoHash(a x509.SignatureAlgorithm) string {
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1: case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1" return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256: case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return "SHA256" return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384: case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384" return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512: case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return "SHA512" return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default: default:
return "unknown hash algorithm" return "unknown hash algorithm"
} }
@@ -90,9 +99,11 @@ const maxLine = 78
func makeIndent(n int) string { func makeIndent(n int) string {
s := " " s := " "
for i := 0; i < n; i++ { var sSb97 strings.Builder
s += " " for range n {
sSb97.WriteString(" ")
} }
s += sSb97.String()
return s return s
} }
@@ -100,7 +111,7 @@ func indentLen(n int) int {
return 4 + (8 * n) return 4 + (8 * n)
} }
// this isn't real efficient, but that's not a problem here // this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string { func wrap(s string, indent int) string {
if indent > 3 { if indent > 3 {
indent = 3 indent = 3
@@ -123,9 +134,11 @@ func wrap(s string, indent int) string {
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb130 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb130.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb130.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -136,14 +149,14 @@ func dumpHex(in []byte) string {
func permissiveConfig() *tls.Config { func permissiveConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} } // #nosec G402
} }
// verifyConfig returns a config that will verify the connection. // verifyConfig returns a config that will verify the connection.
func verifyConfig(hostname string) *tls.Config { func verifyConfig(hostname string) *tls.Config {
return &tls.Config{ return &tls.Config{
ServerName: hostname, ServerName: hostname,
} } // #nosec G402
} }
type connInfo struct { type connInfo struct {

View File

@@ -5,7 +5,6 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
} }
func expires(cert *x509.Certificate) time.Duration { func expires(cert *x509.Certificate) time.Duration {
return cert.NotAfter.Sub(time.Now()) return time.Until(cert.NotAfter)
} }
func inDanger(cert *x509.Certificate) bool { func inDanger(cert *x509.Certificate) bool {
@@ -81,15 +80,15 @@ func main() {
flag.Parse() flag.Parse()
for _, file := range flag.Args() { for _, file := range flag.Args() {
in, err := ioutil.ReadFile(file) in, err := os.ReadFile(file)
if err != nil { if err != nil {
lib.Warn(err, "failed to read file") _, _ = lib.Warn(err, "failed to read file")
continue continue
} }
certs, err := certlib.ParseCertificatesPEM(in) certs, err := certlib.ParseCertificatesPEM(in)
if err != nil { if err != nil {
lib.Warn(err, "while parsing certificates") _, _ = lib.Warn(err, "while parsing certificates")
continue continue
} }

View File

@@ -4,13 +4,11 @@ import (
"crypto/x509" "crypto/x509"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke" "git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
@@ -30,83 +28,116 @@ func printRevocation(cert *x509.Certificate) {
} }
} }
func main() { type appConfig struct {
var caFile, intFile string caFile, intFile string
var forceIntermediateBundle, revexp, verbose bool forceIntermediateBundle bool
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`") revexp, verbose bool
flag.StringVar(&intFile, "i", "", "intermediate `bundle`") }
flag.BoolVar(&forceIntermediateBundle, "f", false,
func parseFlags() appConfig {
var cfg appConfig
flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") "force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&verbose, "v", false, "verbose") flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse() flag.Parse()
return cfg
}
var roots *x509.CertPool func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile != "" { if caFile == "" {
var err error return x509.SystemCertPool()
if verbose {
fmt.Println("[+] loading root certificates from", caFile)
}
roots, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
} }
var ints *x509.CertPool
if intFile != "" {
var err error
if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile)
}
ints, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
} else {
ints = x509.NewCertPool()
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
lib.ProgName())
}
fileData, err := ioutil.ReadFile(flag.Arg(0))
die.If(err)
chain, err := certlib.ParseCertificatesPEM(fileData)
die.If(err)
if verbose { if verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain)) fmt.Println("[+] loading root certificates from", caFile)
} }
return certlib.LoadPEMCertPool(caFile)
}
cert := chain[0] func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if len(chain) > 1 { if intFile == "" {
if !forceIntermediateBundle { return x509.NewCertPool(), nil
for _, intermediate := range chain[1:] { }
if verbose { if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId) fmt.Println("[+] loading intermediate certificates from", intFile)
} }
// Note: use intFile here (previously used caFile mistakenly)
return certlib.LoadPEMCertPool(intFile)
}
ints.AddCert(intermediate) func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
} for _, intermediate := range chain[1:] {
if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
} }
pool.AddCert(intermediate)
} }
}
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Intermediates: ints, Intermediates: ints,
Roots: roots, Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
_, err := cert.Verify(opts)
return err
}
_, err = cert.Verify(opts) func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err) return err
os.Exit(1)
} }
if verbose { ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
fileData, err := os.ReadFile(flag.Arg(0))
if err != nil {
return err
}
chain, err := certlib.ParseCertificatesPEM(fileData)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK") fmt.Println("OK")
} }
if revexp { if cfg.revexp {
printRevocation(cert) printRevocation(cert)
} }
return nil
}
func main() {
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
} }

View File

@@ -2,6 +2,8 @@ package main
import ( import (
"bufio" "bufio"
"context"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
} }
func sshAgent() ssh.AuthMethod { func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil { if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers) return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
} }
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
} }
} }
func logError(host string, err error, format string, args ...interface{}) { func logError(host string, err error, format string, args ...any) {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err) log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
} }
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
} }
shutdown = append(shutdown, session.Close) shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close() session.Close()
logError(host, err, "request for pty failed") logError(host, err, "request for pty failed")
return return
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")

View File

@@ -10,6 +10,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/fileutil" "git.wntrmute.dev/kyle/goutils/fileutil"
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
if verbose { if verbose {
fmt.Printf("\tchmod %0#o\n", hdr.Mode) fmt.Printf("\tchmod %0#o\n", hdr.Mode)
} }
err := file.Chmod(os.FileMode(hdr.Mode)) err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
if err != nil { if err != nil {
return err return err
} }
@@ -48,73 +49,105 @@ func linkTarget(target, top string) string {
return target return target
} }
return filepath.Clean(filepath.Join(target, top)) return filepath.Clean(filepath.Join(top, target))
}
// safeJoin joins base and elem and ensures the resulting path does not escape base.
func safeJoin(base, elem string) (string, error) {
cleanBase := filepath.Clean(base)
joined := filepath.Clean(filepath.Join(cleanBase, elem))
absBase, err := filepath.Abs(cleanBase)
if err != nil {
return "", err
}
absJoined, err := filepath.Abs(joined)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absBase, absJoined)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
}
return joined, nil
}
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
if _, err = io.Copy(file, tfr); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeLink(hdr *tar.Header, top, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
srcPath, err := safeJoin(top, hdr.Linkname)
if err != nil {
return err
}
source, err := os.Open(srcPath)
if err != nil {
return err
}
defer source.Close()
if _, err = io.Copy(file, source); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
}
func handleTypeDir(hdr *tar.Header, filePath string) error {
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
} }
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error { func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
if verbose { if verbose {
fmt.Println(hdr.Name) fmt.Println(hdr.Name)
} }
filePath := filepath.Clean(filepath.Join(top, hdr.Name))
switch hdr.Typeflag {
case tar.TypeReg:
file, err := os.Create(filePath)
if err != nil {
return err
}
_, err = io.Copy(file, tfr) filePath, err := safeJoin(top, hdr.Name)
if err != nil { if err != nil {
return err return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeLink:
file, err := os.Create(filePath)
if err != nil {
return err
}
source, err := os.Open(hdr.Linkname)
if err != nil {
return err
}
_, err = io.Copy(file, source)
if err != nil {
return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeSymlink:
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s",
hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath)
if err != nil {
return err
}
case tar.TypeDir:
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode))
if err != nil {
return err
}
} }
switch hdr.Typeflag {
case tar.TypeReg:
return handleTypeReg(tfr, hdr, filePath)
case tar.TypeLink:
return handleTypeLink(hdr, top, filePath)
case tar.TypeSymlink:
return handleTypeSymlink(hdr, top, filePath)
case tar.TypeDir:
return handleTypeDir(hdr, filePath)
}
return nil return nil
} }
@@ -261,16 +294,16 @@ func main() {
die.If(err) die.If(err)
tfr := tar.NewReader(r) tfr := tar.NewReader(r)
var hdr *tar.Header
for { for {
hdr, err := tfr.Next() hdr, err = tfr.Next()
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
die.If(err) die.If(err)
err = processFile(tfr, hdr, top) err = processFile(tfr, hdr, top)
die.If(err) die.If(err)
} }
r.Close() r.Close()

View File

@@ -7,9 +7,9 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "os"
"log"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -17,17 +17,10 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
in, err := ioutil.ReadFile(fileName) in, err := os.ReadFile(fileName)
die.If(err) die.If(err)
if p, _ := pem.Decode(in); p != nil { csr, _, err := certlib.ParseCSR(in)
if p.Type != "CERTIFICATE REQUEST" {
log.Fatal("INVALID FILE TYPE")
}
in = p.Bytes
}
csr, err := x509.ParseCertificateRequest(in)
die.If(err) die.If(err)
out, err := x509.MarshalPKIXPublicKey(csr.PublicKey) out, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
@@ -48,8 +41,8 @@ func main() {
Bytes: out, Bytes: out,
} }
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644) err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.\n", fileName+".pub") fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
return err return err
} }
cmd := exec.Command(path, args...) cmd := exec.CommandContext(context.Background(), path, args...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
@@ -163,7 +164,6 @@ func init() {
} }
func main() { func main() {
var logLevel, mountDir, syncDir, target string var logLevel, mountDir, syncDir, target string
var dryRun, quietMode, noSyslog, verboseRsync bool var dryRun, quietMode, noSyslog, verboseRsync bool
@@ -219,7 +219,7 @@ func main() {
if excludeFile != "" { if excludeFile != "" {
defer func() { defer func() {
log.Infof("removing exclude file %s", excludeFile) log.Infof("removing exclude file %s", excludeFile)
if err := os.Remove(excludeFile); err != nil { if rmErr := os.Remove(excludeFile); rmErr != nil {
log.Warningf("failed to remove temp file %s", excludeFile) log.Warningf("failed to remove temp file %s", excludeFile)
} }
}() }()

View File

@@ -15,43 +15,41 @@ import (
const defaultHashAlgorithm = "sha256" const defaultHashAlgorithm = "sha256"
var ( var (
hAlgo string hAlgo string
debug = dbg.New() debug = dbg.New()
) )
func openImage(imageFile string) (*os.File, []byte, error) {
func openImage(imageFile string) (image *os.File, hash []byte, err error) { f, err := os.Open(imageFile)
image, err = os.Open(imageFile)
if err != nil { if err != nil {
return return nil, nil, err
} }
hash, err = ahash.SumReader(hAlgo, image) h, err := ahash.SumReader(hAlgo, f)
if err != nil { if err != nil {
return return nil, nil, err
} }
_, err = image.Seek(0, 0) if _, err = f.Seek(0, 0); err != nil {
if err != nil { return nil, nil, err
return
} }
debug.Printf("%s %x\n", imageFile, hash) debug.Printf("%s %x\n", imageFile, h)
return return f, h, nil
} }
func openDevice(devicePath string) (device *os.File, err error) { func openDevice(devicePath string) (*os.File, error) {
fi, err := os.Stat(devicePath) fi, err := os.Stat(devicePath)
if err != nil { if err != nil {
return return nil, err
} }
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode()) device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
if err != nil { if err != nil {
return return nil, err
} }
return return device, nil
} }
func main() { func main() {
@@ -105,12 +103,12 @@ func main() {
die.If(err) die.If(err)
if !bytes.Equal(deviceHash, hash) { if !bytes.Equal(deviceHash, hash) {
fmt.Fprintln(os.Stderr, "Hash mismatch:") buf := &bytes.Buffer{}
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash) fmt.Fprintln(buf, "Hash mismatch:")
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash) fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
os.Exit(1) fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
die.With(buf.String())
} }
debug.Println("OK") debug.Println("OK")
os.Exit(0)
} }

View File

@@ -1,30 +1,33 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/die"
"io" "io"
"os" "os"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
) )
func usage(w io.Writer, exc int) { func usage(w io.Writer, exc int) {
fmt.Fprintln(w, `usage: dumpbytes <file>`) fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
os.Exit(exc) os.Exit(exc)
} }
func printBytes(buf []byte) { func printBytes(buf []byte) {
fmt.Printf("\t") fmt.Printf("\t")
for i := 0; i < len(buf); i++ { for i := range buf {
fmt.Printf("0x%02x, ", buf[i]) fmt.Printf("0x%02x, ", buf[i])
} }
fmt.Println() fmt.Println()
} }
func dumpFile(path string, indentLevel int) error { func dumpFile(path string, indentLevel int) error {
indent := "" var indent strings.Builder
for i := 0; i < indentLevel; i++ { for range indentLevel {
indent += "\t" indent.WriteByte('\t')
} }
file, err := os.Open(path) file, err := os.Open(path)
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
defer file.Close() defer file.Close()
fmt.Printf("%svar buffer = []byte{\n", indent) fmt.Printf("%svar buffer = []byte{\n", indent.String())
var n int
for { for {
buf := make([]byte, 8) buf := make([]byte, 8)
n, err := file.Read(buf) n, err = file.Read(buf)
if err == io.EOF { if errors.Is(err, io.EOF) {
if n > 0 { if n > 0 {
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
break break
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
return err return err
} }
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
fmt.Printf("%s}\n", indent) fmt.Printf("%s}\n", indent.String())
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
// size of a kilobit in bytes // size of a kilobit in bytes.
const kilobit = 128 const kilobit = 128
const pageSize = 4096 const pageSize = 4096
@@ -26,10 +26,10 @@ func main() {
path = flag.Arg(0) path = flag.Arg(0)
} }
fillByte := uint8(*fill) fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
buf := make([]byte, pageSize) buf := make([]byte, pageSize)
for i := 0; i < pageSize; i++ { for i := range pageSize {
buf[i] = fillByte buf[i] = fillByte
} }
@@ -40,7 +40,7 @@ func main() {
die.If(err) die.If(err)
defer file.Close() defer file.Close()
for i := 0; i < pages; i++ { for range pages {
_, err = file.Write(buf) _, err = file.Write(buf)
die.If(err) die.If(err)
} }

View File

@@ -72,15 +72,13 @@ func main() {
if end < start { if end < start {
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values") fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
tmp := end start, end = end, start
end = start
start = tmp
} }
var fmtStr string var fmtStr string
if !*quiet { if !*quiet {
maxLine := fmt.Sprintf("%d", len(lines)) maxLine := strconv.Itoa(len(lines))
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine)) fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
} }
@@ -98,9 +96,9 @@ func main() {
fmtStr += "\n" fmtStr += "\n"
for i := start; !endFunc(i); i++ { for i := start; !endFunc(i); i++ {
if *quiet { if *quiet {
fmt.Println(lines[i]) fmt.Fprintln(os.Stdout, lines[i])
} else { } else {
fmt.Printf(fmtStr, i, lines[i]) fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@@ -8,7 +9,8 @@ import (
) )
func lookupHost(host string) error { func lookupHost(host string) error {
cname, err := net.LookupCNAME(host) r := &net.Resolver{}
cname, err := r.LookupCNAME(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
host = cname host = cname
} }
addrs, err := net.LookupHost(host) addrs, err := r.LookupHost(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Indent(buf, in, "", " ") err = json.Indent(buf, in, "", " ")
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Compact(buf, in) err = json.Compact(buf, in)
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
func usage() { func usage() {
progname := lib.ProgName() progname := lib.ProgName()
fmt.Printf(`Usage: %s [-h] files... fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
%s is used to lint and prettify (or compact) JSON files. The %s is used to lint and prettify (or compact) JSON files. The
files will be updated in-place. files will be updated in-place.
@@ -100,7 +100,6 @@ func usage() {
-h Print this help message. -h Print this help message.
-n Don't prettify; only perform validation. -n Don't prettify; only perform validation.
`, progname, progname) `, progname, progname)
} }
func init() { func init() {

View File

@@ -11,7 +11,10 @@ based on whether the source filename ends in ".gz".
Flags: Flags:
-l level Compression level (0-9). Only meaninful when -l level Compression level (0-9). Only meaninful when
compressing a file. compressing a file.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.

View File

@@ -1,68 +1,84 @@
package main package main
import ( import (
"compress/flate" "compress/flate"
"compress/gzip" "compress/gzip"
"flag" "flag"
"fmt" "fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
) )
const gzipExt = ".gz" const gzipExt = ".gz"
func compress(path, target string, level int) error { func compress(path, target string, level int) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return fmt.Errorf("opening file for read: %w", err) return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
destFile, err := os.Create(target) destFile, err := os.Create(target)
if err != nil { if err != nil {
return fmt.Errorf("opening file for write: %w", err) return fmt.Errorf("opening file for write: %w", err)
} }
defer destFile.Close() defer destFile.Close()
gzipCompressor, err := gzip.NewWriterLevel(destFile, level) gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
if err != nil { if err != nil {
return fmt.Errorf("invalid compression level: %w", err) return fmt.Errorf("invalid compression level: %w", err)
} }
defer gzipCompressor.Close() defer gzipCompressor.Close()
_, err = io.Copy(gzipCompressor, sourceFile) _, err = io.Copy(gzipCompressor, sourceFile)
if err != nil { if err != nil {
return fmt.Errorf("compressing file: %w", err) return fmt.Errorf("compressing file: %w", err)
} }
return nil return nil
} }
func uncompress(path, target string) error { func uncompress(path, target string, unrestrict bool) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return fmt.Errorf("opening file for read: %w", err) return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
gzipUncompressor, err := gzip.NewReader(sourceFile) fi, err := sourceFile.Stat()
if err != nil { if err != nil {
return fmt.Errorf("reading gzip headers: %w", err) return fmt.Errorf("reading file stats: %w", err)
} }
maxDecompressionSize := fi.Size() * 32
gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil {
return fmt.Errorf("reading gzip headers: %w", err)
}
defer gzipUncompressor.Close() defer gzipUncompressor.Close()
destFile, err := os.Create(target) var reader io.Reader = &io.LimitedReader{
if err != nil { R: gzipUncompressor,
return fmt.Errorf("opening file for write: %w", err) N: maxDecompressionSize,
} }
if unrestrict {
reader = gzipUncompressor
}
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
}
defer destFile.Close() defer destFile.Close()
_, err = io.Copy(destFile, gzipUncompressor) _, err = io.Copy(destFile, reader)
if err != nil { if err != nil {
return fmt.Errorf("uncompressing file: %w", err) return fmt.Errorf("uncompressing file: %w", err)
} }
return nil return nil
} }
@@ -87,8 +103,8 @@ func isDir(path string) bool {
file, err := os.Open(path) file, err := os.Open(path)
if err == nil { if err == nil {
defer file.Close() defer file.Close()
stat, err := file.Stat() stat, err2 := file.Stat()
if err != nil { if err2 != nil {
return false return false
} }
@@ -106,9 +122,9 @@ func pathForUncompressing(source, dest string) (string, error) {
} }
source = filepath.Base(source) source = filepath.Base(source)
if !strings.HasSuffix(source, gzipExt) { if !strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a not gzip-compressed file", source) return "", fmt.Errorf("%s is a not gzip-compressed file", source)
} }
outFile := source[:len(source)-len(gzipExt)] outFile := source[:len(source)-len(gzipExt)]
outFile = filepath.Join(dest, outFile) outFile = filepath.Join(dest, outFile)
return outFile, nil return outFile, nil
@@ -120,9 +136,9 @@ func pathForCompressing(source, dest string) (string, error) {
} }
source = filepath.Base(source) source = filepath.Base(source)
if strings.HasSuffix(source, gzipExt) { if strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a gzip-compressed file", source) return "", fmt.Errorf("%s is a gzip-compressed file", source)
} }
dest = filepath.Join(dest, source+gzipExt) dest = filepath.Join(dest, source+gzipExt)
return dest, nil return dest, nil
@@ -132,8 +148,11 @@ func main() {
var level int var level int
var path string var path string
var target = "." var target = "."
var err error
var unrestrict bool
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level") flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.Parse() flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 { if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -147,30 +166,31 @@ func main() {
} }
if strings.HasSuffix(path, gzipExt) { if strings.HasSuffix(path, gzipExt) {
target, err := pathForUncompressing(path, target) target, err = pathForUncompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
err = uncompress(path, target) err = uncompress(path, target, unrestrict)
if err != nil { if err != nil {
os.Remove(target) os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
} else { return
target, err := pathForCompressing(path, target) }
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
err = compress(path, target, level) target, err = pathForCompressing(path, target)
if err != nil { if err != nil {
os.Remove(target) fmt.Fprintf(os.Stderr, "%s\n", err)
fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1)
os.Exit(1) }
}
err = compress(path, target, level)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
} }
} }

View File

@@ -40,14 +40,14 @@ func main() {
usage() usage()
} }
min, err := strconv.Atoi(flag.Arg(1)) minVal, err := strconv.Atoi(flag.Arg(1))
dieIf(err) dieIf(err)
max, err := strconv.Atoi(flag.Arg(2)) maxVal, err := strconv.Atoi(flag.Arg(2))
dieIf(err) dieIf(err)
code := kind << 6 code := kind << 6
code += (min << 3) code += (minVal << 3)
code += max code += maxVal
fmt.Printf("%0o\n", code) fmt.Fprintf(os.Stdout, "%0o\n", code)
} }

View File

@@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@@ -47,7 +46,7 @@ func help(w io.Writer) {
} }
func loadDatabase() { func loadDatabase() {
data, err := ioutil.ReadFile(dbFile) data, err := os.ReadFile(dbFile)
if err != nil && os.IsNotExist(err) { if err != nil && os.IsNotExist(err) {
partsDB = &database{ partsDB = &database{
Version: dbVersion, Version: dbVersion,
@@ -74,7 +73,7 @@ func writeDB() {
data, err := json.Marshal(partsDB) data, err := json.Marshal(partsDB)
die.If(err) die.If(err)
err = ioutil.WriteFile(dbFile, data, 0644) err = os.WriteFile(dbFile, data, 0644)
die.If(err) die.If(err)
} }

View File

@@ -4,14 +4,13 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
var ext = ".bin" var ext = ".bin"
func stripPEM(path string) error { func stripPEM(path string) error {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return err
} }
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n") fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
} }
return ioutil.WriteFile(path+ext, p.Bytes, 0644) return os.WriteFile(path+ext, p.Bytes, 0644)
} }
func main() { func main() {

View File

@@ -3,8 +3,7 @@ package main
import ( import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "io"
"io/ioutil"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -21,9 +20,9 @@ func main() {
path := flag.Arg(0) path := flag.Arg(0)
if path == "-" { if path == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(flag.Arg(0)) in, err = os.ReadFile(flag.Arg(0))
} }
if err != nil { if err != nil {
lib.Err(lib.ExitFailure, err, "couldn't read file") lib.Err(lib.ExitFailure, err, "couldn't read file")
@@ -33,5 +32,7 @@ func main() {
if p == nil { if p == nil {
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
} }
fmt.Printf("%s", p.Bytes) if _, err = os.Stdout.Write(p.Bytes); err != nil {
lib.Err(lib.ExitFailure, err, "writing body")
}
} }

View File

@@ -70,7 +70,7 @@ func main() {
lib.Err(lib.ExitFailure, err, "failed to read input") lib.Err(lib.ExitFailure, err, "failed to read input")
} }
case argc > 1: case argc > 1:
for i := 0; i < argc; i++ { for i := range argc {
path := flag.Arg(i) path := flag.Arg(i)
err = copyFile(path, buf) err = copyFile(path, buf)
if err != nil { if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
@@ -13,14 +12,14 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
data, err := ioutil.ReadFile(fileName) data, err := os.ReadFile(fileName)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
continue continue
} }
fmt.Printf("[+] %s:\n", fileName) fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
rest := data[:] rest := data
for { for {
var p *pem.Block var p *pem.Block
p, rest = pem.Decode(rest) p, rest = pem.Decode(rest)
@@ -28,13 +27,14 @@ func main() {
break break
} }
cert, err := x509.ParseCertificate(p.Bytes) var cert *x509.Certificate
cert, err = x509.ParseCertificate(p.Bytes)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
break break
} }
fmt.Printf("\t%+v\n", cert.Subject.CommonName) fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
} }
} }
} }

View File

@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
return hashName(path, encodedHash), nil return hashName(path, encodedHash), nil
} }
func move(dst, src string, force bool) (err error) { func move(dst, src string, force bool) error {
if fileutil.FileDoesExist(dst) && !force { if fileutil.FileDoesExist(dst) && !force {
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst) return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
} }
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
return err return err
} }
defer func(e error) { var retErr error
defer func(e *error) {
dstFile.Close() dstFile.Close()
if e != nil { if *e != nil {
os.Remove(dst) os.Remove(dst)
} }
}(err) }(&retErr)
srcFile, err := os.Open(src) srcFile, err := os.Open(src)
if err != nil { if err != nil {
retErr = err
return err return err
} }
defer srcFile.Close() defer srcFile.Close()
_, err = io.Copy(dstFile, srcFile) if _, err = io.Copy(dstFile, srcFile); err != nil {
if err != nil { retErr = err
return err return err
} }
@@ -94,6 +96,44 @@ func init() {
flag.Usage = func() { usage(os.Stdout) } flag.Usage = func() { usage(os.Stdout) }
} }
type options struct {
dryRun, force, printChanged, verbose bool
}
func processOne(file string, opt options) error {
renamed, err := newName(file)
if err != nil {
_, _ = lib.Warn(err, "failed to get new file name")
return err
}
if opt.verbose && !opt.printChanged {
fmt.Fprintln(os.Stdout, file)
}
if renamed == file {
return nil
}
if !opt.dryRun {
if err = move(renamed, file, opt.force); err != nil {
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
return err
}
}
if opt.printChanged && !opt.verbose {
fmt.Fprintln(os.Stdout, file, "->", renamed)
}
return nil
}
func run(dryRun, force, printChanged, verbose bool, files []string) {
if verbose && printChanged {
printChanged = false
}
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
for _, file := range files {
_ = processOne(file, opt)
}
}
func main() { func main() {
var dryRun, force, printChanged, verbose bool var dryRun, force, printChanged, verbose bool
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision") flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
@@ -102,34 +142,5 @@ func main() {
flag.BoolVar(&verbose, "v", false, "list all processed files") flag.BoolVar(&verbose, "v", false, "list all processed files")
flag.Parse() flag.Parse()
run(dryRun, force, printChanged, verbose, flag.Args())
if verbose && printChanged {
printChanged = false
}
for _, file := range flag.Args() {
renamed, err := newName(file)
if err != nil {
lib.Warn(err, "failed to get new file name")
continue
}
if verbose && !printChanged {
fmt.Println(file)
}
if renamed != file {
if !dryRun {
err = move(renamed, file, force)
if err != nil {
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
continue
}
}
if printChanged && !verbose {
fmt.Println(file, "->", renamed)
}
}
}
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -66,24 +67,25 @@ func main() {
for _, remote := range flag.Args() { for _, remote := range flag.Args() {
u, err := url.Parse(remote) u, err := url.Parse(remote)
if err != nil { if err != nil {
lib.Warn(err, "parsing %s", remote) _, _ = lib.Warn(err, "parsing %s", remote)
continue continue
} }
name := filepath.Base(u.Path) name := filepath.Base(u.Path)
if name == "" { if name == "" {
lib.Warnx("source URL doesn't appear to name a file") _, _ = lib.Warnx("source URL doesn't appear to name a file")
continue continue
} }
resp, err := http.Get(remote) req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
if err != nil { if reqErr != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(reqErr, "building request for %s", remote)
continue continue
} }
client := &http.Client{}
resp, err := client.Do(req)
if err != nil { if err != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(err, "fetching %s", remote)
continue continue
} }

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
sum := 0 sum := 0
var rolls []int var rolls []int
for i := 0; i < count; i++ { for range count {
roll := rand.Intn(sides) + 1 roll := rand.IntN(sides) + 1 // #nosec G404
sum += roll sum += roll
rolls = append(rolls, roll) rolls = append(rolls, roll)
} }

View File

@@ -53,7 +53,7 @@ func init() {
project = wd[len(gopath):] project = wd[len(gopath):]
} }
func walkFile(path string, info os.FileInfo, err error) error { func walkFile(path string, _ os.FileInfo, err error) error {
if ignores[path] { if ignores[path] {
return filepath.SkipDir return filepath.SkipDir
} }
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
return nil return nil
} }
debug.Println(path)
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err != nil { if err != nil {
return err return err
} }
debug.Println(path)
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err2 != nil {
return err2
}
for _, importSpec := range f.Imports { for _, importSpec := range f.Imports {
importPath := strings.Trim(importSpec.Path.Value, `"`) importPath := strings.Trim(importSpec.Path.Value, `"`)
if stdLibRegexp.MatchString(importPath) { switch {
case stdLibRegexp.MatchString(importPath):
debug.Println("standard lib:", importPath) debug.Println("standard lib:", importPath)
continue continue
} else if strings.HasPrefix(importPath, project) { case strings.HasPrefix(importPath, project):
debug.Println("internal import:", importPath) debug.Println("internal import:", importPath)
continue continue
} else if strings.HasPrefix(importPath, "golang.org/") { case strings.HasPrefix(importPath, "golang.org/"):
debug.Println("extended lib:", importPath) debug.Println("extended lib:", importPath)
continue continue
} }
@@ -102,7 +107,7 @@ func main() {
ignores["vendor"] = true ignores["vendor"] = true
} }
for _, word := range strings.Split(ignoreLine, ",") { for word := range strings.SplitSeq(ignoreLine, ",") {
ignores[strings.TrimSpace(word)] = true ignores[strings.TrimSpace(word)] = true
} }

View File

@@ -2,10 +2,9 @@ package main
import ( import (
"bytes" "bytes"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1" // #nosec G505
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
@@ -13,14 +12,19 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) { func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -39,14 +43,14 @@ func init() {
flag.Usage = func() { usage(os.Stderr) } flag.Usage = func() { usage(os.Stderr) }
} }
func parse(path string) (public []byte, kt, ft string) { func parse(path string) ([]byte, string, string) {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
die.If(err) die.If(err)
data = bytes.TrimSpace(data) data = bytes.TrimSpace(data)
p, rest := pem.Decode(data) p, rest := pem.Decode(data)
if len(rest) > 0 { if len(rest) > 0 {
lib.Warnx("trailing data in PEM file") _, _ = lib.Warnx("trailing data in PEM file")
} }
if p == nil { if p == nil {
@@ -55,6 +59,12 @@ func parse(path string) (public []byte, kt, ft string) {
data = p.Bytes data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type { switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY": case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data) public, kt = parseKey(data)
@@ -69,82 +79,79 @@ func parse(path string) (public []byte, kt, ft string) {
die.With("unknown PEM type %s", p.Type) die.With("unknown PEM type %s", p.Type)
} }
return return public, kt, ft
} }
func parseKey(data []byte) (public []byte, kt string) { func parseKey(data []byte) ([]byte, string) {
privInterface, err := x509.ParsePKCS8PrivateKey(data) priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil { if err != nil {
privInterface, err = x509.ParsePKCS1PrivateKey(data) die.If(err)
if err != nil {
privInterface, err = x509.ParseECPrivateKey(data)
if err != nil {
die.With("couldn't parse private key.")
}
}
} }
var priv crypto.Signer var kt string
switch privInterface.(type) { switch priv.Public().(type) {
case *rsa.PrivateKey: case *rsa.PublicKey:
priv = privInterface.(*rsa.PrivateKey) kt = keyTypeRSA
kt = "RSA" case *ecdsa.PublicKey:
case *ecdsa.PrivateKey: kt = keyTypeECDSA
priv = privInterface.(*ecdsa.PrivateKey)
kt = "ECDSA"
default: default:
die.With("unknown private key type %T", privInterface) die.With("unknown private key type %T", priv)
} }
public, err = x509.MarshalPKIXPublicKey(priv.Public()) public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err) die.If(err)
return return public, kt
} }
func parseCertificate(data []byte) (public []byte, kt string) { func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data) cert, err := x509.ParseCertificate(data)
die.If(err) die.If(err)
pub := cert.PublicKey pub := cert.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func parseCSR(data []byte) (public []byte, kt string) { func parseCSR(data []byte) ([]byte, string) {
csr, err := x509.ParseCertificateRequest(data) // Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err) die.If(err)
pub := csr.PublicKey pub := csr.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb153 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb153.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb153.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -172,18 +179,18 @@ func main() {
var subPKI subjectPublicKeyInfo var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI) _, err := asn1.Unmarshal(public, &subPKI)
if err != nil { if err != nil {
lib.Warn(err, "failed to get subject PKI") _, _ = lib.Warn(err, "failed to get subject PKI")
continue continue
} }
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:]) pubHashString := dumpHex(pubHash[:])
if ski == "" { if ski == "" {
ski = pubHashString ski = pubHashString
} }
if shouldMatch && ski != pubHashString { if shouldMatch && ski != pubHashString {
lib.Warnx("%s: SKI mismatch (%s != %s)", _, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString) path, ski, pubHashString)
} }
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft) fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)

View File

@@ -1,16 +1,17 @@
package main package main
import ( import (
"context"
"flag" "flag"
"io" "io"
"log"
"net" "net"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
) )
func proxy(conn net.Conn, inside string) error { func proxy(conn net.Conn, inside string) error {
proxyConn, err := net.Dial("tcp", inside) proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
if err != nil { if err != nil {
return err return err
} }
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
defer conn.Close() defer conn.Close()
go func() { go func() {
io.Copy(conn, proxyConn) _, _ = io.Copy(conn, proxyConn)
}() }()
_, err = io.Copy(proxyConn, conn) _, err = io.Copy(proxyConn, conn)
return err return err
@@ -31,16 +32,22 @@ func main() {
flag.StringVar(&inside, "p", "4000", "inside port") flag.StringVar(&inside, "p", "4000", "inside port")
flag.Parse() flag.Parse()
l, err := net.Listen("tcp", "0.0.0.0:"+outside) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
die.If(err) die.If(err)
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
log.Println(err) _, _ = lib.Warn(err, "accept failed")
continue continue
} }
go proxy(conn, "127.0.0.1:"+inside) go func() {
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
_, _ = lib.Warn(err, "proxy error")
}
}()
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@@ -8,7 +9,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -16,7 +16,7 @@ import (
) )
func main() { func main() {
cfg := &tls.Config{} cfg := &tls.Config{} // #nosec G402
var sysRoot, listenAddr, certFile, keyFile string var sysRoot, listenAddr, certFile, keyFile string
var verify bool var verify bool
@@ -47,7 +47,8 @@ func main() {
} }
cfg.Certificates = append(cfg.Certificates, cert) cfg.Certificates = append(cfg.Certificates, cert)
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) var pemList []byte
pemList, err = os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -59,48 +60,54 @@ func main() {
cfg.RootCAs = roots cfg.RootCAs = roots
} }
l, err := net.Listen("tcp", listenAddr) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
} }
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
}
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
err = tconn.Handshake()
if err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
continue continue
} }
cs := tconn.ConnectionState() handleConn(conn, cfg)
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
continue
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
_, err = rand.Read(nonce[:])
if err != nil {
panic(err)
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
err = ioutil.WriteFile(fname, chain, 0644)
die.If(err)
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
} }
} }
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
func handleConn(conn net.Conn, cfg *tls.Config) {
defer conn.Close()
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
if err := tconn.HandshakeContext(context.Background()); err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
return
}
cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
return
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
if _, err := rand.Read(nonce[:]); err != nil {
fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
return
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
if err := os.WriteFile(fname, chain, 0o644); err != nil {
fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
return
}
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
}

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -14,7 +14,7 @@ import (
) )
func main() { func main() {
var cfg = &tls.Config{} var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string var sysRoot, serverName string
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
@@ -23,7 +23,7 @@ func main() {
flag.Parse() flag.Parse()
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) pemList, err := os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -44,10 +44,13 @@ func main() {
if err != nil { if err != nil {
site += ":443" site += ":443"
} }
conn, err := tls.Dial("tcp", site, cfg) d := &tls.Dialer{Config: cfg}
if err != nil { nc, err := d.DialContext(context.Background(), "tcp", site)
fmt.Println(err.Error()) die.If(err)
os.Exit(1)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
cs := conn.ConnectionState() cs := conn.ConnectionState()
@@ -61,8 +64,9 @@ func main() {
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
err = ioutil.WriteFile(site+".pem", chain, 0644) err = os.WriteFile(site+".pem", chain, 0644)
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.pem.\n", site) fmt.Printf("[+] wrote %s.pem.\n", site)
} }
} }

View File

@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
for _, path := range paths { for _, path := range paths {
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
lib.Warn(err, "failed to load certificate from %s", path) _, _ = lib.Warn(err, "failed to load certificate from %s", path)
continue continue
} }
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
} }
var invalid int var invalid int
for { for len(paths) > 0 {
if len(paths) == 0 {
break
}
fst := paths[0] fst := paths[0]
snd := paths[1] snd := paths[1]
paths = paths[2:] paths = paths[2:]
fstCert, err := certlib.LoadCertificate(fst) fstCert, err := certlib.LoadCertificate(fst)
die.If(err) die.If(err)
sndCert, err := certlib.LoadCertificate(snd) sndCert, err := certlib.LoadCertificate(snd)
die.If(err) die.If(err)
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
lib.Warnx("certificates don't match: %s and %s", fst, snd) _, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
invalid++ invalid++
} }
} }

View File

@@ -1,10 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/die"
) )
func main() { func main() {
@@ -13,16 +17,23 @@ func main() {
os.Exit(1) os.Exit(1)
} }
hostPort := os.Args[1] hostPort, err := hosts.ParseHost(os.Args[1])
conn, err := tls.Dial("tcp", hostPort, &tls.Config{ die.If(err)
InsecureSkipVerify: true,
})
if err != nil { d := &tls.Dialer{Config: &tls.Config{
fmt.Printf("Failed to connect to the TLS server: %v\n", err) InsecureSkipVerify: true,
os.Exit(1) }} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
printConnectionDetails(state) printConnectionDetails(state)
} }
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
func tlsVersion(version uint16) string { func tlsVersion(version uint16) string {
switch version { switch version {
case tls.VersionTLS13: case tls.VersionTLS13:
return "TLS 1.3" return "TLS 1.3"
case tls.VersionTLS12: case tls.VersionTLS12:

View File

@@ -11,10 +11,9 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -32,7 +31,7 @@ const (
curveP521 curveP521
) )
func getECCurve(pub interface{}) int { func getECCurve(pub any) int {
switch pub := pub.(type) { switch pub := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
switch pub.Curve { switch pub.Curve {
@@ -52,42 +51,88 @@ func getECCurve(pub interface{}) int {
} }
} }
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) { func loadKey(path string) (crypto.Signer, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, _ := pem.Decode(in) if p, _ := pem.Decode(in); p != nil {
if p != nil {
if !validPEMs[p.Type] { if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type) return nil, errors.New("invalid private key file type " + p.Type)
} }
in = p.Bytes return certlib.ParsePrivateKeyPEM(in)
} }
priv, err := x509.ParsePKCS8PrivateKey(in) return certlib.ParsePrivateKeyDER(in)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(in)
if err != nil {
priv, err = x509.ParseECPrivateKey(in)
if err != nil {
return nil, err
}
}
}
switch priv.(type) {
case *rsa.PrivateKey:
return priv.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return priv.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, errors.New("invalid private key")
} }
func main() { func main() {
@@ -96,7 +141,7 @@ func main() {
flag.StringVar(&certFile, "c", "", "TLS `certificate` file") flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse() flag.Parse()
in, err := ioutil.ReadFile(certFile) in, err := os.ReadFile(certFile)
die.If(err) die.If(err)
p, _ := pem.Decode(in) p, _ := pem.Decode(in)
@@ -112,50 +157,11 @@ func main() {
priv, err := loadKey(keyFile) priv, err := loadKey(keyFile)
die.If(err) die.If(err)
switch pub := priv.Public().(type) { matched, reason := matchKeys(cert, priv)
case *rsa.PublicKey: if matched {
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.")
return
case *ecdsa.PublicKey:
fmt.Println("No match (RSA private key, EC public key).")
os.Exit(1)
}
case *ecdsa.PublicKey:
privCurve := getECCurve(pub)
certCurve := getECCurve(cert.PublicKey)
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
if certCurve == curveRSA {
fmt.Println("No match (private key is EC, certificate is RSA).")
os.Exit(1)
} else if privCurve == curveInvalid {
fmt.Println("No match (invalid private key curve).")
os.Exit(1)
} else if privCurve != certCurve {
fmt.Println("No match (EC curves don't match).")
os.Exit(1)
}
certPub := cert.PublicKey.(*ecdsa.PublicKey)
if pub.X.Cmp(certPub.X) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
if pub.Y.Cmp(certPub.Y) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.") fmt.Println("Match.")
default: return
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
os.Exit(1)
} }
fmt.Printf("No match (%s).\n", reason)
os.Exit(1)
} }

View File

@@ -201,10 +201,6 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if fromLoc == time.UTC {
}
toLoc = time.UTC toLoc = time.UTC
} }
@@ -257,15 +253,16 @@ func main() {
showTime(time.Now()) showTime(time.Now())
os.Exit(0) os.Exit(0)
case 1: case 1:
if flag.Arg(0) == "-" { switch {
case flag.Arg(0) == "-":
s := bufio.NewScanner(os.Stdin) s := bufio.NewScanner(os.Stdin)
for s.Scan() { for s.Scan() {
times = append(times, s.Text()) times = append(times, s.Text())
} }
} else if flag.Arg(0) == "help" { case flag.Arg(0) == "help":
usageExamples() usageExamples()
} else { default:
times = flag.Args() times = flag.Args()
} }
default: default:

View File

@@ -4,7 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@@ -12,9 +11,8 @@ import (
type empty struct{} type empty struct{}
func errorf(format string, args ...interface{}) { func errorf(path string, err error) {
format += "\n" fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
fmt.Fprintf(os.Stderr, format, args...)
} }
func usage(w io.Writer) { func usage(w io.Writer) {
@@ -44,16 +42,16 @@ func main() {
if flag.NArg() == 1 && flag.Arg(0) == "-" { if flag.NArg() == 1 && flag.Arg(0) == "-" {
path := "stdin" path := "stdin"
in, err := ioutil.ReadAll(os.Stdin) in, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
@@ -65,16 +63,16 @@ func main() {
} }
for _, path := range flag.Args() { for _, path := range flag.Args() {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }

View File

@@ -14,16 +14,16 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"git.wntrmute.dev/kyle/goutils/lib"
) )
const defaultDirectory = ".git/objects" const defaultDirectory = ".git/objects"
func errorf(format string, a ...interface{}) { // maxDecompressedSize limits how many bytes we will decompress from a zlib
fmt.Fprintf(os.Stderr, format, a...) // stream to mitigate decompression bombs (gosec G110).
if format[len(format)-1] != '\n' { // Increase this if you expect larger objects.
fmt.Fprintf(os.Stderr, "\n") const maxDecompressedSize int64 = 64 << 30 // 64 GiB
}
}
func isDir(path string) bool { func isDir(path string) bool {
fi, err := os.Stat(path) fi, err := os.Stat(path)
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
} }
defer zread.Close() defer zread.Close()
_, err = io.Copy(buf, zread) // Protect against decompression bombs by limiting how much we read.
if err != nil { lr := io.LimitReader(zread, maxDecompressedSize+1)
if _, err = buf.ReadFrom(lr); err != nil {
return nil, err return nil, err
} }
if int64(buf.Len()) > maxDecompressedSize {
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
}
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func showFile(path string) { func showFile(path string) {
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return return
} }
@@ -68,39 +72,71 @@ func showFile(path string) {
func searchFile(path string, search *regexp.Regexp) error { func searchFile(path string, search *regexp.Regexp) error {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to open %s", path)
return err return err
} }
defer file.Close() defer file.Close()
zread, err := zlib.NewReader(file) zread, err := zlib.NewReader(file)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to decompress %s", path)
return err return err
} }
defer zread.Close() defer zread.Close()
zbuf := bufio.NewReader(zread) // Limit how much we scan to avoid DoS via huge decompression.
if search.MatchReader(zbuf) { lr := io.LimitReader(zread, maxDecompressedSize+1)
fileData, err := loadFile(path) zbuf := bufio.NewReader(lr)
if err != nil { if !search.MatchReader(zbuf) {
errorf("%v", err) return nil
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
} }
fileData, err := loadFile(path)
if err != nil {
lib.Warn(err, "failed to load %s", path)
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
return nil return nil
} }
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc { func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error { return func(path string, info os.FileInfo, _ error) error {
if info.Mode().IsRegular() { if !info.Mode().IsRegular() {
return searchFile(path, searchExpr) return nil
} }
return nil return searchFile(path, searchExpr)
} }
} }
// runSearch compiles the search expression and processes the provided paths.
// It returns an error for fatal conditions; per-file errors are logged.
func runSearch(expr string) error {
search, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("invalid regexp: %w", err)
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
return err2
}
continue
}
if err2 := searchFile(path, search); err2 != nil {
// Non-fatal: keep going, but report it.
lib.Warn(err2, "non-fatal error while searching files")
}
}
return nil
}
func main() { func main() {
flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)") flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)")
flag.Parse() flag.Parse()
@@ -109,28 +145,10 @@ func main() {
for _, path := range flag.Args() { for _, path := range flag.Args() {
showFile(path) showFile(path)
} }
} else { return
search, err := regexp.Compile(*flSearch) }
if err != nil {
errorf("Bad regexp: %v", err)
return
}
pathList := flag.Args() if err := runSearch(*flSearch); err != nil {
if len(pathList) == 0 { lib.Err(lib.ExitFailure, err, "failed to run search")
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
err := filepath.Walk(path, buildWalker(search))
if err != nil {
errorf("%v", err)
return
}
} else {
searchFile(path, search)
}
}
} }
} }

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
) )
require ( require (
github.com/benbjohnson/clock v1.3.5
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/google/certificate-transparency-go v1.0.21 github.com/google/certificate-transparency-go v1.0.21
) )

2
go.sum
View File

@@ -1,3 +1,5 @@
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

View File

@@ -1,4 +1,4 @@
// +build freebsd darwin,386 netbsd //go:build bsd
package lib package lib

View File

@@ -1,4 +1,4 @@
// +build unix linux openbsd darwin,amd64 //go:build unix || linux || openbsd || (darwin && amd64)
package lib package lib
@@ -18,7 +18,7 @@ type FileTime struct {
func timeSpecToTime(ts unix.Timespec) time.Time { func timeSpecToTime(ts unix.Timespec) time.Time {
// The casts to int64 are needed because on 386, these are int32s. // The casts to int64 are needed because on 386, these are int32s.
return time.Unix(int64(ts.Sec), int64(ts.Nsec)) return time.Unix(ts.Sec, ts.Nsec)
} }
// LoadFileTime returns a FileTime associated with the file. // LoadFileTime returns a FileTime associated with the file.

View File

@@ -10,6 +10,12 @@ import (
var progname = filepath.Base(os.Args[0]) var progname = filepath.Base(os.Args[0])
const (
daysInYear = 365
digitWidth = 10
hoursInQuarterDay = 6
)
// ProgName returns what lib thinks the program name is, namely the // ProgName returns what lib thinks the program name is, namely the
// basename of argv0. // basename of argv0.
// //
@@ -20,7 +26,7 @@ func ProgName() string {
// Warnx displays a formatted error message to standard error, à la // Warnx displays a formatted error message to standard error, à la
// warnx(3). // warnx(3).
func Warnx(format string, a ...interface{}) (int, error) { func Warnx(format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
return fmt.Fprintf(os.Stderr, format, a...) return fmt.Fprintf(os.Stderr, format, a...)
@@ -28,7 +34,7 @@ func Warnx(format string, a ...interface{}) (int, error) {
// Warn displays a formatted error message to standard output, // Warn displays a formatted error message to standard output,
// appending the error string, à la warn(3). // appending the error string, à la warn(3).
func Warn(err error, format string, a ...interface{}) (int, error) { func Warn(err error, format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -37,7 +43,7 @@ func Warn(err error, format string, a ...interface{}) (int, error) {
// Errx displays a formatted error message to standard error and exits // Errx displays a formatted error message to standard error and exits
// with the status code from `exit`, à la errx(3). // with the status code from `exit`, à la errx(3).
func Errx(exit int, format string, a ...interface{}) { func Errx(exit int, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
fmt.Fprintf(os.Stderr, format, a...) fmt.Fprintf(os.Stderr, format, a...)
@@ -47,7 +53,7 @@ func Errx(exit int, format string, a ...interface{}) {
// Err displays a formatting error message to standard error, // Err displays a formatting error message to standard error,
// appending the error string, and exits with the status code from // appending the error string, and exits with the status code from
// `exit`, à la err(3). // `exit`, à la err(3).
func Err(exit int, err error, format string, a ...interface{}) { func Err(exit int, err error, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -62,30 +68,30 @@ func Itoa(i int, wid int) string {
// Assemble decimal in reverse order. // Assemble decimal in reverse order.
var b [20]byte var b [20]byte
bp := len(b) - 1 bp := len(b) - 1
for i >= 10 || wid > 1 { for i >= digitWidth || wid > 1 {
wid-- wid--
q := i / 10 q := i / digitWidth
b[bp] = byte('0' + i - q*10) b[bp] = byte('0' + i - q*digitWidth)
bp-- bp--
i = q i = q
} }
// i < 10
b[bp] = byte('0' + i) b[bp] = byte('0' + i)
return string(b[bp:]) return string(b[bp:])
} }
var ( var (
dayDuration = 24 * time.Hour dayDuration = 24 * time.Hour
yearDuration = (365 * dayDuration) + (6 * time.Hour) yearDuration = (daysInYear * dayDuration) + (hoursInQuarterDay * time.Hour)
) )
// Duration returns a prettier string for time.Durations. // Duration returns a prettier string for time.Durations.
func Duration(d time.Duration) string { func Duration(d time.Duration) string {
var s string var s string
if d >= yearDuration { if d >= yearDuration {
years := d / yearDuration years := int64(d / yearDuration)
s += fmt.Sprintf("%dy", years) s += fmt.Sprintf("%dy", years)
d -= years * yearDuration d -= time.Duration(years) * yearDuration
} }
if d >= dayDuration { if d >= dayDuration {
@@ -98,8 +104,8 @@ func Duration(d time.Duration) string {
} }
d %= 1 * time.Second d %= 1 * time.Second
hours := d / time.Hour hours := int64(d / time.Hour)
d -= hours * time.Hour d -= time.Duration(hours) * time.Hour
s += fmt.Sprintf("%dh%s", hours, d) s += fmt.Sprintf("%dh%s", hours, d)
return s return s
} }

View File

@@ -1,8 +1,9 @@
package logging package logging
import ( import (
"fmt" "errors"
"os" "fmt"
"os"
) )
// File writes its logs to file. // File writes its logs to file.
@@ -59,12 +60,12 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) {
fl.fe, err = os.OpenFile(errpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600) fl.fe, err = os.OpenFile(errpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
} }
if err != nil { if err != nil {
if closeErr := fl.Close(); closeErr != nil { if closeErr := fl.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err) return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err))
} }
return nil, err return nil, fmt.Errorf("failed to open error log: %w", err)
} }
fl.LogWriter = NewLogWriter(fl.fo, fl.fe) fl.LogWriter = NewLogWriter(fl.fo, fl.fe)
return fl, nil return fl, nil
@@ -94,13 +95,13 @@ func (fl *File) Flush() error {
} }
func (fl *File) Chmod(mode os.FileMode) error { func (fl *File) Chmod(mode os.FileMode) error {
if err := fl.fo.Chmod(mode); err != nil { if err := fl.fo.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod output log: %w", err) return fmt.Errorf("failed to chmod output log: %w", err)
} }
if err := fl.fe.Chmod(mode); err != nil { if err := fl.fe.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod error log: %w", err) return fmt.Errorf("failed to chmod error log: %w", err)
} }
return nil return nil
} }

View File

@@ -0,0 +1,42 @@
# Use the latest 2.1 version of CircleCI pipeline process engine.
# See: https://circleci.com/docs/2.0/configuration-reference
version: 2.1
# Define a job to be invoked later in a workflow.
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
jobs:
testbuild:
working_directory: ~/repo
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
docker:
- image: cimg/go:1.22.2
# Add steps to the job
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
steps:
- checkout
- restore_cache:
keys:
- go-mod-v4-{{ checksum "go.sum" }}
- run:
name: Install Dependencies
command: go mod download
- save_cache:
key: go-mod-v4-{{ checksum "go.sum" }}
paths:
- "/go/pkg/mod"
- run:
name: Run tests
command: go test ./...
- run:
name: Run build
command: go build ./...
- store_test_results:
path: /tmp/test-reports
# Invoke jobs via workflows
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
workflows:
testbuild:
jobs:
- testbuild

19
twofactor/LICENSE Normal file
View File

@@ -0,0 +1,19 @@
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

33
twofactor/README.md Normal file
View File

@@ -0,0 +1,33 @@
## `twofactor`
[![GoDoc](https://godoc.org/github.com/gokyle/twofactor?status.svg)](https://godoc.org/github.com/gokyle/twofactor)
### Author
`twofactor` was written by Kyle Isom <kyle@tyrfingr.is>.
### License
```
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

5
twofactor/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// twofactor implements two-factor authentication.
//
// Currently supported are RFC 4226 HOTP one-time passwords and
// RFC 6238 TOTP SHA-1 one-time passwords.
package twofactor

8
twofactor/go.mod Normal file
View File

@@ -0,0 +1,8 @@
module github.com/gokyle/twofactor
go 1.14
require (
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3
rsc.io/qr v0.1.0
)

4
twofactor/go.sum Normal file
View File

@@ -0,0 +1,4 @@
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
rsc.io/qr v0.1.0 h1:M/sAxsU2J5mlQ4W84Bxga2EgdQqOaAliipcjPmMUM5Q=
rsc.io/qr v0.1.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

103
twofactor/hotp.go Normal file
View File

@@ -0,0 +1,103 @@
package twofactor
import (
"crypto"
"crypto/sha1"
"encoding/base32"
"io"
"net/url"
"strconv"
"strings"
)
// HOTP represents an RFC-4226 Hash-based One Time Password instance.
type HOTP struct {
*OATH
}
// Type returns OATH_HOTP.
func (otp *HOTP) Type() Type {
return OATH_HOTP
}
// NewHOTP takes the key, the initial counter value, and the number
// of digits (typically 6 or 8) and returns a new HOTP instance.
func NewHOTP(key []byte, counter uint64, digits int) *HOTP {
return &HOTP{
OATH: &OATH{
key: key,
counter: counter,
size: digits,
hash: sha1.New,
algo: crypto.SHA1,
},
}
}
// OTP returns the next OTP and increments the counter.
func (otp *HOTP) OTP() string {
code := otp.OATH.OTP(otp.counter)
otp.counter++
return code
}
// URL returns an HOTP URL (i.e. for putting in a QR code).
func (otp *HOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *HOTP) SetProvider(provider string) {
otp.provider = provider
}
// GenerateGoogleHOTP generates a new HOTP instance as used by
// Google Authenticator.
func GenerateGoogleHOTP() *HOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewHOTP(key, 0, 6)
}
func hotpFromURL(u *url.URL) (*HOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var counter uint64 = 0
if scounter := v.Get("counter"); scounter != "" {
var err error
counter, err = strconv.ParseUint(scounter, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewHOTP(key, counter, digits)
return otp, label, nil
}
// QR generates a new QR code for the HOTP.
func (otp *HOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}

64
twofactor/hotp_test.go Normal file
View File

@@ -0,0 +1,64 @@
package twofactor
import (
"fmt"
"testing"
)
var testKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
var rfcHotpKey = []byte("12345678901234567890")
var rfcHotpExpected = []string{
"755224",
"287082",
"359152",
"969429",
"338314",
"254676",
"287922",
"162583",
"399871",
"520489",
}
// This test runs through the test cases presented in the RFC, and
// ensures that this implementation is in compliance.
func TestHotpRFC(t *testing.T) {
otp := NewHOTP(rfcHotpKey, 0, 6)
for i := 0; i < len(rfcHotpExpected); i++ {
if otp.Counter() != uint64(i) {
fmt.Printf("twofactor: invalid counter (should be %d, is %d",
i, otp.Counter())
t.FailNow()
}
code := otp.OTP()
if code == "" {
fmt.Printf("twofactor: failed to produce an OTP\n")
t.FailNow()
} else if code != rfcHotpExpected[i] {
fmt.Printf("twofactor: invalid OTP\n")
fmt.Printf("\tExpected: %s\n", rfcHotpExpected[i])
fmt.Printf("\t Actual: %s\n", code)
fmt.Printf("\t Counter: %d\n", otp.counter)
t.FailNow()
}
}
}
// This test uses a different key than the test cases in the RFC,
// but runs through the same test cases to ensure that they fail as
// expected.
func TestHotpBadRFC(t *testing.T) {
otp := NewHOTP(testKey, 0, 6)
for i := 0; i < len(rfcHotpExpected); i++ {
code := otp.OTP()
switch code {
case "":
fmt.Printf("twofactor: failed to produce an OTP\n")
t.FailNow()
case rfcHotpExpected[i]:
fmt.Printf("twofactor: should not have received a valid OTP\n")
t.FailNow()
}
}
}

150
twofactor/oath.go Normal file
View File

@@ -0,0 +1,150 @@
package twofactor
import (
"crypto"
"crypto/hmac"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"net/url"
"rsc.io/qr"
)
const defaultSize = 6
// OATH provides a baseline structure for the two OATH algorithms.
type OATH struct {
key []byte
counter uint64
size int
hash func() hash.Hash
algo crypto.Hash
provider string
}
// Size returns the output size (in characters) of the password.
func (o OATH) Size() int {
return o.size
}
// Counter returns the OATH token's counter.
func (o OATH) Counter() uint64 {
return o.counter
}
// SetCounter updates the OATH token's counter to a new value.
func (o *OATH) SetCounter(counter uint64) {
o.counter = counter
}
// Key returns the token's secret key.
func (o OATH) Key() []byte {
return o.key[:]
}
// Hash returns the token's hash function.
func (o OATH) Hash() func() hash.Hash {
return o.hash
}
// URL constructs a URL appropriate for the token (i.e. for use in a
// QR code).
func (o OATH) URL(t Type, label string) string {
secret := base32.StdEncoding.EncodeToString(o.key)
u := url.URL{}
v := url.Values{}
u.Scheme = "otpauth"
switch t {
case OATH_HOTP:
u.Host = "hotp"
case OATH_TOTP:
u.Host = "totp"
}
u.Path = label
v.Add("secret", secret)
if o.Counter() != 0 && t == OATH_HOTP {
v.Add("counter", fmt.Sprintf("%d", o.Counter()))
}
if o.Size() != defaultSize {
v.Add("digits", fmt.Sprintf("%d", o.Size()))
}
switch o.algo {
case crypto.SHA256:
v.Add("algorithm", "SHA256")
case crypto.SHA512:
v.Add("algorithm", "SHA512")
}
if o.provider != "" {
v.Add("provider", o.provider)
}
u.RawQuery = v.Encode()
return u.String()
}
var digits = []int64{
0: 1,
1: 10,
2: 100,
3: 1000,
4: 10000,
5: 100000,
6: 1000000,
7: 10000000,
8: 100000000,
9: 1000000000,
10: 10000000000,
}
// The top-level type should provide a counter; for example, HOTP
// will provide the counter directly while TOTP will provide the
// time-stepped counter.
func (o OATH) OTP(counter uint64) string {
var ctr [8]byte
binary.BigEndian.PutUint64(ctr[:], counter)
var mod int64 = 1
if len(digits) > o.size {
for i := 1; i <= o.size; i++ {
mod *= 10
}
} else {
mod = digits[o.size]
}
h := hmac.New(o.hash, o.key)
h.Write(ctr[:])
dt := truncate(h.Sum(nil)) % mod
fmtStr := fmt.Sprintf("%%0%dd", o.size)
return fmt.Sprintf(fmtStr, dt)
}
// truncate contains the DT function from the RFC; this is used to
// deterministically select a sequence of 4 bytes from the HMAC
// counter hash.
func truncate(in []byte) int64 {
offset := int(in[len(in)-1] & 0xF)
p := in[offset : offset+4]
var binCode int32
binCode = int32((p[0] & 0x7f)) << 24
binCode += int32((p[1] & 0xff)) << 16
binCode += int32((p[2] & 0xff)) << 8
binCode += int32((p[3] & 0xff))
return int64(binCode) & 0x7FFFFFFF
}
// QR generates a byte slice containing the a QR code encoded as a
// PNG with level Q error correction.
func (o OATH) QR(t Type, label string) ([]byte, error) {
u := o.URL(t, label)
code, err := qr.Encode(u, qr.Q)
if err != nil {
return nil, err
}
return code.PNG(), nil
}

30
twofactor/oath_test.go Normal file
View File

@@ -0,0 +1,30 @@
package twofactor
import (
"fmt"
"testing"
)
var sha1Hmac = []byte{
0x1f, 0x86, 0x98, 0x69, 0x0e,
0x02, 0xca, 0x16, 0x61, 0x85,
0x50, 0xef, 0x7f, 0x19, 0xda,
0x8e, 0x94, 0x5b, 0x55, 0x5a,
}
var truncExpect int64 = 0x50ef7f19
// This test runs through the truncation example given in the RFC.
func TestTruncate(t *testing.T) {
if result := truncate(sha1Hmac); result != truncExpect {
fmt.Printf("hotp: expected truncate -> %d, saw %d\n",
truncExpect, result)
t.FailNow()
}
sha1Hmac[19]++
if result := truncate(sha1Hmac); result == truncExpect {
fmt.Println("hotp: expected truncation to fail")
t.FailNow()
}
}

86
twofactor/otp.go Normal file
View File

@@ -0,0 +1,86 @@
package twofactor
import (
"crypto/rand"
"errors"
"fmt"
"hash"
"net/url"
)
type Type uint
const (
OATH_HOTP = iota
OATH_TOTP
)
// PRNG is an io.Reader that provides a cryptographically secure
// random byte stream.
var PRNG = rand.Reader
var (
ErrInvalidURL = errors.New("twofactor: invalid URL")
ErrInvalidAlgo = errors.New("twofactor: invalid algorithm")
)
// Type OTP represents a one-time password token -- whether a
// software taken (as in the case of Google Authenticator) or a
// hardware token (as in the case of a YubiKey).
type OTP interface {
// Returns the current counter value; the meaning of the
// returned value is algorithm-specific.
Counter() uint64
// Set the counter to a specific value.
SetCounter(uint64)
// the secret key contained in the OTP
Key() []byte
// generate a new OTP
OTP() string
// the output size of the OTP
Size() int
// the hash function used by the OTP
Hash() func() hash.Hash
// Returns the type of this OTP.
Type() Type
}
func otpString(otp OTP) string {
var typeName string
switch otp.Type() {
case OATH_HOTP:
typeName = "OATH-HOTP"
case OATH_TOTP:
typeName = "OATH-TOTP"
default:
typeName = "UNKNOWN"
}
return fmt.Sprintf("%s, %d", typeName, otp.Size())
}
// FromURL constructs a new OTP token from a URL string.
func FromURL(URL string) (OTP, string, error) {
u, err := url.Parse(URL)
if err != nil {
return nil, "", err
}
if u.Scheme != "otpauth" {
return nil, "", ErrInvalidURL
}
switch u.Host {
case "totp":
return totpFromURL(u)
case "hotp":
return hotpFromURL(u)
default:
return nil, "", ErrInvalidURL
}
}

136
twofactor/otp_test.go Normal file
View File

@@ -0,0 +1,136 @@
package twofactor
import (
"fmt"
"io"
"testing"
)
func TestHOTPString(t *testing.T) {
hotp := NewHOTP(nil, 0, 6)
hotpString := otpString(hotp)
if hotpString != "OATH-HOTP, 6" {
fmt.Println("twofactor: invalid OTP string")
t.FailNow()
}
}
// This test generates a new OTP, outputs the URL for that OTP,
// and attempts to parse that URL. It verifies that the two OTPs
// are the same, and that they produce the same output.
func TestURL(t *testing.T) {
var ident = "testuser@foo"
otp := NewHOTP(testKey, 0, 6)
url := otp.URL("testuser@foo")
otp2, id, err := FromURL(url)
if err != nil {
fmt.Printf("hotp: failed to parse HOTP URL\n")
t.FailNow()
} else if id != ident {
fmt.Printf("hotp: bad label\n")
fmt.Printf("\texpected: %s\n", ident)
fmt.Printf("\t actual: %s\n", id)
t.FailNow()
} else if otp2.Counter() != otp.Counter() {
fmt.Printf("hotp: OTP counters aren't synced\n")
fmt.Printf("\toriginal: %d\n", otp.Counter())
fmt.Printf("\t second: %d\n", otp2.Counter())
t.FailNow()
}
code1 := otp.OTP()
code2 := otp2.OTP()
if code1 != code2 {
fmt.Printf("hotp: mismatched OTPs\n")
fmt.Printf("\texpected: %s\n", code1)
fmt.Printf("\t actual: %s\n", code2)
}
// There's not much we can do test the QR code, except to
// ensure it doesn't fail.
_, err = otp.QR(ident)
if err != nil {
fmt.Printf("hotp: failed to generate QR code PNG (%v)\n", err)
t.FailNow()
}
// This should fail because the maximum size of an alphanumeric
// QR code with the lowest-level of error correction should
// max out at 4296 bytes. 8k may be a bit overkill... but it
// gets the job done. The value is read from the PRNG to
// increase the likelihood that the returned data is
// uncompressible.
var tooBigIdent = make([]byte, 8192)
_, err = io.ReadFull(PRNG, tooBigIdent)
if err != nil {
fmt.Printf("hotp: failed to read identity (%v)\n", err)
t.FailNow()
} else if _, err = otp.QR(string(tooBigIdent)); err == nil {
fmt.Println("hotp: QR code should fail to encode oversized URL")
t.FailNow()
}
}
// This test makes sure we can generate codes for padded and non-padded
// entries
func TestPaddedURL(t *testing.T) {
var urlList = []string{
"otpauth://hotp/?secret=ME",
"otpauth://hotp/?secret=MEFR",
"otpauth://hotp/?secret=MFRGG",
"otpauth://hotp/?secret=MFRGGZA",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by=======",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by%3D%3D%3D%3D%3D%3D%3D",
}
var codeList = []string{
"413198",
"770938",
"670717",
"402378",
"069864",
"069864",
"069864",
}
for i := range urlList {
if o, id, err := FromURL(urlList[i]); err != nil {
fmt.Println("hotp: URL should have parsed successfully (id=", id, ")")
fmt.Printf("\turl was: %s\n", urlList[i])
t.FailNow()
fmt.Printf("\t%s, %s\n", o.OTP(), id)
} else {
code2 := o.OTP()
if code2 != codeList[i] {
fmt.Printf("hotp: mismatched OTPs\n")
fmt.Printf("\texpected: %s\n", codeList[i])
fmt.Printf("\t actual: %s\n", code2)
t.FailNow()
}
}
}
}
// This test attempts a variety of invalid urls against the parser
// to ensure they fail.
func TestBadURL(t *testing.T) {
var urlList = []string{
"http://google.com",
"",
"-",
"foo",
"otpauth:/foo/bar/baz",
"://",
"otpauth://hotp/?digits=",
"otpauth://hotp/?secret=MFRGGZDF&digits=ABCD",
"otpauth://hotp/?secret=MFRGGZDF&counter=ABCD",
}
for i := range urlList {
if _, _, err := FromURL(urlList[i]); err == nil {
fmt.Println("hotp: URL should not have parsed successfully")
fmt.Printf("\turl was: %s\n", urlList[i])
t.FailNow()
}
}
}

172
twofactor/totp.go Normal file
View File

@@ -0,0 +1,172 @@
package twofactor
import (
"crypto"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"hash"
"io"
"net/url"
"strconv"
"strings"
"github.com/benbjohnson/clock"
)
var timeSource = clock.New()
// TOTP represents an RFC 6238 Time-based One-Time Password instance.
type TOTP struct {
*OATH
step uint64
}
// Type returns OATH_TOTP.
func (otp *TOTP) Type() Type {
return OATH_TOTP
}
func (otp *TOTP) otp(counter uint64) string {
return otp.OATH.OTP(counter)
}
// OTP returns the OTP for the current timestep.
func (otp *TOTP) OTP() string {
return otp.otp(otp.OTPCounter())
}
// URL returns a TOTP URL (i.e. for putting in a QR code).
func (otp *TOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *TOTP) SetProvider(provider string) {
otp.provider = provider
}
func (otp *TOTP) otpCounter(t uint64) uint64 {
return (t - otp.counter) / otp.step
}
// OTPCounter returns the current time value for the OTP.
func (otp *TOTP) OTPCounter() uint64 {
return otp.otpCounter(uint64(timeSource.Now().Unix()))
}
// NewTOTP takes a new key, a starting time, a step, the number of
// digits of output (typically 6 or 8) and the hash algorithm to
// use, and builds a new OTP.
func NewTOTP(key []byte, start uint64, step uint64, digits int, algo crypto.Hash) *TOTP {
h := hashFromAlgo(algo)
if h == nil {
return nil
}
return &TOTP{
OATH: &OATH{
key: key,
counter: start,
size: digits,
hash: h,
algo: algo,
},
step: step,
}
}
// NewTOTPSHA1 will build a new TOTP using SHA-1.
func NewTOTPSHA1(key []byte, start uint64, step uint64, digits int) *TOTP {
return NewTOTP(key, start, step, digits, crypto.SHA1)
}
func hashFromAlgo(algo crypto.Hash) func() hash.Hash {
switch algo {
case crypto.SHA1:
return sha1.New
case crypto.SHA256:
return sha256.New
case crypto.SHA512:
return sha512.New
}
return nil
}
// GenerateGoogleTOTP produces a new TOTP token with the defaults expected by
// Google Authenticator.
func GenerateGoogleTOTP() *TOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1)
}
// NewGoogleTOTP takes a secret as a base32-encoded string and
// returns an appropriate Google Authenticator TOTP instance.
func NewGoogleTOTP(secret string) (*TOTP, error) {
key, err := base32.StdEncoding.DecodeString(secret)
if err != nil {
return nil, err
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1), nil
}
func totpFromURL(u *url.URL) (*TOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var algo = crypto.SHA1
if algorithm := v.Get("algorithm"); algorithm != "" {
if strings.EqualFold(algorithm, "SHA256") {
algo = crypto.SHA256
} else if strings.EqualFold(algorithm, "SHA512") {
algo = crypto.SHA512
} else if !strings.EqualFold(algorithm, "SHA1") {
return nil, "", ErrInvalidAlgo
}
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var period uint64 = 30
if speriod := v.Get("period"); speriod != "" {
var err error
period, err = strconv.ParseUint(speriod, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewTOTP(key, 0, period, digits, algo)
return otp, label, nil
}
// QR generates a new TOTP QR code.
func (otp *TOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}
func SetClock(c clock.Clock) {
timeSource = c
}

87
twofactor/totp_test.go Normal file
View File

@@ -0,0 +1,87 @@
package twofactor
import (
"crypto"
"fmt"
"testing"
"time"
"github.com/benbjohnson/clock"
)
var rfcTotpKey = map[crypto.Hash][]byte{
crypto.SHA1: []byte("12345678901234567890"),
crypto.SHA256: []byte("12345678901234567890123456789012"),
crypto.SHA512: []byte("1234567890123456789012345678901234567890123456789012345678901234"),
}
var rfcTotpStep uint64 = 30
var rfcTotpTests = []struct {
Time uint64
Code string
T uint64
Algo crypto.Hash
}{
{59, "94287082", 1, crypto.SHA1},
{59, "46119246", 1, crypto.SHA256},
{59, "90693936", 1, crypto.SHA512},
{1111111109, "07081804", 37037036, crypto.SHA1},
{1111111109, "68084774", 37037036, crypto.SHA256},
{1111111109, "25091201", 37037036, crypto.SHA512},
{1111111111, "14050471", 37037037, crypto.SHA1},
{1111111111, "67062674", 37037037, crypto.SHA256},
{1111111111, "99943326", 37037037, crypto.SHA512},
{1234567890, "89005924", 41152263, crypto.SHA1},
{1234567890, "91819424", 41152263, crypto.SHA256},
{1234567890, "93441116", 41152263, crypto.SHA512},
{2000000000, "69279037", 66666666, crypto.SHA1},
{2000000000, "90698825", 66666666, crypto.SHA256},
{2000000000, "38618901", 66666666, crypto.SHA512},
{20000000000, "65353130", 666666666, crypto.SHA1},
{20000000000, "77737706", 666666666, crypto.SHA256},
{20000000000, "47863826", 666666666, crypto.SHA512},
}
func TestTotpRFC(t *testing.T) {
for _, tc := range rfcTotpTests {
otp := NewTOTP(rfcTotpKey[tc.Algo], 0, rfcTotpStep, 8, tc.Algo)
if otp.otpCounter(tc.Time) != tc.T {
fmt.Printf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
fmt.Printf("\texpected: %d\n", tc.T)
fmt.Printf("\t actual: %d\n", otp.otpCounter(tc.Time))
t.Fail()
}
if code := otp.otp(otp.otpCounter(tc.Time)); code != tc.Code {
fmt.Printf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
fmt.Printf("\texpected: %s\n", tc.Code)
fmt.Printf("\t actual: %s\n", code)
t.Fail()
}
}
}
func TestTOTPTime(t *testing.T) {
otp := GenerateGoogleTOTP()
testClock := clock.NewMock()
testClock.Add(2 * time.Minute)
SetClock(testClock)
code := otp.OTP()
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: previous code %s shouldn't match code %s", newCode, code)
}
testClock.Add(2 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: future code %s shouldn't match code %s", newCode, code)
}
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode != code {
t.Errorf("twofactor: TOTP: current code %s shouldn't match code %s", newCode, code)
}
}

16
twofactor/util.go Normal file
View File

@@ -0,0 +1,16 @@
package twofactor
import (
"strings"
)
// Pad calculates the number of '='s to add to our encoded string
// to make base32.StdEncoding.DecodeString happy
func Pad(s string) string {
if !strings.HasSuffix(s, "=") && len(s)%8 != 0 {
for len(s)%8 != 0 {
s += "="
}
}
return s
}

53
twofactor/util_test.go Normal file
View File

@@ -0,0 +1,53 @@
package twofactor
import (
"encoding/base32"
"fmt"
"math/rand"
"strings"
"testing"
)
const letters = "1234567890!@#$%^&*()abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randString() string {
b := make([]byte, rand.Intn(len(letters)))
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return base32.StdEncoding.EncodeToString(b)
}
func TestPadding(t *testing.T) {
for i := 0; i < 300; i++ {
b := randString()
origEncoding := string(b)
modEncoding := strings.ReplaceAll(string(b), "=", "")
str, err := base32.StdEncoding.DecodeString(origEncoding)
if err != nil {
fmt.Println("Can't decode: ", string(b))
t.FailNow()
}
paddedEncoding := Pad(modEncoding)
if origEncoding != paddedEncoding {
fmt.Println("Padding failed:")
fmt.Printf("Expected: '%s'", origEncoding)
fmt.Printf("Got: '%s'", paddedEncoding)
t.FailNow()
} else {
mstr, err := base32.StdEncoding.DecodeString(paddedEncoding)
if err != nil {
fmt.Println("Can't decode: ", paddedEncoding)
t.FailNow()
}
if string(mstr) != string(str) {
fmt.Println("Re-padding failed:")
fmt.Printf("Expected: '%s'", str)
fmt.Printf("Got: '%s'", mstr)
t.FailNow()
}
}
}
}