Compare commits

..

50 Commits

Author SHA1 Message Date
cfb80355bb Update CHANGELOG for v1.13.1. 2025-11-17 10:08:05 -08:00
77160395a0 Cleaning up a few things. 2025-11-17 10:07:03 -08:00
37d5e04421 Adding Dockerfile 2025-11-17 09:03:43 -08:00
dc54eeacbc Remove cert bundles generated in testdata. 2025-11-17 08:36:31 -08:00
e2a3081ce5 cmd: add certser command. 2025-11-17 07:18:46 -08:00
3149d958f4 cmd: add certser 2025-11-17 06:55:20 -08:00
f296344acf twofactor: linting fixes 2025-11-16 21:51:38 -08:00
3fb2d88a3f go get rsc.io/qr 2025-11-16 20:44:13 -08:00
150c02b377 Fix subtree. 2025-11-16 18:55:43 -08:00
83f88c49fe Import twofactor. 2025-11-16 18:45:34 -08:00
7c437ac45f Add 'twofactor/' from commit 'c999bf35b0e47de4f63d59abbe0d7efc76c13ced'
git-subtree-dir: twofactor
git-subtree-mainline: 4dc135cfe0
git-subtree-split: c999bf35b0
2025-11-16 18:43:03 -08:00
c999bf35b0 linter fixes. 2025-11-16 18:39:18 -08:00
4dc135cfe0 Update CHANGELOG for v1.11.2. 2025-11-16 13:18:38 -08:00
790113e189 cmd: refactor for code reuse. 2025-11-16 13:15:08 -08:00
8348c5fd65 Update CHANGELOG. 2025-11-16 11:09:02 -08:00
1eafb638a8 cmd: finish linting fixes 2025-11-16 11:03:12 -08:00
3ad562b6fa cmd: continuing linter fixes 2025-11-16 02:54:02 -08:00
0f77bd49dc cmd: continue lint fixes. 2025-11-16 01:32:19 -08:00
f31d74243f cmd: start linting fixes. 2025-11-16 00:36:19 -08:00
0dcd18c6f1 clean up code
- travisci is long dead
- golangci-lint the repo
2024-12-02 13:47:43 -08:00
024d552293 add circle ci config 2024-12-02 13:26:34 -08:00
9cd2ced695 There are different keys for different hash sizes. 2024-12-02 13:16:32 -08:00
619c08a13f Update travis to latest Go versions. 2020-10-31 08:06:11 -07:00
944a57bf0e Switch to go modules. 2020-10-31 07:41:31 -07:00
0857b29624 Actually support clock mocking. 2020-10-31 07:24:34 -07:00
CodeLingo Bot
e95404bfc5 Fix function comments based on best practices from Effective Go
Signed-off-by: CodeLingo Bot <bot@codelingo.io>
2020-10-30 14:57:11 -07:00
ujjwalsh
924654e7c4 Added Support for Linux on Power 2020-10-30 07:50:32 -07:00
9e0979e07f Support clock mocking.
This addresses #15.
2018-12-07 08:23:01 -08:00
Aaron Bieber
bbc82ff8de Pad non-padded secrets. This lets us continue building on <= go1.8.
- Add tests for secrets using various padding methods.
- Add a new method/test to append padding to non-padded secrets.
2018-04-18 13:39:21 -07:00
Aaron Bieber
5fd928f69a Decode using WithPadding as pointed out by @gl-sergei.
This makes us print the same 6 digits as oathtool for non-padded
secrets like "a6mryljlbufszudtjdt42nh5by".
2018-04-18 13:39:21 -07:00
Aaron Bieber
acefe4a3b9 Don't assume our secret is base32 encoded.
According to https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm
secrets are only base32 encoded in gauthenticator and gauth friendly providers.
2018-04-16 13:14:03 -07:00
a1452cebc9 Travis requires a string for Go 1.10. 2018-04-16 13:03:16 -07:00
6e9812e6f5 Vendor dependencies and add more tests. 2018-04-16 13:03:16 -07:00
Aaron Bieber
8c34415c34 add readme 2018-04-16 12:52:39 -07:00
Paul TREHIOU
2cf2c15def Case insensitive algorithm match 2018-04-16 12:43:27 -07:00
Aaron Bieber
eaad1884d4 Make sure our secret is always uppercase
Non-uppercase secrets that are base32 encoded will fial to decode
unless we upper them.
2017-09-17 18:19:23 -07:00
5d57d844d4 Add license (MIT). 2017-04-13 10:02:20 -07:00
Kyle Isom
31b9d175dd Add travis config. 2017-03-20 14:20:49 -07:00
Aaron Bieber
79e106da2e point to new qr location 2017-03-20 13:18:56 -07:00
Kyle Isom
939b1bc272 Updating imports. 2015-08-12 12:29:34 -07:00
Kyle
89e74f390b Add doc.go, finish YubiKey removal. 2014-04-24 20:43:13 -06:00
Kyle
7881b6fdfc Remove test TOTP client. 2014-04-24 20:40:44 -06:00
Kyle
5bef33245f Remove YubiKey (not currently functional). 2014-04-24 20:37:53 -06:00
Kyle
84250b0501 More documentation. 2014-04-24 20:37:00 -06:00
Kyle Isom
459e9f880f Add function to build Google TOTPs from secret 2014-04-23 16:54:16 -07:00
Kyle Isom
0982f47ce3 Add last night's progress.
Basic functionality for HOTP, TOTP, and YubiKey OTP. Still need YubiKey
HMAC, serialisation, check, and scan.
2013-12-20 17:00:01 -07:00
Kyle Isom
1dec15fd11 add missing files
new files are
	oath_test
	totp code
2013-12-19 00:21:26 -07:00
Kyle Isom
2ee9cae5ba Add basic Google Authenticator TOTP client. 2013-12-19 00:20:00 -07:00
Kyle Isom
dc04475120 HOTP and TOTP-SHA-1 working.
why the frak aren't the SHA-256 and SHA-512 variants working
2013-12-19 00:04:26 -07:00
Kyle Isom
dbbd5116b5 Initial import.
Basic HOTP functionality.
2013-12-18 21:48:14 -07:00
78 changed files with 2473 additions and 1104 deletions

View File

@@ -64,4 +64,4 @@ workflows:
testbuild: testbuild:
jobs: jobs:
- testbuild - testbuild
# - lint - lint

1
.gitignore vendored
View File

@@ -1 +1,2 @@
.idea .idea
cmd/cert-bundler/testdata/pkg/*

View File

@@ -12,12 +12,31 @@
version: "2" version: "2"
output:
sort-order:
- file
- linter
- severity
issues: issues:
# Maximum count of issues with the same text. # Maximum count of issues with the same text.
# Set to 0 to disable. # Set to 0 to disable.
# Default: 3 # Default: 3
max-same-issues: 50 max-same-issues: 50
# Exclude some lints for CLI programs under cmd/ (package main).
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
exclude-rules:
- path: ^cmd/
linters:
- forbidigo
- path: cmd/.*
linters:
- forbidigo
- path: .*/cmd/.*
linters:
- forbidigo
formatters: formatters:
enable: enable:
- goimports # checks if the code and import statements are formatted according to the 'goimports' command - goimports # checks if the code and import statements are formatted according to the 'goimports' command
@@ -73,7 +92,6 @@ linters:
- godoclint # checks Golang's documentation practice - godoclint # checks Golang's documentation practice
- godot # checks if comments end in a period - godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- goprintffuncname # checks that printf-like functions are named with f at the end
- gosec # inspects source code for security problems - gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
@@ -230,6 +248,10 @@ linters:
check-type-assertions: true check-type-assertions: true
exclude-functions: exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write - (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
exhaustive: exhaustive:
# Program elements to check for exhaustiveness. # Program elements to check for exhaustiveness.
@@ -321,6 +343,12 @@ linters:
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link # https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
- no-unused-link - no-unused-link
gosec:
excludes:
- G104 # handled by errcheck
- G301
- G306
govet: govet:
# Enable all analyzers. # Enable all analyzers.
# Default: false # Default: false
@@ -356,6 +384,12 @@ linters:
- os.WriteFile - os.WriteFile
- prometheus.ExponentialBuckets.* - prometheus.ExponentialBuckets.*
- prometheus.LinearBuckets - prometheus.LinearBuckets
ignored-numbers:
- 1
- 2
- 3
- 4
- 8
nakedret: nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns. # Make an issue if func has more lines of code than this setting, and it has naked returns.
@@ -424,6 +458,10 @@ linters:
# Omit embedded fields from selector expression. # Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008 # https://staticcheck.dev/docs/checks/#QF1008
- -QF1008 - -QF1008
# We often explicitly enable old/deprecated ciphers for research.
- -SA1019
# Covered by revive.
- -ST1003
usetesting: usetesting:
# Enable/disable `os.TempDir()` detections. # Enable/disable `os.TempDir()` detections.
@@ -442,6 +480,8 @@ linters:
rules: rules:
- path: 'ahash/ahash.go' - path: 'ahash/ahash.go'
linters: [ staticcheck, gosec ] linters: [ staticcheck, gosec ]
- path: 'twofactor/.*.go'
linters: [ exhaustive, mnd, revive ]
- path: 'backoff/backoff_test.go' - path: 'backoff/backoff_test.go'
linters: [ testpackage ] linters: [ testpackage ]
- path: 'dbg/dbg_test.go' - path: 'dbg/dbg_test.go'
@@ -452,6 +492,8 @@ linters:
linters: [ testableexamples ] linters: [ testableexamples ]
- path: 'main.go' - path: 'main.go'
linters: [ forbidigo, mnd, reassign ] linters: [ forbidigo, mnd, reassign ]
- path: 'cmd/cruntar/main.go'
linters: [ unparam ]
- source: 'TODO' - source: 'TODO'
linters: [ godot ] linters: [ godot ]
- text: 'should have a package comment' - text: 'should have a package comment'

View File

@@ -1,5 +1,57 @@
CHANGELOG CHANGELOG
v1.13.1 - 2025-11-17
Add:
- Dockerfile for cert-bundler.
v1.13.0 - 2025-11-16
Add:
- cmd/certser: print serial numbers for certificates.
- lib/HexEncode: add a new hex encode function handling multiple output
formats, including with and without colons.
v1.12.4 - 2025-11-16
Changed:
- Linting fixes for twofactor that were previously masked.
v1.12.3 erroneously tagged and pushed
v1.12.2 - 2025-11-16
Changed:
- add rsc.io/qr dependency for twofactor.
v1.12.1 - 2025-11-16
Changed:
- twofactor: Remove go.{mod,sum}.
v1.12.0 - 2025-11-16
Added
- twofactor: the github.com/kisom/twofactor repo has been subtree'd
into this repo.
v1.11.2 - 2025-11-16
Changed
- cmd/ski, cmd/csrpubdump, cmd/tlskeypair: centralize
certificate/private-key/CSR parsing by reusing certlib helpers.
This reduces duplication and improves consistency across commands.
- csr: CSR parsing in the above commands now uses certlib.ParseCSR,
which verifies CSR signatures (behavioral hardening compared to
prior parsing without signature verification).
v1.11.1 - 2025-11-16
Changed
- cmd: complete linting fixes across programs; no functional changes.
v1.11.0 - 2025-11-15 v1.11.0 - 2025-11-15
Added Added

View File

@@ -2,39 +2,52 @@ GOUTILS
This is a collection of small utility code I've written in Go; the `cmd/` This is a collection of small utility code I've written in Go; the `cmd/`
directory has a number of command-line utilities. Rather than keep all directory has a number of command-line utilities. Rather than keep all
of these in superfluous repositories of their own, or rewriting them of these in superfluous repositories of their own or rewriting them
for each project, I'm putting them here. for each project, I'm putting them here.
The project can be built with the standard Go tooling, or it can be built The project can be built with the standard Go tooling.
with Bazel.
Contents: Contents:
ahash/ Provides hashes from string algorithm specifiers. ahash/ Provides hashes from string algorithm specifiers.
assert/ Error handling, assertion-style. assert/ Error handling, assertion-style.
backoff/ Implementation of an intelligent backoff strategy. backoff/ Implementation of an intelligent backoff strategy.
cache/ Implementations of various caches.
lru/ Least-recently-used cache.
mru/ Most-recently-used cache.
certlib/ Library for working with TLS certificates.
cmd/ cmd/
atping/ Automated TCP ping, meant for putting in cronjobs. atping/ Automated TCP ping, meant for putting in cronjobs.
certchain/ Display the certificate chain from a ca-signed/ Validate whether a certificate is signed by a CA.
TLS connection. cert-bundler/
Create certificate bundles from a source of PEM
certificates.
cert-revcheck/
Check whether a certificate has been revoked or is
expired.
certchain/ Display the certificate chain from a TLS connection.
certdump/ Dump certificate information. certdump/ Dump certificate information.
certexpiry/ Print a list of certificate subjects and expiry times certexpiry/ Print a list of certificate subjects and expiry times
or warn about certificates expiring within a certain or warn about certificates expiring within a certain
window. window.
certverify/ Verify a TLS X.509 certificate, optionally printing certverify/ Verify a TLS X.509 certificate file, optionally printing
the time to expiry and checking for revocations. the time to expiry and checking for revocations.
clustersh/ Run commands or transfer files across multiple clustersh/ Run commands or transfer files across multiple
servers via SSH. servers via SSH.
cruntar/ Untar an archive with hard links, copying instead of cruntar/ (Un)tar an archive with hard links, copying instead of
linking. linking.
csrpubdump/ Dump the public key from an X.509 certificate request. csrpubdump/ Dump the public key from an X.509 certificate request.
data_sync/ Sync the user's homedir to external storage. data_sync/ Sync the user's homedir to external storage.
diskimg/ Write a disk image to a device. diskimg/ Write a disk image to a device.
dumpbytes/ Dump the contents of a file as hex bytes, printing it as
a Go []byte literal.
eig/ EEPROM image generator. eig/ EEPROM image generator.
fragment/ Print a fragment of a file. fragment/ Print a fragment of a file.
host/ Go imlpementation of the host(1) command.
jlp/ JSON linter/prettifier. jlp/ JSON linter/prettifier.
kgz/ Custom gzip compressor / decompressor that handles 99% kgz/ Custom gzip compressor / decompressor that handles 99%
of my use cases. of my use cases.
minmax/ Generate a minmax code for use in uLisp.
parts/ Simple parts database management for my collection of parts/ Simple parts database management for my collection of
electronic components. electronic components.
pem2bin/ Dump the binary body of a PEM-encoded block. pem2bin/ Dump the binary body of a PEM-encoded block.
@@ -44,41 +57,45 @@ Contents:
in a bundle. in a bundle.
renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash. renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash.
rhash/ Compute the digest of remote files. rhash/ Compute the digest of remote files.
rolldie/ Roll some dice.
showimp/ List the external (e.g. non-stdlib and outside the showimp/ List the external (e.g. non-stdlib and outside the
current working directory) imports for a Go file. current working directory) imports for a Go file.
ski Display the SKI for PEM-encoded TLS material. ski Display the SKI for PEM-encoded TLS material.
sprox/ Simple TCP proxy. sprox/ Simple TCP proxy.
stealchain/ Dump the verified chain from a TLS stealchain/ Dump the verified chain from a TLS connection to a
connection to a server. server.
stealchain- Dump the verified chain from a TLS stealchain-server/
server/ connection from a client. Dump the verified chain from a TLS connection from
from a client.
subjhash/ Print or match subject info from a certificate. subjhash/ Print or match subject info from a certificate.
tlsinfo/ Print information about a TLS connection (the TLS version
and cipher suite).
tlskeypair/ Check whether a TLS certificate and key file match. tlskeypair/ Check whether a TLS certificate and key file match.
utc/ Convert times to UTC. utc/ Convert times to UTC.
yamll/ A small YAML linter. yamll/ A small YAML linter.
zsearch/ Search for a string in directory of gzipped files.
config/ A simple global configuration system where configuration config/ A simple global configuration system where configuration
data is pulled from a file or an environment variable data is pulled from a file or an environment variable
transparently. transparently.
iniconf/ A simple INI-style configuration system.
dbg/ A debug printer. dbg/ A debug printer.
die/ Death of a program. die/ Death of a program.
fileutil/ Common file functions. fileutil/ Common file functions.
lib/ Commonly-useful functions for writing Go programs. lib/ Commonly-useful functions for writing Go programs.
log/ A syslog library.
logging/ A logging library. logging/ A logging library.
mwc/ MultiwriteCloser implementation. mwc/ MultiwriteCloser implementation.
rand/ Utilities for working with math/rand.
sbuf/ A byte buffer that can be wiped. sbuf/ A byte buffer that can be wiped.
seekbuf/ A read-seekable byte buffer. seekbuf/ A read-seekable byte buffer.
syslog/ Syslog-type logging. syslog/ Syslog-type logging.
tee/ Emulate tee(1)'s functionality in io.Writers. tee/ Emulate tee(1)'s functionality in io.Writers.
testio/ Various I/O utilities useful during testing. testio/ Various I/O utilities useful during testing.
testutil/ Various utility functions useful during testing. twofactor/ Two-factor authentication.
Each program should have a small README in the directory with more Each program should have a small README in the directory with more
information. information.
All code here is licensed under the ISC license. All code here is licensed under the Apache 2.0 license.
Error handling Error handling
-------------- --------------
@@ -99,7 +116,7 @@ Examples:
``` ```
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
// sentinel match // sentinel match:
if errors.Is(err, certerr.ErrEmptyCertificate) { if errors.Is(err, certerr.ErrEmptyCertificate) {
// handle empty input // handle empty input
} }
@@ -116,5 +133,3 @@ if err != nil {
} }
} }
``` ```
Avoid including sensitive data (keys, passwords, tokens) in error messages.

View File

@@ -91,7 +91,7 @@ func TestReset(t *testing.T) {
} }
} }
const decay = 5 * time.Millisecond const decay = 25 * time.Millisecond
const maxDuration = 10 * time.Millisecond const maxDuration = 10 * time.Millisecond
const interval = time.Millisecond const interval = time.Millisecond

View File

@@ -458,8 +458,6 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
} }
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") { if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
if password != nil { if password != nil {
// nolintlint requires rationale:
//nolint:staticcheck // legacy RFC1423 PEM encryption supported for backward compatibility when caller supplies a password
return x509.DecryptPEMBlock(keyDER, password) return x509.DecryptPEMBlock(keyDER, password)
} }
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey) return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)

View File

@@ -0,0 +1,28 @@
# Build and runtime image for cert-bundler
# Usage (from repo root or cmd/cert-bundler directory):
# docker build -t cert-bundler:latest -f cmd/cert-bundler/Dockerfile .
# docker run --rm -v "$PWD":/work cert-bundler:latest
# This expects a /work/bundle.yaml file in the mounted directory and
# will write generated bundles to /work/bundle.
# Build stage
FROM golang:1.24.3-alpine AS build
WORKDIR /src
# Copy go module files and download dependencies first for better caching
RUN go install git.wntrmute.dev/kyle/goutils/cmd/cert-bundler@v1.13.1 && \
mv /go/bin/cert-bundler /usr/local/bin/cert-bundler
# Runtime stage (kept as golang:alpine per requirement)
FROM golang:1.24.3-alpine
# Create a work directory that users will typically mount into
WORKDIR /work
VOLUME ["/work"]
# Copy the built binary from the builder stage
COPY --from=build /usr/local/bin/cert-bundler /usr/local/bin/cert-bundler
# Default command: read bundle.yaml from current directory and output to ./bundle
ENTRYPOINT ["/usr/local/bin/cert-bundler"]
CMD ["-c", "/work/bundle.yaml", "-o", "/work/bundle"]

View File

@@ -17,8 +17,9 @@ import (
"strings" "strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
"git.wntrmute.dev/kyle/goutils/certlib"
) )
// Config represents the top-level YAML configuration. // Config represents the top-level YAML configuration.
@@ -299,12 +300,18 @@ func prepareArchiveFiles(
) ([]fileEntry, error) { ) ([]fileEntry, error) {
var archiveFiles []fileEntry var archiveFiles []fileEntry
// Track used filenames to avoid collisions inside archives
usedNames := make(map[string]int)
// Handle a single bundle file // Handle a single bundle file
if outputs.IncludeSingle && len(singleFileCerts) > 0 { if outputs.IncludeSingle && len(singleFileCerts) > 0 {
files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true) files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode single bundle: %w", err) return nil, fmt.Errorf("failed to encode single bundle: %w", err)
} }
for i := range files {
files[i].name = makeUniqueName(files[i].name, usedNames)
}
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
@@ -316,6 +323,9 @@ func prepareArchiveFiles(
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err) return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err)
} }
for i := range files {
files[i].name = makeUniqueName(files[i].name, usedNames)
}
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
} }
@@ -323,8 +333,9 @@ func prepareArchiveFiles(
// Generate manifest if requested // Generate manifest if requested
if outputs.Manifest { if outputs.Manifest {
manifestContent := generateManifest(archiveFiles) manifestContent := generateManifest(archiveFiles)
manifestName := makeUniqueName("MANIFEST", usedNames)
archiveFiles = append(archiveFiles, fileEntry{ archiveFiles = append(archiveFiles, fileEntry{
name: "MANIFEST", name: manifestName,
content: manifestContent, content: manifestContent,
}) })
} }
@@ -573,3 +584,29 @@ func generateHashFile(path string, files []string) error {
return nil return nil
} }
// makeUniqueName ensures that each file name within the archive is unique by appending
// an incremental numeric suffix before the extension when collisions occur.
// Example: "root.pem" -> "root-2.pem", "root-3.pem", etc.
func makeUniqueName(name string, used map[string]int) string {
// If unused, mark and return as-is
if _, ok := used[name]; !ok {
used[name] = 1
return name
}
ext := filepath.Ext(name)
base := strings.TrimSuffix(name, ext)
// Track a counter per base+ext key
key := base + ext
counter := max(used[key], 1)
for {
counter++
candidate := fmt.Sprintf("%s-%d%s", base, counter, ext)
if _, exists := used[candidate]; !exists {
used[key] = counter
used[candidate] = 1
return candidate
}
}
}

View File

@@ -1,197 +0,0 @@
This project is an exploration into the utility of Jetbrains' Junie
to write smaller but tedious programs.
Task: build a certificate bundling tool in cmd/cert-bundler. It
creates archives of certificates chains.
A YAML file for this looks something like:
``` yaml
config:
hashes: bundle.sha256
expiry: 1y
chains:
core_certs:
certs:
- root: roots/core-ca.pem
intermediates:
- int/cca1.pem
- int/cca2.pem
- int/cca3.pem
- root: roots/ssh-ca.pem
intermediates:
- ssh/ssh_dmz1.pem
- ssh/ssh_internal.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
```
Some requirements:
1. First, all the certificates should be loaded.
2. For each root, each of the indivudal intermediates should be
checked to make sure they are properly signed by the root CA.
3. The program should optionally take an expiration period (defaulting
to one year), specified in config.expiration, and if any certificate
is within that expiration period, a warning should be printed.
4. If outputs.include_single is true, all certificates under chains
should be concatenated into a single file.
5. If outputs.include_individual is true, all certificates under
chains should be included at the root level (e.g. int/cca2.pem
would be cca2.pem in the archive).
6. If bundle.manifest is true, a "MANIFEST" file is created with
SHA256 sums of each file included in the archive.
7. For each of the formats, create an archive file in the output
directory (specified with `-o`) with that format.
- If zip is included, create a .zip file.
- If tgz is included, create a .tar.gz file with default compression
levels.
- All archive files should include any generated files (single
and/or individual) in the top-level directory.
8. In the output directory, create a file with the same name as
config.hashes that contains the SHA256 sum of all files created.
-----
The outputs.include_single and outputs.include_individual describe
what should go in the final archive. If both are specified, the output
archive should include both a single bundle.pem and each individual
certificate, for example.
-----
As it stands, given the following `bundle.yaml`:
``` yaml
config:
hashes: bundle.sha256
expiry: 1y
chains:
core_certs:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
- root: pems/isrg-root-x1.pem
intermediates:
- pems/le-e7.pem
outputs:
include_single: true
include_individual: false
manifest: true
formats:
- zip
- tgz
google_certs:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
outputs:
include_single: true
include_individual: false
manifest: true
formats:
- tgz
lets_encrypt:
certs:
- root: pems/isrg-root-x1.pem
intermediates:
- pems/le-e7.pem
outputs:
include_single: false
include_individual: true
manifest: false
formats:
- zip
```
The program outputs the following files:
- bundle.sha256
- core_certs_0.tgz (contains individual certs)
- core_certs_0.zip (contains individual certs)
- core_certs_1.tgz (contains core_certs.pem)
- core_certs_1.zip (contains core_certs.pem)
- google_certs_0.tgz
- lets_encrypt_0.zip
It should output
- bundle.sha256
- core_certs.tgz
- core_certs.zip
- google_certs.tgz
- lets_encrypt.zip
core_certs.* should contain `bundle.pem` and all the individual
certs. There should be no _$n$ variants of archives.
-----
Add an additional field to outputs: encoding. It should accept one of
`der`, `pem`, or `both`. If `der`, certificates should be output as a
`.crt` file containing a DER-encoded certificate. If `pem`, certificates
should be output as a `.pem` file containing a PEM-encoded certificate.
If both, both the `.crt` and `.pem` certificate should be included.
For example, given the previous config, if `encoding` is der, the
google_certs.tgz archive should contain
- bundle.crt
- MANIFEST
Or with lets_encrypt.zip:
- isrg-root-x1.crt
- le-e7.crt
However, if `encoding` is pem, the lets_encrypt.zip archive should contain:
- isrg-root-x1.pem
- le-e7.pem
And if it `encoding` is both, the lets_encrypt.zip archive should contain:
- isrg-root-x1.crt
- isrg-root-x1.pem
- le-e7.crt
- le-e7.pem
-----
The tgz format should output a `.tar.gz` file instead of a `.tgz` file.
-----
Move the format extensions to a global variable.
-----
Write a README.txt with a description of the bundle.yaml format.
Additionally, update the help text for the program (e.g. with `-h`)
to provide the same detailed information.
-----
It may be easier to embed the README.txt in the program on build.
-----
For the archive (tar.gz and zip) writers, make sure errors are
checked at the end, and don't just defer the close operations.

View File

@@ -2,6 +2,19 @@ config:
hashes: bundle.sha256 hashes: bundle.sha256
expiry: 1y expiry: 1y
chains: chains:
weird:
certs:
- root: pems/gts-r1.pem
intermediates:
- pems/goog-wr2.pem
- root: pems/isrg-root-x1.pem
outputs:
include_single: true
include_individual: true
manifest: true
formats:
- zip
- tgz
core_certs: core_certs:
certs: certs:
- root: pems/gts-r1.pem - root: pems/gts-r1.pem

View File

@@ -1,4 +0,0 @@
5ed8bf9ed693045faa8a5cb0edc4a870052e56aef6291ce8b1604565affbc2a4 core_certs.zip
e59eddc590d2f7b790a87c5b56e81697088ab54be382c0e2c51b82034006d308 core_certs.tgz
51b9b63b1335118079e90700a3a5b847c363808e9116e576ca84f301bc433289 google_certs.tgz
3d1910ca8835c3ded1755a8c7d6c48083c2f3ff68b2bfbf932aaf27e29d0a232 lets_encrypt.zip

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,14 +1,15 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"flag"
"errors" "errors"
"flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
@@ -23,6 +24,13 @@ var (
verbose bool verbose bool
) )
var (
strOK = "OK"
strExpired = "EXPIRED"
strRevoked = "REVOKED"
strUnknown = "UNKNOWN"
)
func main() { func main() {
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal") flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects") flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
@@ -42,16 +50,16 @@ func main() {
for _, target := range flag.Args() { for _, target := range flag.Args() {
status, err := processTarget(target) status, err := processTarget(target)
switch status { switch status {
case "OK": case strOK:
fmt.Printf("%s: OK\n", target) fmt.Printf("%s: %s\n", target, strOK)
case "EXPIRED": case strExpired:
fmt.Printf("%s: EXPIRED: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strExpired, err)
exitCode = 1 exitCode = 1
case "REVOKED": case strRevoked:
fmt.Printf("%s: REVOKED\n", target) fmt.Printf("%s: %s\n", target, strRevoked)
exitCode = 1 exitCode = 1
case "UNKNOWN": case strUnknown:
fmt.Printf("%s: UNKNOWN: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
if hardfail { if hardfail {
// In hardfail, treat unknown as failure // In hardfail, treat unknown as failure
exitCode = 1 exitCode = 1
@@ -67,74 +75,77 @@ func processTarget(target string) (string, error) {
return checkFile(target) return checkFile(target)
} }
// Not a file; treat as site
return checkSite(target) return checkSite(target)
} }
func checkFile(path string) (string, error) { func checkFile(path string) (string, error) {
in, err := ioutil.ReadFile(path) // Prefer high-level helpers from certlib to load certificates from disk
if err != nil { if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
return "UNKNOWN", err
}
// Try PEM first; if that fails, try single DER cert
certs, err := certlib.ReadCertificates(in)
if err != nil || len(certs) == 0 {
cert, _, derr := certlib.ReadCertificate(in)
if derr != nil || cert == nil {
if err == nil {
err = derr
}
return "UNKNOWN", err
}
return evaluateCert(cert)
}
// Evaluate the first certificate (leaf) by default // Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0]) return evaluateCert(certs[0])
}
cert, err := certlib.LoadCertificate(path)
if err != nil || cert == nil {
return strUnknown, err
}
return evaluateCert(cert)
} }
func checkSite(hostport string) (string, error) { func checkSite(hostport string) (string, error) {
// Use certlib/hosts to parse host/port (supports https URLs and host:port) // Use certlib/hosts to parse host/port (supports https URLs and host:port)
target, err := hosts.ParseHost(hostport) target, err := hosts.ParseHost(hostport)
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
d := &net.Dialer{Timeout: timeout} d := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}) tcfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: target.Host,
} // #nosec G402 -- CLI tool only verifies revocation
td := &tls.Dialer{NetDialer: d, Config: tcfg}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := td.DialContext(ctx, "tcp", target.String())
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() tconn, ok := conn.(*tls.Conn)
if !ok {
return strUnknown, errors.New("connection is not TLS")
}
state := tconn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
return "UNKNOWN", errors.New("no peer certificates presented") return strUnknown, errors.New("no peer certificates presented")
} }
return evaluateCert(state.PeerCertificates[0]) return evaluateCert(state.PeerCertificates[0])
} }
func evaluateCert(cert *x509.Certificate) (string, error) { func evaluateCert(cert *x509.Certificate) (string, error) {
// Expiry check // Delegate validity and revocation checks to certlib/revoke helper.
now := time.Now() // It returns revoked=true for both revoked and expired/not-yet-valid.
if !now.Before(cert.NotAfter) { // Map those cases back to our statuses using the returned error text.
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
}
if !now.After(cert.NotBefore) {
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
}
// Revocation check using certlib/revoke
revoked, ok, err := revoke.VerifyCertificateError(cert) revoked, ok, err := revoke.VerifyCertificateError(cert)
if revoked { if revoked {
// If revoked is true, ok will be true per implementation, err may describe why if err != nil {
return "REVOKED", err msg := err.Error()
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
strings.Contains(msg, "not valid until") {
return strExpired, err
}
}
return strRevoked, err
} }
if !ok { if !ok {
// Revocation status could not be determined // Revocation status could not be determined
return "UNKNOWN", err return strUnknown, err
} }
return "OK", nil return strOK, nil
} }

View File

@@ -1,11 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"os"
"regexp" "regexp"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -20,20 +23,26 @@ func main() {
server += ":443" server += ":443"
} }
var chain string d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", server)
conn, err := tls.Dial("tcp", server, nil)
die.If(err) die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
details := conn.ConnectionState() details := conn.ConnectionState()
var chain strings.Builder
for _, cert := range details.PeerCertificates { for _, cert := range details.PeerCertificates {
p := pem.Block{ p := pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
} }
chain += string(pem.EncodeToMemory(&p)) chain.Write(pem.EncodeToMemory(&p))
} }
fmt.Println(chain) fmt.Fprintln(os.Stdout, chain.String())
} }
} }

View File

@@ -1,7 +1,9 @@
//lint:file-ignore SA1019 allow strict compatibility for old certs
package main package main
import ( import (
"bytes" "bytes"
"context"
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@@ -101,30 +103,30 @@ func extUsage(ext []x509.ExtKeyUsage) string {
} }
func showBasicConstraints(cert *x509.Certificate) { func showBasicConstraints(cert *x509.Certificate) {
fmt.Printf("\tBasic constraints: ") fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid { if cert.BasicConstraintsValid {
fmt.Printf("valid") fmt.Fprint(os.Stdout, "valid")
} else { } else {
fmt.Printf("invalid") fmt.Fprint(os.Stdout, "invalid")
} }
if cert.IsCA { if cert.IsCA {
fmt.Printf(", is a CA certificate") fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid { if !cert.BasicConstraintsValid {
fmt.Printf(" (basic constraint failure)") fmt.Fprint(os.Stdout, " (basic constraint failure)")
} }
} else { } else {
fmt.Printf("is not a CA certificate") fmt.Fprint(os.Stdout, "is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Printf(" (key encipherment usage enabled!)") fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
} }
} }
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) { if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Printf(", max path length %d", cert.MaxPathLen) fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
} }
fmt.Printf("\n") fmt.Fprintln(os.Stdout)
} }
const oneTrueDateFormat = "2006-01-02T15:04:05-0700" const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
@@ -136,39 +138,41 @@ var (
func wrapPrint(text string, indent int) { func wrapPrint(text string, indent int) {
tabs := "" tabs := ""
for i := 0; i < indent; i++ { var tabsSb140 strings.Builder
tabs += "\t" for range indent {
tabsSb140.WriteString("\t")
} }
tabs += tabsSb140.String()
fmt.Printf(tabs+"%s\n", wrap(text, indent)) fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
} }
func displayCert(cert *x509.Certificate) { func displayCert(cert *x509.Certificate) {
fmt.Println("CERTIFICATE") fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash { if showHash {
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0)) fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
} }
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0)) fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0)) fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm), fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm)) sigAlgoHash(cert.SignatureAlgorithm))
fmt.Println("Details:") fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1) wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber) fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 { if len(cert.AuthorityKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
} }
if len(cert.SubjectKeyId) > 0 { if len(cert.SubjectKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
} }
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1) wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat)) fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage)) fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 { if len(cert.ExtKeyUsage) > 0 {
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage)) fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
} }
showBasicConstraints(cert) showBasicConstraints(cert)
@@ -221,13 +225,13 @@ func displayAllCerts(in []byte, leafOnly bool) {
if err != nil { if err != nil {
certs, _, err = certlib.ParseCertificatesDER(in, "") certs, _, err = certlib.ParseCertificatesDER(in, "")
if err != nil { if err != nil {
lib.Warn(err, "failed to parse certificates") _, _ = lib.Warn(err, "failed to parse certificates")
return return
} }
} }
if len(certs) == 0 { if len(certs) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -243,29 +247,45 @@ func displayAllCerts(in []byte, leafOnly bool) {
func displayAllCertsWeb(uri string, leafOnly bool) { func displayAllCertsWeb(uri string, leafOnly bool) {
ci := getConnInfo(uri) ci := getConnInfo(uri)
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig()) d := &tls.Dialer{Config: permissiveConfig()}
nc, err := d.DialContext(context.Background(), "tcp", ci.Addr)
if err != nil { if err != nil {
lib.Warn(err, "couldn't connect to %s", ci.Addr) _, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr)
return
}
conn, ok := nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return return
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
conn.Close() if err = conn.Close(); err != nil {
_, _ = lib.Warn(err, "couldn't close TLS connection")
}
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host)) d = &tls.Dialer{Config: verifyConfig(ci.Host)}
nc, err = d.DialContext(context.Background(), "tcp", ci.Addr)
if err == nil { if err == nil {
conn, ok = nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return
}
err = conn.VerifyHostname(ci.Host) err = conn.VerifyHostname(ci.Host)
if err == nil { if err == nil {
state = conn.ConnectionState() state = conn.ConnectionState()
} }
conn.Close() conn.Close()
} else { } else {
lib.Warn(err, "TLS verification error with server name %s", ci.Host) _, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host)
} }
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -275,14 +295,14 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
if len(state.VerifiedChains) == 0 { if len(state.VerifiedChains) == 0 {
lib.Warnx("no verified chains found; using peer chain") _, _ = lib.Warnx("no verified chains found; using peer chain")
for i := range state.PeerCertificates { for i := range state.PeerCertificates {
displayCert(state.PeerCertificates[i]) displayCert(state.PeerCertificates[i])
} }
} else { } else {
fmt.Println("TLS chain verified successfully.") fmt.Fprintln(os.Stdout, "TLS chain verified successfully.")
for i := range state.VerifiedChains { for i := range state.VerifiedChains {
fmt.Printf("--- Verified certificate chain %d ---\n", i+1) fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n")
for j := range state.VerifiedChains[i] { for j := range state.VerifiedChains[i] {
displayCert(state.VerifiedChains[i][j]) displayCert(state.VerifiedChains[i][j])
} }
@@ -290,6 +310,32 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
} }
func shouldReadStdin(argc int, argv []string) bool {
if argc == 0 {
return true
}
if argc == 1 && argv[0] == "-" {
return true
}
return false
}
func readStdin(leafOnly bool) {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.ReplaceAll(certs, []byte(`\n`), []byte{0xa})
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
}
func main() { func main() {
var leafOnly bool var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents") flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
@@ -297,32 +343,23 @@ func main() {
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate") flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse() flag.Parse()
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") { if shouldReadStdin(flag.NArg(), flag.Args()) {
certs, err := io.ReadAll(os.Stdin) readStdin(leafOnly)
if err != nil { return
lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
} }
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1)
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
} else {
for _, filename := range flag.Args() { for _, filename := range flag.Args() {
fmt.Printf("--%s ---\n", filename) fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
if strings.HasPrefix(filename, "https://") { if strings.HasPrefix(filename, "https://") {
displayAllCertsWeb(filename, leafOnly) displayAllCertsWeb(filename, leafOnly)
} else { } else {
in, err := os.ReadFile(filename) in, err := os.ReadFile(filename)
if err != nil { if err != nil {
lib.Warn(err, "couldn't read certificate") _, _ = lib.Warn(err, "couldn't read certificate")
continue continue
} }
displayAllCerts(in, leafOnly) displayAllCerts(in, leafOnly)
} }
} }
}
} }

View File

@@ -13,6 +13,11 @@ import (
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\)," // following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,") // "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{ var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature", x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment", x509.KeyUsageContentCommitment: "content committment",
@@ -38,30 +43,24 @@ var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageOCSPSigning: "ocsp signing", x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc", x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc", x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
} x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
switch a {
case x509.RSA:
return "RSA"
case x509.ECDSA:
return "ECDSA"
case x509.DSA:
return "DSA"
default:
return "unknown public key algorithm"
}
} }
func sigAlgoPK(a x509.SignatureAlgorithm) string { func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a { switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA" return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA" return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256: case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA" return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default: default:
return "unknown public key algorithm" return "unknown public key algorithm"
} }
@@ -76,11 +75,21 @@ func sigAlgoHash(a x509.SignatureAlgorithm) string {
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1: case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1" return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256: case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return "SHA256" return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384: case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384" return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512: case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return "SHA512" return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default: default:
return "unknown hash algorithm" return "unknown hash algorithm"
} }
@@ -90,9 +99,11 @@ const maxLine = 78
func makeIndent(n int) string { func makeIndent(n int) string {
s := " " s := " "
for i := 0; i < n; i++ { var sSb97 strings.Builder
s += " " for range n {
sSb97.WriteString(" ")
} }
s += sSb97.String()
return s return s
} }
@@ -100,7 +111,7 @@ func indentLen(n int) int {
return 4 + (8 * n) return 4 + (8 * n)
} }
// this isn't real efficient, but that's not a problem here // this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string { func wrap(s string, indent int) string {
if indent > 3 { if indent > 3 {
indent = 3 indent = 3
@@ -123,9 +134,11 @@ func wrap(s string, indent int) string {
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb130 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb130.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb130.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -136,14 +149,14 @@ func dumpHex(in []byte) string {
func permissiveConfig() *tls.Config { func permissiveConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} } // #nosec G402
} }
// verifyConfig returns a config that will verify the connection. // verifyConfig returns a config that will verify the connection.
func verifyConfig(hostname string) *tls.Config { func verifyConfig(hostname string) *tls.Config {
return &tls.Config{ return &tls.Config{
ServerName: hostname, ServerName: hostname,
} } // #nosec G402
} }
type connInfo struct { type connInfo struct {

View File

@@ -5,7 +5,6 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
} }
func expires(cert *x509.Certificate) time.Duration { func expires(cert *x509.Certificate) time.Duration {
return cert.NotAfter.Sub(time.Now()) return time.Until(cert.NotAfter)
} }
func inDanger(cert *x509.Certificate) bool { func inDanger(cert *x509.Certificate) bool {
@@ -81,15 +80,15 @@ func main() {
flag.Parse() flag.Parse()
for _, file := range flag.Args() { for _, file := range flag.Args() {
in, err := ioutil.ReadFile(file) in, err := os.ReadFile(file)
if err != nil { if err != nil {
lib.Warn(err, "failed to read file") _, _ = lib.Warn(err, "failed to read file")
continue continue
} }
certs, err := certlib.ParseCertificatesPEM(in) certs, err := certlib.ParseCertificatesPEM(in)
if err != nil { if err != nil {
lib.Warn(err, "while parsing certificates") _, _ = lib.Warn(err, "while parsing certificates")
continue continue
} }

51
cmd/certser/main.go Normal file
View File

@@ -0,0 +1,51 @@
package main
import (
"crypto/x509"
"flag"
"fmt"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const displayInt lib.HexEncodeMode = iota
func parseDisplayMode(mode string) lib.HexEncodeMode {
mode = strings.ToLower(mode)
if mode == "int" {
return displayInt
}
return lib.ParseHexEncodeMode(mode)
}
func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
if mode == displayInt {
return cert.SerialNumber.String()
}
return lib.HexEncode(cert.SerialNumber.Bytes(), mode)
}
func main() {
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date")
flag.Parse()
displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() {
cert, err := certlib.LoadCertificate(arg)
die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode))
if *showExpiry {
fmt.Printf(" (%s)", cert.NotAfter.Format("2006-01-02"))
}
fmt.Println()
}
}

View File

@@ -4,13 +4,11 @@ import (
"crypto/x509" "crypto/x509"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke" "git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
@@ -30,83 +28,116 @@ func printRevocation(cert *x509.Certificate) {
} }
} }
func main() { type appConfig struct {
var caFile, intFile string caFile, intFile string
var forceIntermediateBundle, revexp, verbose bool forceIntermediateBundle bool
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`") revexp, verbose bool
flag.StringVar(&intFile, "i", "", "intermediate `bundle`") }
flag.BoolVar(&forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") func parseFlags() appConfig {
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information") var cfg appConfig
flag.BoolVar(&verbose, "v", false, "verbose") flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
flag.Parse() flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse()
return cfg
}
func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile == "" {
return x509.SystemCertPool()
}
var roots *x509.CertPool
if caFile != "" {
var err error
if verbose { if verbose {
fmt.Println("[+] loading root certificates from", caFile) fmt.Println("[+] loading root certificates from", caFile)
} }
roots, err = certlib.LoadPEMCertPool(caFile) return certlib.LoadPEMCertPool(caFile)
die.If(err) }
}
var ints *x509.CertPool func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if intFile != "" { if intFile == "" {
var err error return x509.NewCertPool(), nil
}
if verbose { if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile) fmt.Println("[+] loading intermediate certificates from", intFile)
} }
ints, err = certlib.LoadPEMCertPool(caFile) // Note: use intFile here (previously used caFile mistakenly)
die.If(err) return certlib.LoadPEMCertPool(intFile)
} else { }
ints = x509.NewCertPool()
}
if flag.NArg() != 1 { func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
lib.ProgName())
}
fileData, err := ioutil.ReadFile(flag.Arg(0))
die.If(err)
chain, err := certlib.ParseCertificatesPEM(fileData)
die.If(err)
if verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 {
if !forceIntermediateBundle {
for _, intermediate := range chain[1:] { for _, intermediate := range chain[1:] {
if verbose { if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId) fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
} }
pool.AddCert(intermediate)
}
}
ints.AddCert(intermediate) func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
}
}
}
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Intermediates: ints, Intermediates: ints,
Roots: roots, Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
_, err := cert.Verify(opts)
return err
}
_, err = cert.Verify(opts) func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err) return err
os.Exit(1)
} }
if verbose { ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
fileData, err := os.ReadFile(flag.Arg(0))
if err != nil {
return err
}
chain, err := certlib.ParseCertificatesPEM(fileData)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK") fmt.Println("OK")
} }
if revexp { if cfg.revexp {
printRevocation(cert) printRevocation(cert)
} }
return nil
}
func main() {
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
} }

View File

@@ -2,6 +2,8 @@ package main
import ( import (
"bufio" "bufio"
"context"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
} }
func sshAgent() ssh.AuthMethod { func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil { if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers) return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
} }
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
} }
} }
func logError(host string, err error, format string, args ...interface{}) { func logError(host string, err error, format string, args ...any) {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err) log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
} }
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
} }
shutdown = append(shutdown, session.Close) shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close() session.Close()
logError(host, err, "request for pty failed") logError(host, err, "request for pty failed")
return return
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")

View File

@@ -10,6 +10,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/fileutil" "git.wntrmute.dev/kyle/goutils/fileutil"
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
if verbose { if verbose {
fmt.Printf("\tchmod %0#o\n", hdr.Mode) fmt.Printf("\tchmod %0#o\n", hdr.Mode)
} }
err := file.Chmod(os.FileMode(hdr.Mode)) err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
if err != nil { if err != nil {
return err return err
} }
@@ -48,54 +49,71 @@ func linkTarget(target, top string) string {
return target return target
} }
return filepath.Clean(filepath.Join(target, top)) return filepath.Clean(filepath.Join(top, target))
} }
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error { // safeJoin joins base and elem and ensures the resulting path does not escape base.
if verbose { func safeJoin(base, elem string) (string, error) {
fmt.Println(hdr.Name) cleanBase := filepath.Clean(base)
joined := filepath.Clean(filepath.Join(cleanBase, elem))
absBase, err := filepath.Abs(cleanBase)
if err != nil {
return "", err
} }
filePath := filepath.Clean(filepath.Join(top, hdr.Name)) absJoined, err := filepath.Abs(joined)
switch hdr.Typeflag { if err != nil {
case tar.TypeReg: return "", err
}
rel, err := filepath.Rel(absBase, absJoined)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
}
return joined, nil
}
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
file, err := os.Create(filePath) file, err := os.Create(filePath)
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
_, err = io.Copy(file, tfr) if _, err = io.Copy(file, tfr); err != nil {
if err != nil {
return err return err
} }
return setupFile(hdr, file)
}
err = setupFile(hdr, file) func handleTypeLink(hdr *tar.Header, top, filePath string) error {
if err != nil {
return err
}
case tar.TypeLink:
file, err := os.Create(filePath) file, err := os.Create(filePath)
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
source, err := os.Open(hdr.Linkname) srcPath, err := safeJoin(top, hdr.Linkname)
if err != nil { if err != nil {
return err return err
} }
source, err := os.Open(srcPath)
_, err = io.Copy(file, source)
if err != nil { if err != nil {
return err return err
} }
defer source.Close()
err = setupFile(hdr, file) if _, err = io.Copy(file, source); err != nil {
if err != nil {
return err return err
} }
case tar.TypeSymlink: return setupFile(hdr, file)
}
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
if !fileutil.ValidateSymlink(hdr.Linkname, top) { if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s", return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
hdr.Linkname, top)
} }
path := linkTarget(hdr.Linkname, top) path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok { if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
@@ -103,18 +121,33 @@ func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
} else if err != nil { } else if err != nil {
return err return err
} }
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
}
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath) func handleTypeDir(hdr *tar.Header, filePath string) error {
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
}
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
if verbose {
fmt.Println(hdr.Name)
}
filePath, err := safeJoin(top, hdr.Name)
if err != nil { if err != nil {
return err return err
} }
switch hdr.Typeflag {
case tar.TypeReg:
return handleTypeReg(tfr, hdr, filePath)
case tar.TypeLink:
return handleTypeLink(hdr, top, filePath)
case tar.TypeSymlink:
return handleTypeSymlink(hdr, top, filePath)
case tar.TypeDir: case tar.TypeDir:
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode)) return handleTypeDir(hdr, filePath)
if err != nil {
return err
} }
}
return nil return nil
} }
@@ -261,16 +294,16 @@ func main() {
die.If(err) die.If(err)
tfr := tar.NewReader(r) tfr := tar.NewReader(r)
var hdr *tar.Header
for { for {
hdr, err := tfr.Next() hdr, err = tfr.Next()
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
die.If(err) die.If(err)
err = processFile(tfr, hdr, top) err = processFile(tfr, hdr, top)
die.If(err) die.If(err)
} }
r.Close() r.Close()

View File

@@ -7,9 +7,9 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "os"
"log"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -17,17 +17,10 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
in, err := ioutil.ReadFile(fileName) in, err := os.ReadFile(fileName)
die.If(err) die.If(err)
if p, _ := pem.Decode(in); p != nil { csr, _, err := certlib.ParseCSR(in)
if p.Type != "CERTIFICATE REQUEST" {
log.Fatal("INVALID FILE TYPE")
}
in = p.Bytes
}
csr, err := x509.ParseCertificateRequest(in)
die.If(err) die.If(err)
out, err := x509.MarshalPKIXPublicKey(csr.PublicKey) out, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
@@ -48,8 +41,8 @@ func main() {
Bytes: out, Bytes: out,
} }
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644) err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.\n", fileName+".pub") fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
return err return err
} }
cmd := exec.Command(path, args...) cmd := exec.CommandContext(context.Background(), path, args...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
@@ -163,7 +164,6 @@ func init() {
} }
func main() { func main() {
var logLevel, mountDir, syncDir, target string var logLevel, mountDir, syncDir, target string
var dryRun, quietMode, noSyslog, verboseRsync bool var dryRun, quietMode, noSyslog, verboseRsync bool
@@ -219,7 +219,7 @@ func main() {
if excludeFile != "" { if excludeFile != "" {
defer func() { defer func() {
log.Infof("removing exclude file %s", excludeFile) log.Infof("removing exclude file %s", excludeFile)
if err := os.Remove(excludeFile); err != nil { if rmErr := os.Remove(excludeFile); rmErr != nil {
log.Warningf("failed to remove temp file %s", excludeFile) log.Warningf("failed to remove temp file %s", excludeFile)
} }
}() }()

View File

@@ -19,39 +19,37 @@ var (
debug = dbg.New() debug = dbg.New()
) )
func openImage(imageFile string) (*os.File, []byte, error) {
func openImage(imageFile string) (image *os.File, hash []byte, err error) { f, err := os.Open(imageFile)
image, err = os.Open(imageFile)
if err != nil { if err != nil {
return return nil, nil, err
} }
hash, err = ahash.SumReader(hAlgo, image) h, err := ahash.SumReader(hAlgo, f)
if err != nil { if err != nil {
return return nil, nil, err
} }
_, err = image.Seek(0, 0) if _, err = f.Seek(0, 0); err != nil {
if err != nil { return nil, nil, err
return
} }
debug.Printf("%s %x\n", imageFile, hash) debug.Printf("%s %x\n", imageFile, h)
return return f, h, nil
} }
func openDevice(devicePath string) (device *os.File, err error) { func openDevice(devicePath string) (*os.File, error) {
fi, err := os.Stat(devicePath) fi, err := os.Stat(devicePath)
if err != nil { if err != nil {
return return nil, err
} }
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode()) device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
if err != nil { if err != nil {
return return nil, err
} }
return return device, nil
} }
func main() { func main() {
@@ -105,12 +103,12 @@ func main() {
die.If(err) die.If(err)
if !bytes.Equal(deviceHash, hash) { if !bytes.Equal(deviceHash, hash) {
fmt.Fprintln(os.Stderr, "Hash mismatch:") buf := &bytes.Buffer{}
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash) fmt.Fprintln(buf, "Hash mismatch:")
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash) fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
os.Exit(1) fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
die.With(buf.String())
} }
debug.Println("OK") debug.Println("OK")
os.Exit(0)
} }

View File

@@ -1,30 +1,33 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/die"
"io" "io"
"os" "os"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
) )
func usage(w io.Writer, exc int) { func usage(w io.Writer, exc int) {
fmt.Fprintln(w, `usage: dumpbytes <file>`) fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
os.Exit(exc) os.Exit(exc)
} }
func printBytes(buf []byte) { func printBytes(buf []byte) {
fmt.Printf("\t") fmt.Printf("\t")
for i := 0; i < len(buf); i++ { for i := range buf {
fmt.Printf("0x%02x, ", buf[i]) fmt.Printf("0x%02x, ", buf[i])
} }
fmt.Println() fmt.Println()
} }
func dumpFile(path string, indentLevel int) error { func dumpFile(path string, indentLevel int) error {
indent := "" var indent strings.Builder
for i := 0; i < indentLevel; i++ { for range indentLevel {
indent += "\t" indent.WriteByte('\t')
} }
file, err := os.Open(path) file, err := os.Open(path)
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
defer file.Close() defer file.Close()
fmt.Printf("%svar buffer = []byte{\n", indent) fmt.Printf("%svar buffer = []byte{\n", indent.String())
var n int
for { for {
buf := make([]byte, 8) buf := make([]byte, 8)
n, err := file.Read(buf) n, err = file.Read(buf)
if err == io.EOF { if errors.Is(err, io.EOF) {
if n > 0 { if n > 0 {
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
break break
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
return err return err
} }
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
fmt.Printf("%s}\n", indent) fmt.Printf("%s}\n", indent.String())
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
// size of a kilobit in bytes // size of a kilobit in bytes.
const kilobit = 128 const kilobit = 128
const pageSize = 4096 const pageSize = 4096
@@ -26,10 +26,10 @@ func main() {
path = flag.Arg(0) path = flag.Arg(0)
} }
fillByte := uint8(*fill) fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
buf := make([]byte, pageSize) buf := make([]byte, pageSize)
for i := 0; i < pageSize; i++ { for i := range pageSize {
buf[i] = fillByte buf[i] = fillByte
} }
@@ -40,7 +40,7 @@ func main() {
die.If(err) die.If(err)
defer file.Close() defer file.Close()
for i := 0; i < pages; i++ { for range pages {
_, err = file.Write(buf) _, err = file.Write(buf)
die.If(err) die.If(err)
} }

View File

@@ -72,15 +72,13 @@ func main() {
if end < start { if end < start {
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values") fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
tmp := end start, end = end, start
end = start
start = tmp
} }
var fmtStr string var fmtStr string
if !*quiet { if !*quiet {
maxLine := fmt.Sprintf("%d", len(lines)) maxLine := strconv.Itoa(len(lines))
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine)) fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
} }
@@ -98,9 +96,9 @@ func main() {
fmtStr += "\n" fmtStr += "\n"
for i := start; !endFunc(i); i++ { for i := start; !endFunc(i); i++ {
if *quiet { if *quiet {
fmt.Println(lines[i]) fmt.Fprintln(os.Stdout, lines[i])
} else { } else {
fmt.Printf(fmtStr, i, lines[i]) fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@@ -8,7 +9,8 @@ import (
) )
func lookupHost(host string) error { func lookupHost(host string) error {
cname, err := net.LookupCNAME(host) r := &net.Resolver{}
cname, err := r.LookupCNAME(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
host = cname host = cname
} }
addrs, err := net.LookupHost(host) addrs, err := r.LookupHost(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Indent(buf, in, "", " ") err = json.Indent(buf, in, "", " ")
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Compact(buf, in) err = json.Compact(buf, in)
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
func usage() { func usage() {
progname := lib.ProgName() progname := lib.ProgName()
fmt.Printf(`Usage: %s [-h] files... fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
%s is used to lint and prettify (or compact) JSON files. The %s is used to lint and prettify (or compact) JSON files. The
files will be updated in-place. files will be updated in-place.
@@ -100,7 +100,6 @@ func usage() {
-h Print this help message. -h Print this help message.
-n Don't prettify; only perform validation. -n Don't prettify; only perform validation.
`, progname, progname) `, progname, progname)
} }
func init() { func init() {

View File

@@ -12,6 +12,9 @@ based on whether the source filename ends in ".gz".
Flags: Flags:
-l level Compression level (0-9). Only meaninful when -l level Compression level (0-9). Only meaninful when
compressing a file. compressing a file.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.

View File

@@ -40,26 +40,42 @@ func compress(path, target string, level int) error {
return nil return nil
} }
func uncompress(path, target string) error { func uncompress(path, target string, unrestrict bool) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return fmt.Errorf("opening file for read: %w", err) return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
fi, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("reading file stats: %w", err)
}
maxDecompressionSize := fi.Size() * 32
gzipUncompressor, err := gzip.NewReader(sourceFile) gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil { if err != nil {
return fmt.Errorf("reading gzip headers: %w", err) return fmt.Errorf("reading gzip headers: %w", err)
} }
defer gzipUncompressor.Close() defer gzipUncompressor.Close()
var reader io.Reader = &io.LimitedReader{
R: gzipUncompressor,
N: maxDecompressionSize,
}
if unrestrict {
reader = gzipUncompressor
}
destFile, err := os.Create(target) destFile, err := os.Create(target)
if err != nil { if err != nil {
return fmt.Errorf("opening file for write: %w", err) return fmt.Errorf("opening file for write: %w", err)
} }
defer destFile.Close() defer destFile.Close()
_, err = io.Copy(destFile, gzipUncompressor) _, err = io.Copy(destFile, reader)
if err != nil { if err != nil {
return fmt.Errorf("uncompressing file: %w", err) return fmt.Errorf("uncompressing file: %w", err)
} }
@@ -87,8 +103,8 @@ func isDir(path string) bool {
file, err := os.Open(path) file, err := os.Open(path)
if err == nil { if err == nil {
defer file.Close() defer file.Close()
stat, err := file.Stat() stat, err2 := file.Stat()
if err != nil { if err2 != nil {
return false return false
} }
@@ -132,8 +148,11 @@ func main() {
var level int var level int
var path string var path string
var target = "." var target = "."
var err error
var unrestrict bool
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level") flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.Parse() flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 { if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -147,20 +166,22 @@ func main() {
} }
if strings.HasSuffix(path, gzipExt) { if strings.HasSuffix(path, gzipExt) {
target, err := pathForUncompressing(path, target) target, err = pathForUncompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
err = uncompress(path, target) err = uncompress(path, target, unrestrict)
if err != nil { if err != nil {
os.Remove(target) os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
} else { return
target, err := pathForCompressing(path, target) }
target, err = pathForCompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
@@ -172,5 +193,4 @@ func main() {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
}
} }

View File

@@ -40,14 +40,14 @@ func main() {
usage() usage()
} }
min, err := strconv.Atoi(flag.Arg(1)) minVal, err := strconv.Atoi(flag.Arg(1))
dieIf(err) dieIf(err)
max, err := strconv.Atoi(flag.Arg(2)) maxVal, err := strconv.Atoi(flag.Arg(2))
dieIf(err) dieIf(err)
code := kind << 6 code := kind << 6
code += (min << 3) code += (minVal << 3)
code += max code += maxVal
fmt.Printf("%0o\n", code) fmt.Fprintf(os.Stdout, "%0o\n", code)
} }

View File

@@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@@ -47,7 +46,7 @@ func help(w io.Writer) {
} }
func loadDatabase() { func loadDatabase() {
data, err := ioutil.ReadFile(dbFile) data, err := os.ReadFile(dbFile)
if err != nil && os.IsNotExist(err) { if err != nil && os.IsNotExist(err) {
partsDB = &database{ partsDB = &database{
Version: dbVersion, Version: dbVersion,
@@ -74,7 +73,7 @@ func writeDB() {
data, err := json.Marshal(partsDB) data, err := json.Marshal(partsDB)
die.If(err) die.If(err)
err = ioutil.WriteFile(dbFile, data, 0644) err = os.WriteFile(dbFile, data, 0644)
die.If(err) die.If(err)
} }

View File

@@ -4,14 +4,13 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
var ext = ".bin" var ext = ".bin"
func stripPEM(path string) error { func stripPEM(path string) error {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return err
} }
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n") fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
} }
return ioutil.WriteFile(path+ext, p.Bytes, 0644) return os.WriteFile(path+ext, p.Bytes, 0644)
} }
func main() { func main() {

View File

@@ -3,8 +3,7 @@ package main
import ( import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "io"
"io/ioutil"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -21,9 +20,9 @@ func main() {
path := flag.Arg(0) path := flag.Arg(0)
if path == "-" { if path == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(flag.Arg(0)) in, err = os.ReadFile(flag.Arg(0))
} }
if err != nil { if err != nil {
lib.Err(lib.ExitFailure, err, "couldn't read file") lib.Err(lib.ExitFailure, err, "couldn't read file")
@@ -33,5 +32,7 @@ func main() {
if p == nil { if p == nil {
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
} }
fmt.Printf("%s", p.Bytes) if _, err = os.Stdout.Write(p.Bytes); err != nil {
lib.Err(lib.ExitFailure, err, "writing body")
}
} }

View File

@@ -70,7 +70,7 @@ func main() {
lib.Err(lib.ExitFailure, err, "failed to read input") lib.Err(lib.ExitFailure, err, "failed to read input")
} }
case argc > 1: case argc > 1:
for i := 0; i < argc; i++ { for i := range argc {
path := flag.Arg(i) path := flag.Arg(i)
err = copyFile(path, buf) err = copyFile(path, buf)
if err != nil { if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
@@ -13,14 +12,14 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
data, err := ioutil.ReadFile(fileName) data, err := os.ReadFile(fileName)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
continue continue
} }
fmt.Printf("[+] %s:\n", fileName) fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
rest := data[:] rest := data
for { for {
var p *pem.Block var p *pem.Block
p, rest = pem.Decode(rest) p, rest = pem.Decode(rest)
@@ -28,13 +27,14 @@ func main() {
break break
} }
cert, err := x509.ParseCertificate(p.Bytes) var cert *x509.Certificate
cert, err = x509.ParseCertificate(p.Bytes)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
break break
} }
fmt.Printf("\t%+v\n", cert.Subject.CommonName) fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
} }
} }
} }

View File

@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
return hashName(path, encodedHash), nil return hashName(path, encodedHash), nil
} }
func move(dst, src string, force bool) (err error) { func move(dst, src string, force bool) error {
if fileutil.FileDoesExist(dst) && !force { if fileutil.FileDoesExist(dst) && !force {
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst) return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
} }
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
return err return err
} }
defer func(e error) { var retErr error
defer func(e *error) {
dstFile.Close() dstFile.Close()
if e != nil { if *e != nil {
os.Remove(dst) os.Remove(dst)
} }
}(err) }(&retErr)
srcFile, err := os.Open(src) srcFile, err := os.Open(src)
if err != nil { if err != nil {
retErr = err
return err return err
} }
defer srcFile.Close() defer srcFile.Close()
_, err = io.Copy(dstFile, srcFile) if _, err = io.Copy(dstFile, srcFile); err != nil {
if err != nil { retErr = err
return err return err
} }
@@ -94,6 +96,44 @@ func init() {
flag.Usage = func() { usage(os.Stdout) } flag.Usage = func() { usage(os.Stdout) }
} }
type options struct {
dryRun, force, printChanged, verbose bool
}
func processOne(file string, opt options) error {
renamed, err := newName(file)
if err != nil {
_, _ = lib.Warn(err, "failed to get new file name")
return err
}
if opt.verbose && !opt.printChanged {
fmt.Fprintln(os.Stdout, file)
}
if renamed == file {
return nil
}
if !opt.dryRun {
if err = move(renamed, file, opt.force); err != nil {
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
return err
}
}
if opt.printChanged && !opt.verbose {
fmt.Fprintln(os.Stdout, file, "->", renamed)
}
return nil
}
func run(dryRun, force, printChanged, verbose bool, files []string) {
if verbose && printChanged {
printChanged = false
}
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
for _, file := range files {
_ = processOne(file, opt)
}
}
func main() { func main() {
var dryRun, force, printChanged, verbose bool var dryRun, force, printChanged, verbose bool
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision") flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
@@ -102,34 +142,5 @@ func main() {
flag.BoolVar(&verbose, "v", false, "list all processed files") flag.BoolVar(&verbose, "v", false, "list all processed files")
flag.Parse() flag.Parse()
run(dryRun, force, printChanged, verbose, flag.Args())
if verbose && printChanged {
printChanged = false
}
for _, file := range flag.Args() {
renamed, err := newName(file)
if err != nil {
lib.Warn(err, "failed to get new file name")
continue
}
if verbose && !printChanged {
fmt.Println(file)
}
if renamed != file {
if !dryRun {
err = move(renamed, file, force)
if err != nil {
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
continue
}
}
if printChanged && !verbose {
fmt.Println(file, "->", renamed)
}
}
}
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -66,24 +67,25 @@ func main() {
for _, remote := range flag.Args() { for _, remote := range flag.Args() {
u, err := url.Parse(remote) u, err := url.Parse(remote)
if err != nil { if err != nil {
lib.Warn(err, "parsing %s", remote) _, _ = lib.Warn(err, "parsing %s", remote)
continue continue
} }
name := filepath.Base(u.Path) name := filepath.Base(u.Path)
if name == "" { if name == "" {
lib.Warnx("source URL doesn't appear to name a file") _, _ = lib.Warnx("source URL doesn't appear to name a file")
continue continue
} }
resp, err := http.Get(remote) req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
if err != nil { if reqErr != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(reqErr, "building request for %s", remote)
continue continue
} }
client := &http.Client{}
resp, err := client.Do(req)
if err != nil { if err != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(err, "fetching %s", remote)
continue continue
} }

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
sum := 0 sum := 0
var rolls []int var rolls []int
for i := 0; i < count; i++ { for range count {
roll := rand.Intn(sides) + 1 roll := rand.IntN(sides) + 1 // #nosec G404
sum += roll sum += roll
rolls = append(rolls, roll) rolls = append(rolls, roll)
} }

View File

@@ -53,7 +53,7 @@ func init() {
project = wd[len(gopath):] project = wd[len(gopath):]
} }
func walkFile(path string, info os.FileInfo, err error) error { func walkFile(path string, _ os.FileInfo, err error) error {
if ignores[path] { if ignores[path] {
return filepath.SkipDir return filepath.SkipDir
} }
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
return nil return nil
} }
debug.Println(path)
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err != nil { if err != nil {
return err return err
} }
debug.Println(path)
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err2 != nil {
return err2
}
for _, importSpec := range f.Imports { for _, importSpec := range f.Imports {
importPath := strings.Trim(importSpec.Path.Value, `"`) importPath := strings.Trim(importSpec.Path.Value, `"`)
if stdLibRegexp.MatchString(importPath) { switch {
case stdLibRegexp.MatchString(importPath):
debug.Println("standard lib:", importPath) debug.Println("standard lib:", importPath)
continue continue
} else if strings.HasPrefix(importPath, project) { case strings.HasPrefix(importPath, project):
debug.Println("internal import:", importPath) debug.Println("internal import:", importPath)
continue continue
} else if strings.HasPrefix(importPath, "golang.org/") { case strings.HasPrefix(importPath, "golang.org/"):
debug.Println("extended lib:", importPath) debug.Println("extended lib:", importPath)
continue continue
} }
@@ -102,7 +107,7 @@ func main() {
ignores["vendor"] = true ignores["vendor"] = true
} }
for _, word := range strings.Split(ignoreLine, ",") { for word := range strings.SplitSeq(ignoreLine, ",") {
ignores[strings.TrimSpace(word)] = true ignores[strings.TrimSpace(word)] = true
} }

View File

@@ -2,10 +2,9 @@ package main
import ( import (
"bytes" "bytes"
"crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1" // #nosec G505
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
@@ -13,14 +12,19 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) { func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -39,14 +43,14 @@ func init() {
flag.Usage = func() { usage(os.Stderr) } flag.Usage = func() { usage(os.Stderr) }
} }
func parse(path string) (public []byte, kt, ft string) { func parse(path string) ([]byte, string, string) {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
die.If(err) die.If(err)
data = bytes.TrimSpace(data) data = bytes.TrimSpace(data)
p, rest := pem.Decode(data) p, rest := pem.Decode(data)
if len(rest) > 0 { if len(rest) > 0 {
lib.Warnx("trailing data in PEM file") _, _ = lib.Warnx("trailing data in PEM file")
} }
if p == nil { if p == nil {
@@ -55,6 +59,12 @@ func parse(path string) (public []byte, kt, ft string) {
data = p.Bytes data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type { switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY": case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data) public, kt = parseKey(data)
@@ -69,82 +79,79 @@ func parse(path string) (public []byte, kt, ft string) {
die.With("unknown PEM type %s", p.Type) die.With("unknown PEM type %s", p.Type)
} }
return return public, kt, ft
} }
func parseKey(data []byte) (public []byte, kt string) { func parseKey(data []byte) ([]byte, string) {
privInterface, err := x509.ParsePKCS8PrivateKey(data) priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil { if err != nil {
privInterface, err = x509.ParsePKCS1PrivateKey(data) die.If(err)
if err != nil {
privInterface, err = x509.ParseECPrivateKey(data)
if err != nil {
die.With("couldn't parse private key.")
}
}
} }
var priv crypto.Signer var kt string
switch privInterface.(type) { switch priv.Public().(type) {
case *rsa.PrivateKey: case *rsa.PublicKey:
priv = privInterface.(*rsa.PrivateKey) kt = keyTypeRSA
kt = "RSA" case *ecdsa.PublicKey:
case *ecdsa.PrivateKey: kt = keyTypeECDSA
priv = privInterface.(*ecdsa.PrivateKey)
kt = "ECDSA"
default: default:
die.With("unknown private key type %T", privInterface) die.With("unknown private key type %T", priv)
} }
public, err = x509.MarshalPKIXPublicKey(priv.Public()) public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err) die.If(err)
return return public, kt
} }
func parseCertificate(data []byte) (public []byte, kt string) { func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data) cert, err := x509.ParseCertificate(data)
die.If(err) die.If(err)
pub := cert.PublicKey pub := cert.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func parseCSR(data []byte) (public []byte, kt string) { func parseCSR(data []byte) ([]byte, string) {
csr, err := x509.ParseCertificateRequest(data) // Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err) die.If(err)
pub := csr.PublicKey pub := csr.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb153 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb153.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb153.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -172,18 +179,18 @@ func main() {
var subPKI subjectPublicKeyInfo var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI) _, err := asn1.Unmarshal(public, &subPKI)
if err != nil { if err != nil {
lib.Warn(err, "failed to get subject PKI") _, _ = lib.Warn(err, "failed to get subject PKI")
continue continue
} }
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:]) pubHashString := dumpHex(pubHash[:])
if ski == "" { if ski == "" {
ski = pubHashString ski = pubHashString
} }
if shouldMatch && ski != pubHashString { if shouldMatch && ski != pubHashString {
lib.Warnx("%s: SKI mismatch (%s != %s)", _, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString) path, ski, pubHashString)
} }
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft) fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)

View File

@@ -1,16 +1,17 @@
package main package main
import ( import (
"context"
"flag" "flag"
"io" "io"
"log"
"net" "net"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
) )
func proxy(conn net.Conn, inside string) error { func proxy(conn net.Conn, inside string) error {
proxyConn, err := net.Dial("tcp", inside) proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
if err != nil { if err != nil {
return err return err
} }
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
defer conn.Close() defer conn.Close()
go func() { go func() {
io.Copy(conn, proxyConn) _, _ = io.Copy(conn, proxyConn)
}() }()
_, err = io.Copy(proxyConn, conn) _, err = io.Copy(proxyConn, conn)
return err return err
@@ -31,16 +32,22 @@ func main() {
flag.StringVar(&inside, "p", "4000", "inside port") flag.StringVar(&inside, "p", "4000", "inside port")
flag.Parse() flag.Parse()
l, err := net.Listen("tcp", "0.0.0.0:"+outside) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
die.If(err) die.If(err)
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
log.Println(err) _, _ = lib.Warn(err, "accept failed")
continue continue
} }
go proxy(conn, "127.0.0.1:"+inside) go func() {
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
_, _ = lib.Warn(err, "proxy error")
}
}()
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@@ -8,7 +9,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -16,7 +16,7 @@ import (
) )
func main() { func main() {
cfg := &tls.Config{} cfg := &tls.Config{} // #nosec G402
var sysRoot, listenAddr, certFile, keyFile string var sysRoot, listenAddr, certFile, keyFile string
var verify bool var verify bool
@@ -47,7 +47,8 @@ func main() {
} }
cfg.Certificates = append(cfg.Certificates, cert) cfg.Certificates = append(cfg.Certificates, cert)
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) var pemList []byte
pemList, err = os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -59,48 +60,54 @@ func main() {
cfg.RootCAs = roots cfg.RootCAs = roots
} }
l, err := net.Listen("tcp", listenAddr) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
} }
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
continue
} }
handleConn(conn, cfg)
}
}
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
func handleConn(conn net.Conn, cfg *tls.Config) {
defer conn.Close()
raddr := conn.RemoteAddr() raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg) tconn := tls.Server(conn, cfg)
err = tconn.Handshake() if err := tconn.HandshakeContext(context.Background()); err != nil {
if err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err) fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
continue return
} }
cs := tconn.ConnectionState() cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 { if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr) fmt.Printf("[+] %v: no chain presented\n", raddr)
continue return
} }
var chain []byte var chain []byte
for _, cert := range cs.PeerCertificates { for _, cert := range cs.PeerCertificates {
p := &pem.Block{ p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
var nonce [16]byte var nonce [16]byte
_, err = rand.Read(nonce[:]) if _, err := rand.Read(nonce[:]); err != nil {
if err != nil { fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
panic(err) return
} }
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:])) fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
err = ioutil.WriteFile(fname, chain, 0644) if err := os.WriteFile(fname, chain, 0o644); err != nil {
die.If(err) fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname) return
} }
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
} }

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -14,7 +14,7 @@ import (
) )
func main() { func main() {
var cfg = &tls.Config{} var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string var sysRoot, serverName string
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
@@ -23,7 +23,7 @@ func main() {
flag.Parse() flag.Parse()
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) pemList, err := os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -44,10 +44,13 @@ func main() {
if err != nil { if err != nil {
site += ":443" site += ":443"
} }
conn, err := tls.Dial("tcp", site, cfg) d := &tls.Dialer{Config: cfg}
if err != nil { nc, err := d.DialContext(context.Background(), "tcp", site)
fmt.Println(err.Error()) die.If(err)
os.Exit(1)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
cs := conn.ConnectionState() cs := conn.ConnectionState()
@@ -61,8 +64,9 @@ func main() {
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
err = ioutil.WriteFile(site+".pem", chain, 0644) err = os.WriteFile(site+".pem", chain, 0644)
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.pem.\n", site) fmt.Printf("[+] wrote %s.pem.\n", site)
} }
} }

View File

@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
for _, path := range paths { for _, path := range paths {
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
lib.Warn(err, "failed to load certificate from %s", path) _, _ = lib.Warn(err, "failed to load certificate from %s", path)
continue continue
} }
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
} }
var invalid int var invalid int
for { for len(paths) > 0 {
if len(paths) == 0 {
break
}
fst := paths[0] fst := paths[0]
snd := paths[1] snd := paths[1]
paths = paths[2:] paths = paths[2:]
fstCert, err := certlib.LoadCertificate(fst) fstCert, err := certlib.LoadCertificate(fst)
die.If(err) die.If(err)
sndCert, err := certlib.LoadCertificate(snd) sndCert, err := certlib.LoadCertificate(snd)
die.If(err) die.If(err)
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
lib.Warnx("certificates don't match: %s and %s", fst, snd) _, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
invalid++ invalid++
} }
} }

View File

@@ -1,10 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/die"
) )
func main() { func main() {
@@ -13,16 +17,23 @@ func main() {
os.Exit(1) os.Exit(1)
} }
hostPort := os.Args[1] hostPort, err := hosts.ParseHost(os.Args[1])
conn, err := tls.Dial("tcp", hostPort, &tls.Config{ die.If(err)
InsecureSkipVerify: true,
})
if err != nil { d := &tls.Dialer{Config: &tls.Config{
fmt.Printf("Failed to connect to the TLS server: %v\n", err) InsecureSkipVerify: true,
os.Exit(1) }} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
printConnectionDetails(state) printConnectionDetails(state)
} }
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
func tlsVersion(version uint16) string { func tlsVersion(version uint16) string {
switch version { switch version {
case tls.VersionTLS13: case tls.VersionTLS13:
return "TLS 1.3" return "TLS 1.3"
case tls.VersionTLS12: case tls.VersionTLS12:

View File

@@ -11,10 +11,9 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -32,7 +31,7 @@ const (
curveP521 curveP521
) )
func getECCurve(pub interface{}) int { func getECCurve(pub any) int {
switch pub := pub.(type) { switch pub := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
switch pub.Curve { switch pub.Curve {
@@ -52,42 +51,88 @@ func getECCurve(pub interface{}) int {
} }
} }
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) { func loadKey(path string) (crypto.Signer, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, _ := pem.Decode(in) if p, _ := pem.Decode(in); p != nil {
if p != nil {
if !validPEMs[p.Type] { if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type) return nil, errors.New("invalid private key file type " + p.Type)
} }
in = p.Bytes return certlib.ParsePrivateKeyPEM(in)
} }
priv, err := x509.ParsePKCS8PrivateKey(in) return certlib.ParsePrivateKeyDER(in)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(in)
if err != nil {
priv, err = x509.ParseECPrivateKey(in)
if err != nil {
return nil, err
}
}
}
switch priv.(type) {
case *rsa.PrivateKey:
return priv.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return priv.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, errors.New("invalid private key")
} }
func main() { func main() {
@@ -96,7 +141,7 @@ func main() {
flag.StringVar(&certFile, "c", "", "TLS `certificate` file") flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse() flag.Parse()
in, err := ioutil.ReadFile(certFile) in, err := os.ReadFile(certFile)
die.If(err) die.If(err)
p, _ := pem.Decode(in) p, _ := pem.Decode(in)
@@ -112,50 +157,11 @@ func main() {
priv, err := loadKey(keyFile) priv, err := loadKey(keyFile)
die.If(err) die.If(err)
switch pub := priv.Public().(type) { matched, reason := matchKeys(cert, priv)
case *rsa.PublicKey: if matched {
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.") fmt.Println("Match.")
return return
case *ecdsa.PublicKey:
fmt.Println("No match (RSA private key, EC public key).")
os.Exit(1)
} }
case *ecdsa.PublicKey: fmt.Printf("No match (%s).\n", reason)
privCurve := getECCurve(pub)
certCurve := getECCurve(cert.PublicKey)
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
if certCurve == curveRSA {
fmt.Println("No match (private key is EC, certificate is RSA).")
os.Exit(1) os.Exit(1)
} else if privCurve == curveInvalid {
fmt.Println("No match (invalid private key curve).")
os.Exit(1)
} else if privCurve != certCurve {
fmt.Println("No match (EC curves don't match).")
os.Exit(1)
}
certPub := cert.PublicKey.(*ecdsa.PublicKey)
if pub.X.Cmp(certPub.X) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
if pub.Y.Cmp(certPub.Y) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.")
default:
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
os.Exit(1)
}
} }

View File

@@ -201,10 +201,6 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if fromLoc == time.UTC {
}
toLoc = time.UTC toLoc = time.UTC
} }
@@ -257,15 +253,16 @@ func main() {
showTime(time.Now()) showTime(time.Now())
os.Exit(0) os.Exit(0)
case 1: case 1:
if flag.Arg(0) == "-" { switch {
case flag.Arg(0) == "-":
s := bufio.NewScanner(os.Stdin) s := bufio.NewScanner(os.Stdin)
for s.Scan() { for s.Scan() {
times = append(times, s.Text()) times = append(times, s.Text())
} }
} else if flag.Arg(0) == "help" { case flag.Arg(0) == "help":
usageExamples() usageExamples()
} else { default:
times = flag.Args() times = flag.Args()
} }
default: default:

View File

@@ -4,7 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@@ -12,9 +11,8 @@ import (
type empty struct{} type empty struct{}
func errorf(format string, args ...interface{}) { func errorf(path string, err error) {
format += "\n" fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
fmt.Fprintf(os.Stderr, format, args...)
} }
func usage(w io.Writer) { func usage(w io.Writer) {
@@ -44,16 +42,16 @@ func main() {
if flag.NArg() == 1 && flag.Arg(0) == "-" { if flag.NArg() == 1 && flag.Arg(0) == "-" {
path := "stdin" path := "stdin"
in, err := ioutil.ReadAll(os.Stdin) in, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
@@ -65,16 +63,16 @@ func main() {
} }
for _, path := range flag.Args() { for _, path := range flag.Args() {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }

View File

@@ -14,16 +14,16 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"git.wntrmute.dev/kyle/goutils/lib"
) )
const defaultDirectory = ".git/objects" const defaultDirectory = ".git/objects"
func errorf(format string, a ...interface{}) { // maxDecompressedSize limits how many bytes we will decompress from a zlib
fmt.Fprintf(os.Stderr, format, a...) // stream to mitigate decompression bombs (gosec G110).
if format[len(format)-1] != '\n' { // Increase this if you expect larger objects.
fmt.Fprintf(os.Stderr, "\n") const maxDecompressedSize int64 = 64 << 30 // 64 GiB
}
}
func isDir(path string) bool { func isDir(path string) bool {
fi, err := os.Stat(path) fi, err := os.Stat(path)
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
} }
defer zread.Close() defer zread.Close()
_, err = io.Copy(buf, zread) // Protect against decompression bombs by limiting how much we read.
if err != nil { lr := io.LimitReader(zread, maxDecompressedSize+1)
if _, err = buf.ReadFrom(lr); err != nil {
return nil, err return nil, err
} }
if int64(buf.Len()) > maxDecompressedSize {
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
}
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func showFile(path string) { func showFile(path string) {
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return return
} }
@@ -68,37 +72,69 @@ func showFile(path string) {
func searchFile(path string, search *regexp.Regexp) error { func searchFile(path string, search *regexp.Regexp) error {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to open %s", path)
return err return err
} }
defer file.Close() defer file.Close()
zread, err := zlib.NewReader(file) zread, err := zlib.NewReader(file)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to decompress %s", path)
return err return err
} }
defer zread.Close() defer zread.Close()
zbuf := bufio.NewReader(zread) // Limit how much we scan to avoid DoS via huge decompression.
if search.MatchReader(zbuf) { lr := io.LimitReader(zread, maxDecompressedSize+1)
zbuf := bufio.NewReader(lr)
if !search.MatchReader(zbuf) {
return nil
}
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return err return err
} }
fmt.Printf("%s:\n%s\n", path, fileData) fmt.Printf("%s:\n%s\n", path, fileData)
}
return nil return nil
} }
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc { func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error { return func(path string, info os.FileInfo, _ error) error {
if info.Mode().IsRegular() { if !info.Mode().IsRegular() {
return searchFile(path, searchExpr)
}
return nil return nil
} }
return searchFile(path, searchExpr)
}
}
// runSearch compiles the search expression and processes the provided paths.
// It returns an error for fatal conditions; per-file errors are logged.
func runSearch(expr string) error {
search, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("invalid regexp: %w", err)
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
return err2
}
continue
}
if err2 := searchFile(path, search); err2 != nil {
// Non-fatal: keep going, but report it.
lib.Warn(err2, "non-fatal error while searching files")
}
}
return nil
} }
func main() { func main() {
@@ -109,28 +145,10 @@ func main() {
for _, path := range flag.Args() { for _, path := range flag.Args() {
showFile(path) showFile(path)
} }
} else {
search, err := regexp.Compile(*flSearch)
if err != nil {
errorf("Bad regexp: %v", err)
return return
} }
pathList := flag.Args() if err := runSearch(*flSearch); err != nil {
if len(pathList) == 0 { lib.Err(lib.ExitFailure, err, "failed to run search")
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
err := filepath.Walk(path, buildWalker(search))
if err != nil {
errorf("%v", err)
return
}
} else {
searchFile(path, search)
}
}
} }
} }

1
go.mod
View File

@@ -22,4 +22,5 @@ require (
github.com/kr/pretty v0.1.0 // indirect github.com/kr/pretty v0.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect github.com/pkg/errors v0.9.1 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
rsc.io/qr v0.2.0 // indirect
) )

2
go.sum
View File

@@ -44,3 +44,5 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -1,4 +1,4 @@
// +build freebsd darwin,386 netbsd //go:build bsd
package lib package lib

View File

@@ -1,4 +1,4 @@
// +build unix linux openbsd darwin,amd64 //go:build unix || linux || openbsd || (darwin && amd64)
package lib package lib
@@ -18,7 +18,7 @@ type FileTime struct {
func timeSpecToTime(ts unix.Timespec) time.Time { func timeSpecToTime(ts unix.Timespec) time.Time {
// The casts to int64 are needed because on 386, these are int32s. // The casts to int64 are needed because on 386, these are int32s.
return time.Unix(int64(ts.Sec), int64(ts.Nsec)) return time.Unix(ts.Sec, ts.Nsec)
} }
// LoadFileTime returns a FileTime associated with the file. // LoadFileTime returns a FileTime associated with the file.

View File

@@ -2,14 +2,22 @@
package lib package lib
import ( import (
"encoding/hex"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"time" "time"
) )
var progname = filepath.Base(os.Args[0]) var progname = filepath.Base(os.Args[0])
const (
daysInYear = 365
digitWidth = 10
hoursInQuarterDay = 6
)
// ProgName returns what lib thinks the program name is, namely the // ProgName returns what lib thinks the program name is, namely the
// basename of argv0. // basename of argv0.
// //
@@ -20,7 +28,7 @@ func ProgName() string {
// Warnx displays a formatted error message to standard error, à la // Warnx displays a formatted error message to standard error, à la
// warnx(3). // warnx(3).
func Warnx(format string, a ...interface{}) (int, error) { func Warnx(format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
return fmt.Fprintf(os.Stderr, format, a...) return fmt.Fprintf(os.Stderr, format, a...)
@@ -28,7 +36,7 @@ func Warnx(format string, a ...interface{}) (int, error) {
// Warn displays a formatted error message to standard output, // Warn displays a formatted error message to standard output,
// appending the error string, à la warn(3). // appending the error string, à la warn(3).
func Warn(err error, format string, a ...interface{}) (int, error) { func Warn(err error, format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -37,7 +45,7 @@ func Warn(err error, format string, a ...interface{}) (int, error) {
// Errx displays a formatted error message to standard error and exits // Errx displays a formatted error message to standard error and exits
// with the status code from `exit`, à la errx(3). // with the status code from `exit`, à la errx(3).
func Errx(exit int, format string, a ...interface{}) { func Errx(exit int, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
fmt.Fprintf(os.Stderr, format, a...) fmt.Fprintf(os.Stderr, format, a...)
@@ -47,7 +55,7 @@ func Errx(exit int, format string, a ...interface{}) {
// Err displays a formatting error message to standard error, // Err displays a formatting error message to standard error,
// appending the error string, and exits with the status code from // appending the error string, and exits with the status code from
// `exit`, à la err(3). // `exit`, à la err(3).
func Err(exit int, err error, format string, a ...interface{}) { func Err(exit int, err error, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -62,30 +70,30 @@ func Itoa(i int, wid int) string {
// Assemble decimal in reverse order. // Assemble decimal in reverse order.
var b [20]byte var b [20]byte
bp := len(b) - 1 bp := len(b) - 1
for i >= 10 || wid > 1 { for i >= digitWidth || wid > 1 {
wid-- wid--
q := i / 10 q := i / digitWidth
b[bp] = byte('0' + i - q*10) b[bp] = byte('0' + i - q*digitWidth)
bp-- bp--
i = q i = q
} }
// i < 10
b[bp] = byte('0' + i) b[bp] = byte('0' + i)
return string(b[bp:]) return string(b[bp:])
} }
var ( var (
dayDuration = 24 * time.Hour dayDuration = 24 * time.Hour
yearDuration = (365 * dayDuration) + (6 * time.Hour) yearDuration = (daysInYear * dayDuration) + (hoursInQuarterDay * time.Hour)
) )
// Duration returns a prettier string for time.Durations. // Duration returns a prettier string for time.Durations.
func Duration(d time.Duration) string { func Duration(d time.Duration) string {
var s string var s string
if d >= yearDuration { if d >= yearDuration {
years := d / yearDuration years := int64(d / yearDuration)
s += fmt.Sprintf("%dy", years) s += fmt.Sprintf("%dy", years)
d -= years * yearDuration d -= time.Duration(years) * yearDuration
} }
if d >= dayDuration { if d >= dayDuration {
@@ -98,8 +106,116 @@ func Duration(d time.Duration) string {
} }
d %= 1 * time.Second d %= 1 * time.Second
hours := d / time.Hour hours := int64(d / time.Hour)
d -= hours * time.Hour d -= time.Duration(hours) * time.Hour
s += fmt.Sprintf("%dh%s", hours, d) s += fmt.Sprintf("%dh%s", hours, d)
return s return s
} }
type HexEncodeMode uint8
const (
// HexEncodeLower prints the bytes as lowercase hexadecimal.
HexEncodeLower HexEncodeMode = iota + 1
// HexEncodeUpper prints the bytes as uppercase hexadecimal.
HexEncodeUpper
// HexEncodeLowerColon prints the bytes as lowercase hexadecimal
// with colons between each pair of bytes.
HexEncodeLowerColon
// HexEncodeUpperColon prints the bytes as uppercase hexadecimal
// with colons between each pair of bytes.
HexEncodeUpperColon
)
func (m HexEncodeMode) String() string {
switch m {
case HexEncodeLower:
return "lower"
case HexEncodeUpper:
return "upper"
case HexEncodeLowerColon:
return "lcolon"
case HexEncodeUpperColon:
return "ucolon"
default:
panic("invalid hex encode mode")
}
}
func ParseHexEncodeMode(s string) HexEncodeMode {
switch strings.ToLower(s) {
case "lower":
return HexEncodeLower
case "upper":
return HexEncodeUpper
case "lcolon":
return HexEncodeLowerColon
case "ucolon":
return HexEncodeUpperColon
}
panic("invalid hex encode mode")
}
func hexColons(s string) string {
if len(s)%2 != 0 {
fmt.Fprintf(os.Stderr, "hex string: %s\n", s)
fmt.Fprintf(os.Stderr, "hex length: %d\n", len(s))
panic("invalid hex string length")
}
n := len(s)
if n <= 2 {
return s
}
pairCount := n / 2
if n%2 != 0 {
pairCount++
}
var b strings.Builder
b.Grow(n + pairCount - 1)
for i := 0; i < n; i += 2 {
b.WriteByte(s[i])
if i+1 < n {
b.WriteByte(s[i+1])
}
if i+2 < n {
b.WriteByte(':')
}
}
return b.String()
}
func hexEncode(b []byte) string {
s := hex.EncodeToString(b)
if len(s)%2 != 0 {
s = "0" + s
}
return s
}
// HexEncode encodes the given bytes as a hexadecimal string.
func HexEncode(b []byte, mode HexEncodeMode) string {
str := hexEncode(b)
switch mode {
case HexEncodeLower:
return str
case HexEncodeUpper:
return strings.ToUpper(str)
case HexEncodeLowerColon:
return hexColons(str)
case HexEncodeUpperColon:
return strings.ToUpper(hexColons(str))
default:
panic("invalid hex encode mode")
}
}

79
lib/lib_test.go Normal file
View File

@@ -0,0 +1,79 @@
package lib_test
import (
"testing"
"git.wntrmute.dev/kyle/goutils/lib"
)
func TestHexEncode_LowerUpper(t *testing.T) {
b := []byte{0x0f, 0xa1, 0x00, 0xff}
gotLower := lib.HexEncode(b, lib.HexEncodeLower)
if gotLower != "0fa100ff" {
t.Fatalf("lib.HexEncode lower: expected %q, got %q", "0fa100ff", gotLower)
}
gotUpper := lib.HexEncode(b, lib.HexEncodeUpper)
if gotUpper != "0FA100FF" {
t.Fatalf("lib.HexEncode upper: expected %q, got %q", "0FA100FF", gotUpper)
}
}
func TestHexEncode_ColonModes(t *testing.T) {
// Includes leading zero nibble and a zero byte to verify padding and separators
b := []byte{0x0f, 0xa1, 0x00, 0xff}
gotLColon := lib.HexEncode(b, lib.HexEncodeLowerColon)
if gotLColon != "0f:a1:00:ff" {
t.Fatalf("lib.HexEncode colon lower: expected %q, got %q", "0f:a1:00:ff", gotLColon)
}
gotUColon := lib.HexEncode(b, lib.HexEncodeUpperColon)
if gotUColon != "0F:A1:00:FF" {
t.Fatalf("lib.HexEncode colon upper: expected %q, got %q", "0F:A1:00:FF", gotUColon)
}
}
func TestHexEncode_EmptyInput(t *testing.T) {
var b []byte
if got := lib.HexEncode(b, lib.HexEncodeLower); got != "" {
t.Fatalf("empty lower: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpper); got != "" {
t.Fatalf("empty upper: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeLowerColon); got != "" {
t.Fatalf("empty colon lower: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpperColon); got != "" {
t.Fatalf("empty colon upper: expected empty string, got %q", got)
}
}
func TestHexEncode_SingleByte(t *testing.T) {
b := []byte{0x0f}
if got := lib.HexEncode(b, lib.HexEncodeLower); got != "0f" {
t.Fatalf("single byte lower: expected %q, got %q", "0f", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpper); got != "0F" {
t.Fatalf("single byte upper: expected %q, got %q", "0F", got)
}
// For a single byte, colon modes should not introduce separators
if got := lib.HexEncode(b, lib.HexEncodeLowerColon); got != "0f" {
t.Fatalf("single byte colon lower: expected %q, got %q", "0f", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpperColon); got != "0F" {
t.Fatalf("single byte colon upper: expected %q, got %q", "0F", got)
}
}
func TestHexEncode_InvalidModePanics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic for invalid mode, but function returned normally")
}
}()
// 0 is not a valid lib.HexEncodeMode (valid modes start at 1)
_ = lib.HexEncode([]byte{0x01}, lib.HexEncodeMode(0))
}

View File

@@ -1,6 +1,7 @@
package logging package logging
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
) )
@@ -61,9 +62,9 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) {
if err != nil { if err != nil {
if closeErr := fl.Close(); closeErr != nil { if closeErr := fl.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err) return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err))
} }
return nil, err return nil, fmt.Errorf("failed to open error log: %w", err)
} }
fl.LogWriter = NewLogWriter(fl.fo, fl.fe) fl.LogWriter = NewLogWriter(fl.fo, fl.fe)

33
twofactor/README.md Normal file
View File

@@ -0,0 +1,33 @@
## `twofactor`
[![GoDoc](https://godoc.org/github.com/gokyle/twofactor?status.svg)](https://godoc.org/github.com/gokyle/twofactor)
### Author
`twofactor` was written by Kyle Isom <kyle@tyrfingr.is>.
### License
```
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

5
twofactor/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package twofactor implements two-factor authentication.
//
// Currently supported are RFC 4226 HOTP one-time passwords and
// RFC 6238 TOTP SHA-1 one-time passwords.
package twofactor

103
twofactor/hotp.go Normal file
View File

@@ -0,0 +1,103 @@
package twofactor
import (
"crypto"
"crypto/sha1" // #nosec G505 - required by RFC
"encoding/base32"
"io"
"net/url"
"strconv"
"strings"
)
// HOTP represents an RFC-4226 Hash-based One Time Password instance.
type HOTP struct {
*OATH
}
// NewHOTP takes the key, the initial counter value, and the number
// of digits (typically 6 or 8) and returns a new HOTP instance.
func NewHOTP(key []byte, counter uint64, digits int) *HOTP {
return &HOTP{
OATH: &OATH{
key: key,
counter: counter,
size: digits,
hash: sha1.New,
algo: crypto.SHA1,
},
}
}
// Type returns OATH_HOTP.
func (otp *HOTP) Type() Type {
return OATH_HOTP
}
// OTP returns the next OTP and increments the counter.
func (otp *HOTP) OTP() string {
code := otp.OATH.OTP(otp.counter)
otp.counter++
return code
}
// URL returns an HOTP URL (i.e. for putting in a QR code).
func (otp *HOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *HOTP) SetProvider(provider string) {
otp.provider = provider
}
// GenerateGoogleHOTP generates a new HOTP instance as used by
// Google Authenticator.
func GenerateGoogleHOTP() *HOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewHOTP(key, 0, 6)
}
func hotpFromURL(u *url.URL) (*HOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var counter uint64
if scounter := v.Get("counter"); scounter != "" {
var err error
counter, err = strconv.ParseUint(scounter, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewHOTP(key, counter, digits)
return otp, label, nil
}
// QR generates a new QR code for the HOTP.
func (otp *HOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}

View File

@@ -0,0 +1,58 @@
package twofactor
import (
"testing"
)
var testKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
var rfcHotpKey = []byte("12345678901234567890")
var rfcHotpExpected = []string{
"755224",
"287082",
"359152",
"969429",
"338314",
"254676",
"287922",
"162583",
"399871",
"520489",
}
// This test runs through the test cases presented in the RFC, and
// ensures that this implementation is in compliance.
func TestHotpRFC(t *testing.T) {
otp := NewHOTP(rfcHotpKey, 0, 6)
for i := range rfcHotpExpected {
if otp.Counter() != uint64(i) {
t.Fatalf("twofactor: invalid counter (should be %d, is %d",
i, otp.Counter())
}
code := otp.OTP()
if code == "" {
t.Fatal("twofactor: failed to produce an OTP")
} else if code != rfcHotpExpected[i] {
t.Logf("twofactor: invalid OTP\n")
t.Logf("\tExpected: %s\n", rfcHotpExpected[i])
t.Logf("\t Actual: %s\n", code)
t.Fatalf("\t Counter: %d\n", otp.counter)
}
}
}
// This test uses a different key than the test cases in the RFC,
// but runs through the same test cases to ensure that they fail as
// expected.
func TestHotpBadRFC(t *testing.T) {
otp := NewHOTP(testKey, 0, 6)
for i := range rfcHotpExpected {
code := otp.OTP()
switch code {
case "":
t.Error("twofactor: failed to produce an OTP")
case rfcHotpExpected[i]:
t.Error("twofactor: should not have received a valid OTP")
}
}
}

150
twofactor/oath.go Normal file
View File

@@ -0,0 +1,150 @@
package twofactor
import (
"crypto"
"crypto/hmac"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"net/url"
"strconv"
"rsc.io/qr"
)
const defaultSize = 6
// OATH provides a baseline structure for the two OATH algorithms.
type OATH struct {
key []byte
counter uint64
size int
hash func() hash.Hash
algo crypto.Hash
provider string
}
// Size returns the output size (in characters) of the password.
func (o *OATH) Size() int {
return o.size
}
// Counter returns the OATH token's counter.
func (o *OATH) Counter() uint64 {
return o.counter
}
// SetCounter updates the OATH token's counter to a new value.
func (o *OATH) SetCounter(counter uint64) {
o.counter = counter
}
// Key returns the token's secret key.
func (o *OATH) Key() []byte {
return o.key
}
// Hash returns the token's hash function.
func (o *OATH) Hash() func() hash.Hash {
return o.hash
}
// URL constructs a URL appropriate for the token (i.e. for use in a
// QR code).
func (o *OATH) URL(t Type, label string) string {
secret := base32.StdEncoding.EncodeToString(o.key)
u := url.URL{}
v := url.Values{}
u.Scheme = "otpauth"
switch t {
case OATH_HOTP:
u.Host = "hotp"
case OATH_TOTP:
u.Host = "totp"
}
u.Path = label
v.Add("secret", secret)
if o.Counter() != 0 && t == OATH_HOTP {
v.Add("counter", strconv.FormatUint(o.Counter(), 10))
}
if o.Size() != defaultSize {
v.Add("digits", strconv.Itoa(o.Size()))
}
switch o.algo {
case crypto.SHA256:
v.Add("algorithm", "SHA256")
case crypto.SHA512:
v.Add("algorithm", "SHA512")
}
if o.provider != "" {
v.Add("provider", o.provider)
}
u.RawQuery = v.Encode()
return u.String()
}
var digits = []int64{
0: 1,
1: 10,
2: 100,
3: 1000,
4: 10000,
5: 100000,
6: 1000000,
7: 10000000,
8: 100000000,
9: 1000000000,
10: 10000000000,
}
// OTP top-level type should provide a counter; for example, HOTP
// will provide the counter directly while TOTP will provide the
// time-stepped counter.
func (o *OATH) OTP(counter uint64) string {
var ctr [8]byte
binary.BigEndian.PutUint64(ctr[:], counter)
var mod int64 = 1
if len(digits) > o.size {
for i := 1; i <= o.size; i++ {
mod *= 10
}
} else {
mod = digits[o.size]
}
h := hmac.New(o.hash, o.key)
h.Write(ctr[:])
dt := truncate(h.Sum(nil)) % mod
fmtStr := fmt.Sprintf("%%0%dd", o.size)
return fmt.Sprintf(fmtStr, dt)
}
// truncate contains the DT function from the RFC; this is used to
// deterministically select a sequence of 4 bytes from the HMAC
// counter hash.
func truncate(in []byte) int64 {
offset := int(in[len(in)-1] & 0xF)
p := in[offset : offset+4]
var binCode int32
binCode = int32((p[0] & 0x7f)) << 24
binCode += int32((p[1] & 0xff)) << 16
binCode += int32((p[2] & 0xff)) << 8
binCode += int32((p[3] & 0xff))
return int64(binCode) & 0x7FFFFFFF
}
// QR generates a byte slice containing the a QR code encoded as a
// PNG with level Q error correction.
func (o *OATH) QR(t Type, label string) ([]byte, error) {
u := o.URL(t, label)
code, err := qr.Encode(u, qr.Q)
if err != nil {
return nil, err
}
return code.PNG(), nil
}

View File

@@ -0,0 +1,27 @@
package twofactor
import (
"testing"
)
var sha1Hmac = []byte{
0x1f, 0x86, 0x98, 0x69, 0x0e,
0x02, 0xca, 0x16, 0x61, 0x85,
0x50, 0xef, 0x7f, 0x19, 0xda,
0x8e, 0x94, 0x5b, 0x55, 0x5a,
}
var truncExpect int64 = 0x50ef7f19
// This test runs through the truncation example given in the RFC.
func TestTruncate(t *testing.T) {
if result := truncate(sha1Hmac); result != truncExpect {
t.Fatalf("hotp: expected truncate -> %d, saw %d\n",
truncExpect, result)
}
sha1Hmac[19]++
if result := truncate(sha1Hmac); result == truncExpect {
t.Fatal("hotp: expected truncation to fail")
}
}

86
twofactor/otp.go Normal file
View File

@@ -0,0 +1,86 @@
package twofactor
import (
"crypto/rand"
"errors"
"fmt"
"hash"
"net/url"
)
type Type uint
const (
OATH_HOTP = iota
OATH_TOTP
)
// PRNG is an io.Reader that provides a cryptographically secure
// random byte stream.
var PRNG = rand.Reader
var (
ErrInvalidURL = errors.New("twofactor: invalid URL")
ErrInvalidAlgo = errors.New("twofactor: invalid algorithm")
)
// OTP represents a one-time password token -- whether a
// software taken (as in the case of Google Authenticator) or a
// hardware token (as in the case of a YubiKey).
type OTP interface {
// Returns the current counter value; the meaning of the
// returned value is algorithm-specific.
Counter() uint64
// Set the counter to a specific value.
SetCounter(uint64)
// the secret key contained in the OTP
Key() []byte
// generate a new OTP
OTP() string
// the output size of the OTP
Size() int
// the hash function used by the OTP
Hash() func() hash.Hash
// Returns the type of this OTP.
Type() Type
}
func otpString(otp OTP) string {
var typeName string
switch otp.Type() {
case OATH_HOTP:
typeName = "OATH-HOTP"
case OATH_TOTP:
typeName = "OATH-TOTP"
default:
typeName = "UNKNOWN"
}
return fmt.Sprintf("%s, %d", typeName, otp.Size())
}
// FromURL constructs a new OTP token from a URL string.
func FromURL(otpURL string) (OTP, string, error) {
u, err := url.Parse(otpURL)
if err != nil {
return nil, "", err
}
if u.Scheme != "otpauth" {
return nil, "", ErrInvalidURL
}
switch u.Host {
case "totp":
return totpFromURL(u)
case "hotp":
return hotpFromURL(u)
default:
return nil, "", ErrInvalidURL
}
}

View File

@@ -0,0 +1,126 @@
package twofactor
import (
"io"
"testing"
)
func TestHOTPString(t *testing.T) {
hotp := NewHOTP(nil, 0, 6)
hotpString := otpString(hotp)
if hotpString != "OATH-HOTP, 6" {
t.Fatal("twofactor: invalid OTP string")
}
}
// This test generates a new OTP, outputs the URL for that OTP,
// and attempts to parse that URL. It verifies that the two OTPs
// are the same, and that they produce the same output.
func TestURL(t *testing.T) {
var ident = "testuser@foo"
otp := NewHOTP(testKey, 0, 6)
url := otp.URL("testuser@foo")
otp2, id, err := FromURL(url)
switch {
case err != nil:
t.Fatal("hotp: failed to parse HOTP URL\n")
case id != ident:
t.Logf("hotp: bad label\n")
t.Logf("\texpected: %s\n", ident)
t.Fatalf("\t actual: %s\n", id)
case otp2.Counter() != otp.Counter():
t.Logf("hotp: OTP counters aren't synced\n")
t.Logf("\toriginal: %d\n", otp.Counter())
t.Fatalf("\t second: %d\n", otp2.Counter())
}
code1 := otp.OTP()
code2 := otp2.OTP()
if code1 != code2 {
t.Logf("hotp: mismatched OTPs\n")
t.Logf("\texpected: %s\n", code1)
t.Fatalf("\t actual: %s\n", code2)
}
// There's not much we can do test the QR code, except to
// ensure it doesn't fail.
_, err = otp.QR(ident)
if err != nil {
t.Fatalf("hotp: failed to generate QR code PNG (%v)\n", err)
}
// This should fail because the maximum size of an alphanumeric
// QR code with the lowest-level of error correction should
// max out at 4296 bytes. 8k may be a bit overkill... but it
// gets the job done. The value is read from the PRNG to
// increase the likelihood that the returned data is
// uncompressible.
var tooBigIdent = make([]byte, 8192)
_, err = io.ReadFull(PRNG, tooBigIdent)
if err != nil {
t.Fatalf("hotp: failed to read identity (%v)\n", err)
} else if _, err = otp.QR(string(tooBigIdent)); err == nil {
t.Fatal("hotp: QR code should fail to encode oversized URL")
}
}
// This test makes sure we can generate codes for padded and non-padded
// entries.
func TestPaddedURL(t *testing.T) {
var urlList = []string{
"otpauth://hotp/?secret=ME",
"otpauth://hotp/?secret=MEFR",
"otpauth://hotp/?secret=MFRGG",
"otpauth://hotp/?secret=MFRGGZA",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by=======",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by%3D%3D%3D%3D%3D%3D%3D",
}
var codeList = []string{
"413198",
"770938",
"670717",
"402378",
"069864",
"069864",
"069864",
}
for i := range urlList {
if o, id, err := FromURL(urlList[i]); err != nil {
t.Log("hotp: URL should have parsed successfully (id=", id, ")")
t.Logf("\turl was: %s\n", urlList[i])
t.Fatalf("\t%s, %s\n", o.OTP(), id)
} else {
code2 := o.OTP()
if code2 != codeList[i] {
t.Logf("hotp: mismatched OTPs\n")
t.Logf("\texpected: %s\n", codeList[i])
t.Fatalf("\t actual: %s\n", code2)
}
}
}
}
// This test attempts a variety of invalid urls against the parser
// to ensure they fail.
func TestBadURL(t *testing.T) {
var urlList = []string{
"http://google.com",
"",
"-",
"foo",
"otpauth:/foo/bar/baz",
"://",
"otpauth://hotp/?digits=",
"otpauth://hotp/?secret=MFRGGZDF&digits=ABCD",
"otpauth://hotp/?secret=MFRGGZDF&counter=ABCD",
}
for i := range urlList {
if _, _, err := FromURL(urlList[i]); err == nil {
t.Log("hotp: URL should not have parsed successfully")
t.Fatalf("\turl was: %s\n", urlList[i])
}
}
}

172
twofactor/totp.go Normal file
View File

@@ -0,0 +1,172 @@
package twofactor
import (
"crypto"
"crypto/sha1" // #nosec G505 - required by RFC
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"hash"
"io"
"net/url"
"strconv"
"strings"
"github.com/benbjohnson/clock"
)
var timeSource = clock.New()
// TOTP represents an RFC 6238 Time-based One-Time Password instance.
type TOTP struct {
*OATH
step uint64
}
// NewTOTP takes a new key, a starting time, a step, the number of
// digits of output (typically 6 or 8) and the hash algorithm to
// use, and builds a new OTP.
func NewTOTP(key []byte, start uint64, step uint64, digits int, algo crypto.Hash) *TOTP {
h := hashFromAlgo(algo)
if h == nil {
return nil
}
return &TOTP{
OATH: &OATH{
key: key,
counter: start,
size: digits,
hash: h,
algo: algo,
},
step: step,
}
}
// NewGoogleTOTP takes a secret as a base32-encoded string and
// returns an appropriate Google Authenticator TOTP instance.
func NewGoogleTOTP(secret string) (*TOTP, error) {
key, err := base32.StdEncoding.DecodeString(secret)
if err != nil {
return nil, err
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1), nil
}
// NewTOTPSHA1 will build a new TOTP using SHA-1.
func NewTOTPSHA1(key []byte, start uint64, step uint64, digits int) *TOTP {
return NewTOTP(key, start, step, digits, crypto.SHA1)
}
// Type returns OATH_TOTP.
func (otp *TOTP) Type() Type {
return OATH_TOTP
}
func (otp *TOTP) otp(counter uint64) string {
return otp.OATH.OTP(counter)
}
// OTP returns the OTP for the current timestep.
func (otp *TOTP) OTP() string {
return otp.otp(otp.OTPCounter())
}
// URL returns a TOTP URL (i.e. for putting in a QR code).
func (otp *TOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *TOTP) SetProvider(provider string) {
otp.provider = provider
}
func (otp *TOTP) otpCounter(t uint64) uint64 {
return (t - otp.counter) / otp.step
}
// OTPCounter returns the current time value for the OTP.
func (otp *TOTP) OTPCounter() uint64 {
return otp.otpCounter(uint64(timeSource.Now().Unix() & 0x7FFFFFFF)) //#nosec G115 - masked out overflow bits
}
func hashFromAlgo(algo crypto.Hash) func() hash.Hash {
switch algo {
case crypto.SHA1:
return sha1.New
case crypto.SHA256:
return sha256.New
case crypto.SHA512:
return sha512.New
}
return nil
}
// GenerateGoogleTOTP produces a new TOTP token with the defaults expected by
// Google Authenticator.
func GenerateGoogleTOTP() *TOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1)
}
func totpFromURL(u *url.URL) (*TOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var algo = crypto.SHA1
if algorithm := v.Get("algorithm"); algorithm != "" {
switch {
case strings.EqualFold(algorithm, "SHA256"):
algo = crypto.SHA256
case strings.EqualFold(algorithm, "SHA512"):
algo = crypto.SHA512
case !strings.EqualFold(algorithm, "SHA1"):
return nil, "", ErrInvalidAlgo
}
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var period uint64 = 30
if speriod := v.Get("period"); speriod != "" {
var err error
period, err = strconv.ParseUint(speriod, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewTOTP(key, 0, period, digits, algo)
return otp, label, nil
}
// QR generates a new TOTP QR code.
func (otp *TOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}
func SetClock(c clock.Clock) {
timeSource = c
}

View File

@@ -0,0 +1,85 @@
package twofactor
import (
"crypto"
"testing"
"time"
"github.com/benbjohnson/clock"
)
var rfcTotpKey = map[crypto.Hash][]byte{
crypto.SHA1: []byte("12345678901234567890"),
crypto.SHA256: []byte("12345678901234567890123456789012"),
crypto.SHA512: []byte("1234567890123456789012345678901234567890123456789012345678901234"),
}
var rfcTotpStep uint64 = 30
var rfcTotpTests = []struct {
Time uint64
Code string
T uint64
Algo crypto.Hash
}{
{59, "94287082", 1, crypto.SHA1},
{59, "46119246", 1, crypto.SHA256},
{59, "90693936", 1, crypto.SHA512},
{1111111109, "07081804", 37037036, crypto.SHA1},
{1111111109, "68084774", 37037036, crypto.SHA256},
{1111111109, "25091201", 37037036, crypto.SHA512},
{1111111111, "14050471", 37037037, crypto.SHA1},
{1111111111, "67062674", 37037037, crypto.SHA256},
{1111111111, "99943326", 37037037, crypto.SHA512},
{1234567890, "89005924", 41152263, crypto.SHA1},
{1234567890, "91819424", 41152263, crypto.SHA256},
{1234567890, "93441116", 41152263, crypto.SHA512},
{2000000000, "69279037", 66666666, crypto.SHA1},
{2000000000, "90698825", 66666666, crypto.SHA256},
{2000000000, "38618901", 66666666, crypto.SHA512},
{20000000000, "65353130", 666666666, crypto.SHA1},
{20000000000, "77737706", 666666666, crypto.SHA256},
{20000000000, "47863826", 666666666, crypto.SHA512},
}
func TestTotpRFC(t *testing.T) {
for _, tc := range rfcTotpTests {
otp := NewTOTP(rfcTotpKey[tc.Algo], 0, rfcTotpStep, 8, tc.Algo)
if otp.otpCounter(tc.Time) != tc.T {
t.Logf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
t.Logf("\texpected: %d\n", tc.T)
t.Errorf("\t actual: %d\n", otp.otpCounter(tc.Time))
}
if code := otp.otp(otp.otpCounter(tc.Time)); code != tc.Code {
t.Logf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
t.Logf("\texpected: %s\n", tc.Code)
t.Errorf("\t actual: %s\n", code)
}
}
}
func TestTOTPTime(t *testing.T) {
otp := GenerateGoogleTOTP()
testClock := clock.NewMock()
testClock.Add(2 * time.Minute)
SetClock(testClock)
code := otp.OTP()
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: previous code %s shouldn't match code %s", newCode, code)
}
testClock.Add(2 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: future code %s shouldn't match code %s", newCode, code)
}
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode != code {
t.Errorf("twofactor: TOTP: current code %s shouldn't match code %s", newCode, code)
}
}

16
twofactor/util.go Normal file
View File

@@ -0,0 +1,16 @@
package twofactor
import (
"strings"
)
// Pad calculates the number of '='s to add to our encoded string
// to make base32.StdEncoding.DecodeString happy.
func Pad(s string) string {
if !strings.HasSuffix(s, "=") && len(s)%8 != 0 {
for len(s)%8 != 0 {
s += "="
}
}
return s
}

51
twofactor/util_test.go Normal file
View File

@@ -0,0 +1,51 @@
package twofactor_test
import (
"encoding/base32"
"math/rand"
"strings"
"testing"
"git.wntrmute.dev/kyle/goutils/twofactor"
)
const letters = "1234567890!@#$%^&*()abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randString() string {
b := make([]byte, rand.Intn(len(letters)))
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return base32.StdEncoding.EncodeToString(b)
}
func TestPadding(t *testing.T) {
for range 300 {
b := randString()
origEncoding := b
modEncoding := strings.ReplaceAll(b, "=", "")
str, err := base32.StdEncoding.DecodeString(origEncoding)
if err != nil {
t.Fatal("Can't decode: ", b)
}
paddedEncoding := twofactor.Pad(modEncoding)
if origEncoding != paddedEncoding {
t.Log("Padding failed:")
t.Logf("Expected: '%s'", origEncoding)
t.Fatalf("Got: '%s'", paddedEncoding)
} else {
var mstr []byte
mstr, err = base32.StdEncoding.DecodeString(paddedEncoding)
if err != nil {
t.Fatal("Can't decode: ", paddedEncoding)
}
if string(mstr) != string(str) {
t.Log("Re-padding failed:")
t.Logf("Expected: '%s'", str)
t.Fatalf("Got: '%s'", mstr)
}
}
}
}