Compare commits

...

46 Commits

Author SHA1 Message Date
e2a3081ce5 cmd: add certser command. 2025-11-17 07:18:46 -08:00
3149d958f4 cmd: add certser 2025-11-17 06:55:20 -08:00
f296344acf twofactor: linting fixes 2025-11-16 21:51:38 -08:00
3fb2d88a3f go get rsc.io/qr 2025-11-16 20:44:13 -08:00
150c02b377 Fix subtree. 2025-11-16 18:55:43 -08:00
83f88c49fe Import twofactor. 2025-11-16 18:45:34 -08:00
7c437ac45f Add 'twofactor/' from commit 'c999bf35b0e47de4f63d59abbe0d7efc76c13ced'
git-subtree-dir: twofactor
git-subtree-mainline: 4dc135cfe0
git-subtree-split: c999bf35b0
2025-11-16 18:43:03 -08:00
c999bf35b0 linter fixes. 2025-11-16 18:39:18 -08:00
4dc135cfe0 Update CHANGELOG for v1.11.2. 2025-11-16 13:18:38 -08:00
790113e189 cmd: refactor for code reuse. 2025-11-16 13:15:08 -08:00
8348c5fd65 Update CHANGELOG. 2025-11-16 11:09:02 -08:00
1eafb638a8 cmd: finish linting fixes 2025-11-16 11:03:12 -08:00
3ad562b6fa cmd: continuing linter fixes 2025-11-16 02:54:02 -08:00
0f77bd49dc cmd: continue lint fixes. 2025-11-16 01:32:19 -08:00
f31d74243f cmd: start linting fixes. 2025-11-16 00:36:19 -08:00
0dcd18c6f1 clean up code
- travisci is long dead
- golangci-lint the repo
2024-12-02 13:47:43 -08:00
024d552293 add circle ci config 2024-12-02 13:26:34 -08:00
9cd2ced695 There are different keys for different hash sizes. 2024-12-02 13:16:32 -08:00
619c08a13f Update travis to latest Go versions. 2020-10-31 08:06:11 -07:00
944a57bf0e Switch to go modules. 2020-10-31 07:41:31 -07:00
0857b29624 Actually support clock mocking. 2020-10-31 07:24:34 -07:00
CodeLingo Bot
e95404bfc5 Fix function comments based on best practices from Effective Go
Signed-off-by: CodeLingo Bot <bot@codelingo.io>
2020-10-30 14:57:11 -07:00
ujjwalsh
924654e7c4 Added Support for Linux on Power 2020-10-30 07:50:32 -07:00
9e0979e07f Support clock mocking.
This addresses #15.
2018-12-07 08:23:01 -08:00
Aaron Bieber
bbc82ff8de Pad non-padded secrets. This lets us continue building on <= go1.8.
- Add tests for secrets using various padding methods.
- Add a new method/test to append padding to non-padded secrets.
2018-04-18 13:39:21 -07:00
Aaron Bieber
5fd928f69a Decode using WithPadding as pointed out by @gl-sergei.
This makes us print the same 6 digits as oathtool for non-padded
secrets like "a6mryljlbufszudtjdt42nh5by".
2018-04-18 13:39:21 -07:00
Aaron Bieber
acefe4a3b9 Don't assume our secret is base32 encoded.
According to https://en.wikipedia.org/wiki/Time-based_One-time_Password_algorithm
secrets are only base32 encoded in gauthenticator and gauth friendly providers.
2018-04-16 13:14:03 -07:00
a1452cebc9 Travis requires a string for Go 1.10. 2018-04-16 13:03:16 -07:00
6e9812e6f5 Vendor dependencies and add more tests. 2018-04-16 13:03:16 -07:00
Aaron Bieber
8c34415c34 add readme 2018-04-16 12:52:39 -07:00
Paul TREHIOU
2cf2c15def Case insensitive algorithm match 2018-04-16 12:43:27 -07:00
Aaron Bieber
eaad1884d4 Make sure our secret is always uppercase
Non-uppercase secrets that are base32 encoded will fial to decode
unless we upper them.
2017-09-17 18:19:23 -07:00
5d57d844d4 Add license (MIT). 2017-04-13 10:02:20 -07:00
Kyle Isom
31b9d175dd Add travis config. 2017-03-20 14:20:49 -07:00
Aaron Bieber
79e106da2e point to new qr location 2017-03-20 13:18:56 -07:00
Kyle Isom
939b1bc272 Updating imports. 2015-08-12 12:29:34 -07:00
Kyle
89e74f390b Add doc.go, finish YubiKey removal. 2014-04-24 20:43:13 -06:00
Kyle
7881b6fdfc Remove test TOTP client. 2014-04-24 20:40:44 -06:00
Kyle
5bef33245f Remove YubiKey (not currently functional). 2014-04-24 20:37:53 -06:00
Kyle
84250b0501 More documentation. 2014-04-24 20:37:00 -06:00
Kyle Isom
459e9f880f Add function to build Google TOTPs from secret 2014-04-23 16:54:16 -07:00
Kyle Isom
0982f47ce3 Add last night's progress.
Basic functionality for HOTP, TOTP, and YubiKey OTP. Still need YubiKey
HMAC, serialisation, check, and scan.
2013-12-20 17:00:01 -07:00
Kyle Isom
1dec15fd11 add missing files
new files are
	oath_test
	totp code
2013-12-19 00:21:26 -07:00
Kyle Isom
2ee9cae5ba Add basic Google Authenticator TOTP client. 2013-12-19 00:20:00 -07:00
Kyle Isom
dc04475120 HOTP and TOTP-SHA-1 working.
why the frak aren't the SHA-256 and SHA-512 variants working
2013-12-19 00:04:26 -07:00
Kyle Isom
dbbd5116b5 Initial import.
Basic HOTP functionality.
2013-12-18 21:48:14 -07:00
68 changed files with 2387 additions and 901 deletions

View File

@@ -64,4 +64,4 @@ workflows:
testbuild:
jobs:
- testbuild
# - lint
- lint

View File

@@ -12,12 +12,31 @@
version: "2"
output:
sort-order:
- file
- linter
- severity
issues:
# Maximum count of issues with the same text.
# Set to 0 to disable.
# Default: 3
max-same-issues: 50
# Exclude some lints for CLI programs under cmd/ (package main).
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
exclude-rules:
- path: ^cmd/
linters:
- forbidigo
- path: cmd/.*
linters:
- forbidigo
- path: .*/cmd/.*
linters:
- forbidigo
formatters:
enable:
- goimports # checks if the code and import statements are formatted according to the 'goimports' command
@@ -73,7 +92,6 @@ linters:
- godoclint # checks Golang's documentation practice
- godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- goprintffuncname # checks that printf-like functions are named with f at the end
- gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
@@ -230,6 +248,10 @@ linters:
check-type-assertions: true
exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
exhaustive:
# Program elements to check for exhaustiveness.
@@ -321,6 +343,12 @@ linters:
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
- no-unused-link
gosec:
excludes:
- G104 # handled by errcheck
- G301
- G306
govet:
# Enable all analyzers.
# Default: false
@@ -356,6 +384,12 @@ linters:
- os.WriteFile
- prometheus.ExponentialBuckets.*
- prometheus.LinearBuckets
ignored-numbers:
- 1
- 2
- 3
- 4
- 8
nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns.
@@ -424,6 +458,10 @@ linters:
# Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008
- -QF1008
# We often explicitly enable old/deprecated ciphers for research.
- -SA1019
# Covered by revive.
- -ST1003
usetesting:
# Enable/disable `os.TempDir()` detections.
@@ -442,6 +480,8 @@ linters:
rules:
- path: 'ahash/ahash.go'
linters: [ staticcheck, gosec ]
- path: 'twofactor/.*.go'
linters: [ exhaustive, mnd, revive ]
- path: 'backoff/backoff_test.go'
linters: [ testpackage ]
- path: 'dbg/dbg_test.go'
@@ -452,6 +492,8 @@ linters:
linters: [ testableexamples ]
- path: 'main.go'
linters: [ forbidigo, mnd, reassign ]
- path: 'cmd/cruntar/main.go'
linters: [ unparam ]
- source: 'TODO'
linters: [ godot ]
- text: 'should have a package comment'

View File

@@ -1,5 +1,52 @@
CHANGELOG
v1.13.0 - 2025-11-16
Add:
- cmd/certser: print serial numbers for certificates.
- lib/HexEncode: add a new hex encode function handling multiple output
formats, including with and without colons.
v1.12.4 - 2025-11-16
Changed:
- Linting fixes for twofactor that were previously masked.
v1.12.3 erroneously tagged and pushed
v1.12.2 - 2025-11-16
Changed:
- add rsc.io/qr dependency for twofactor.
v1.12.1 - 2025-11-16
Changed:
- twofactor: Remove go.{mod,sum}.
v1.12.0 - 2025-11-16
Added
- twofactor: the github.com/kisom/twofactor repo has been subtree'd
into this repo.
v1.11.2 - 2025-11-16
Changed
- cmd/ski, cmd/csrpubdump, cmd/tlskeypair: centralize
certificate/private-key/CSR parsing by reusing certlib helpers.
This reduces duplication and improves consistency across commands.
- csr: CSR parsing in the above commands now uses certlib.ParseCSR,
which verifies CSR signatures (behavioral hardening compared to
prior parsing without signature verification).
v1.11.1 - 2025-11-16
Changed
- cmd: complete linting fixes across programs; no functional changes.
v1.11.0 - 2025-11-15
Added

View File

@@ -2,39 +2,52 @@ GOUTILS
This is a collection of small utility code I've written in Go; the `cmd/`
directory has a number of command-line utilities. Rather than keep all
of these in superfluous repositories of their own, or rewriting them
of these in superfluous repositories of their own or rewriting them
for each project, I'm putting them here.
The project can be built with the standard Go tooling, or it can be built
with Bazel.
The project can be built with the standard Go tooling.
Contents:
ahash/ Provides hashes from string algorithm specifiers.
assert/ Error handling, assertion-style.
backoff/ Implementation of an intelligent backoff strategy.
cache/ Implementations of various caches.
lru/ Least-recently-used cache.
mru/ Most-recently-used cache.
certlib/ Library for working with TLS certificates.
cmd/
atping/ Automated TCP ping, meant for putting in cronjobs.
certchain/ Display the certificate chain from a
TLS connection.
ca-signed/ Validate whether a certificate is signed by a CA.
cert-bundler/
Create certificate bundles from a source of PEM
certificates.
cert-revcheck/
Check whether a certificate has been revoked or is
expired.
certchain/ Display the certificate chain from a TLS connection.
certdump/ Dump certificate information.
certexpiry/ Print a list of certificate subjects and expiry times
or warn about certificates expiring within a certain
window.
certverify/ Verify a TLS X.509 certificate, optionally printing
certverify/ Verify a TLS X.509 certificate file, optionally printing
the time to expiry and checking for revocations.
clustersh/ Run commands or transfer files across multiple
servers via SSH.
cruntar/ Untar an archive with hard links, copying instead of
cruntar/ (Un)tar an archive with hard links, copying instead of
linking.
csrpubdump/ Dump the public key from an X.509 certificate request.
data_sync/ Sync the user's homedir to external storage.
diskimg/ Write a disk image to a device.
dumpbytes/ Dump the contents of a file as hex bytes, printing it as
a Go []byte literal.
eig/ EEPROM image generator.
fragment/ Print a fragment of a file.
host/ Go imlpementation of the host(1) command.
jlp/ JSON linter/prettifier.
kgz/ Custom gzip compressor / decompressor that handles 99%
of my use cases.
minmax/ Generate a minmax code for use in uLisp.
parts/ Simple parts database management for my collection of
electronic components.
pem2bin/ Dump the binary body of a PEM-encoded block.
@@ -44,41 +57,45 @@ Contents:
in a bundle.
renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash.
rhash/ Compute the digest of remote files.
rolldie/ Roll some dice.
showimp/ List the external (e.g. non-stdlib and outside the
current working directory) imports for a Go file.
ski Display the SKI for PEM-encoded TLS material.
sprox/ Simple TCP proxy.
stealchain/ Dump the verified chain from a TLS
connection to a server.
stealchain- Dump the verified chain from a TLS
server/ connection from a client.
stealchain/ Dump the verified chain from a TLS connection to a
server.
stealchain-server/
Dump the verified chain from a TLS connection from
from a client.
subjhash/ Print or match subject info from a certificate.
tlsinfo/ Print information about a TLS connection (the TLS version
and cipher suite).
tlskeypair/ Check whether a TLS certificate and key file match.
utc/ Convert times to UTC.
yamll/ A small YAML linter.
zsearch/ Search for a string in directory of gzipped files.
config/ A simple global configuration system where configuration
data is pulled from a file or an environment variable
transparently.
iniconf/ A simple INI-style configuration system.
dbg/ A debug printer.
die/ Death of a program.
fileutil/ Common file functions.
lib/ Commonly-useful functions for writing Go programs.
log/ A syslog library.
logging/ A logging library.
mwc/ MultiwriteCloser implementation.
rand/ Utilities for working with math/rand.
sbuf/ A byte buffer that can be wiped.
seekbuf/ A read-seekable byte buffer.
syslog/ Syslog-type logging.
tee/ Emulate tee(1)'s functionality in io.Writers.
testio/ Various I/O utilities useful during testing.
testutil/ Various utility functions useful during testing.
twofactor/ Two-factor authentication.
Each program should have a small README in the directory with more
information.
All code here is licensed under the ISC license.
All code here is licensed under the Apache 2.0 license.
Error handling
--------------
@@ -99,7 +116,7 @@ Examples:
```
cert, err := certlib.LoadCertificate(path)
if err != nil {
// sentinel match
// sentinel match:
if errors.Is(err, certerr.ErrEmptyCertificate) {
// handle empty input
}
@@ -116,5 +133,3 @@ if err != nil {
}
}
```
Avoid including sensitive data (keys, passwords, tokens) in error messages.

View File

@@ -91,7 +91,7 @@ func TestReset(t *testing.T) {
}
}
const decay = 5 * time.Millisecond
const decay = 25 * time.Millisecond
const maxDuration = 10 * time.Millisecond
const interval = time.Millisecond

View File

@@ -1,87 +1,87 @@
package lru
import (
"testing"
"time"
"testing"
"time"
"github.com/benbjohnson/clock"
"github.com/benbjohnson/clock"
)
// These tests mirror the MRU-style behavior present in this LRU package
// implementation (eviction removes the most-recently-used entry).
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
// Since this implementation evicts the most-recently-used item, inserting
// "goat" when full evicts "owl" (the most recent at that time).
mock.Add(time.Second)
if _, ok := c.Get("owl"); ok {
t.Fatal("store should not have an entry for owl (MRU-evicted)")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// Since this implementation evicts the most-recently-used item, inserting
// "goat" when full evicts "owl" (the most recent at that time).
mock.Add(time.Second)
if _, ok := c.Get("owl"); ok {
t.Fatal("store should not have an entry for owl (MRU-evicted)")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
if !c.Has("raven") {
t.Fatal("store should contain an entry for 'raven'")
}
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
if !c.Has("raven") {
t.Fatal("store should contain an entry for 'raven'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

View File

@@ -1,50 +1,50 @@
package lru
import (
"testing"
"time"
"testing"
"time"
"github.com/benbjohnson/clock"
"github.com/benbjohnson/clock"
)
// These tests validate timestamps ordering semantics for the LRU package.
// Note: The LRU timestamps are sorted with most-recent-first (descending by t).
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// make owl the most recent
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// make owl the most recent
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
if ts.K(0) != "owl" {
t.Fatalf("first key should be owl, have %s", ts.K(0))
}
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
if ts.K(0) != "owl" {
t.Fatalf("first key should be owl, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "raven" {
t.Fatalf("third key should be raven, have %s", ts.K(2))
}
if ts.K(2) != "raven" {
t.Fatalf("third key should be raven, have %s", ts.K(2))
}
}

View File

@@ -8,9 +8,9 @@ import (
)
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
@@ -55,18 +55,18 @@ func TestBasicCacheEviction(t *testing.T) {
}
mock.Add(time.Second)
v, ok := c.Get("owl")
if !ok {
t.Fatal("store should have an entry for owl")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
v, ok := c.Get("owl")
if !ok {
t.Fatal("store should have an entry for owl")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
itm := v
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
itm := v
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if itm != 2 {
t.Fatalf("stored item should be 2, have %d", itm)

View File

@@ -8,8 +8,8 @@ import (
)
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())

View File

@@ -458,8 +458,6 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
}
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
if password != nil {
// nolintlint requires rationale:
//nolint:staticcheck // legacy RFC1423 PEM encryption supported for backward compatibility when caller supplies a password
return x509.DecryptPEMBlock(keyDER, password)
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)

View File

@@ -1,14 +1,15 @@
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"flag"
"errors"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"strings"
"time"
"git.wntrmute.dev/kyle/goutils/certlib"
@@ -23,6 +24,13 @@ var (
verbose bool
)
var (
strOK = "OK"
strExpired = "EXPIRED"
strRevoked = "REVOKED"
strUnknown = "UNKNOWN"
)
func main() {
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
@@ -42,16 +50,16 @@ func main() {
for _, target := range flag.Args() {
status, err := processTarget(target)
switch status {
case "OK":
fmt.Printf("%s: OK\n", target)
case "EXPIRED":
fmt.Printf("%s: EXPIRED: %v\n", target, err)
case strOK:
fmt.Printf("%s: %s\n", target, strOK)
case strExpired:
fmt.Printf("%s: %s: %v\n", target, strExpired, err)
exitCode = 1
case "REVOKED":
fmt.Printf("%s: REVOKED\n", target)
case strRevoked:
fmt.Printf("%s: %s\n", target, strRevoked)
exitCode = 1
case "UNKNOWN":
fmt.Printf("%s: UNKNOWN: %v\n", target, err)
case strUnknown:
fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
if hardfail {
// In hardfail, treat unknown as failure
exitCode = 1
@@ -67,74 +75,77 @@ func processTarget(target string) (string, error) {
return checkFile(target)
}
// Not a file; treat as site
return checkSite(target)
}
func checkFile(path string) (string, error) {
in, err := ioutil.ReadFile(path)
if err != nil {
return "UNKNOWN", err
// Prefer high-level helpers from certlib to load certificates from disk
if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
// Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
}
// Try PEM first; if that fails, try single DER cert
certs, err := certlib.ReadCertificates(in)
if err != nil || len(certs) == 0 {
cert, _, derr := certlib.ReadCertificate(in)
if derr != nil || cert == nil {
if err == nil {
err = derr
}
return "UNKNOWN", err
}
return evaluateCert(cert)
cert, err := certlib.LoadCertificate(path)
if err != nil || cert == nil {
return strUnknown, err
}
// Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0])
return evaluateCert(cert)
}
func checkSite(hostport string) (string, error) {
// Use certlib/hosts to parse host/port (supports https URLs and host:port)
target, err := hosts.ParseHost(hostport)
if err != nil {
return "UNKNOWN", err
return strUnknown, err
}
d := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host})
tcfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: target.Host,
} // #nosec G402 -- CLI tool only verifies revocation
td := &tls.Dialer{NetDialer: d, Config: tcfg}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := td.DialContext(ctx, "tcp", target.String())
if err != nil {
return "UNKNOWN", err
return strUnknown, err
}
defer conn.Close()
state := conn.ConnectionState()
tconn, ok := conn.(*tls.Conn)
if !ok {
return strUnknown, errors.New("connection is not TLS")
}
state := tconn.ConnectionState()
if len(state.PeerCertificates) == 0 {
return "UNKNOWN", errors.New("no peer certificates presented")
return strUnknown, errors.New("no peer certificates presented")
}
return evaluateCert(state.PeerCertificates[0])
}
func evaluateCert(cert *x509.Certificate) (string, error) {
// Expiry check
now := time.Now()
if !now.Before(cert.NotAfter) {
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
}
if !now.After(cert.NotBefore) {
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
}
// Revocation check using certlib/revoke
// Delegate validity and revocation checks to certlib/revoke helper.
// It returns revoked=true for both revoked and expired/not-yet-valid.
// Map those cases back to our statuses using the returned error text.
revoked, ok, err := revoke.VerifyCertificateError(cert)
if revoked {
// If revoked is true, ok will be true per implementation, err may describe why
return "REVOKED", err
if err != nil {
msg := err.Error()
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
strings.Contains(msg, "not valid until") {
return strExpired, err
}
}
return strRevoked, err
}
if !ok {
// Revocation status could not be determined
return "UNKNOWN", err
return strUnknown, err
}
return "OK", nil
return strOK, nil
}

View File

@@ -1,11 +1,14 @@
package main
import (
"context"
"crypto/tls"
"encoding/pem"
"flag"
"fmt"
"os"
"regexp"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
)
@@ -20,20 +23,26 @@ func main() {
server += ":443"
}
var chain string
conn, err := tls.Dial("tcp", server, nil)
d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", server)
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
details := conn.ConnectionState()
var chain strings.Builder
for _, cert := range details.PeerCertificates {
p := pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain += string(pem.EncodeToMemory(&p))
chain.Write(pem.EncodeToMemory(&p))
}
fmt.Println(chain)
fmt.Fprintln(os.Stdout, chain.String())
}
}

View File

@@ -1,7 +1,9 @@
//lint:file-ignore SA1019 allow strict compatibility for old certs
package main
import (
"bytes"
"context"
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
@@ -101,30 +103,30 @@ func extUsage(ext []x509.ExtKeyUsage) string {
}
func showBasicConstraints(cert *x509.Certificate) {
fmt.Printf("\tBasic constraints: ")
fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid {
fmt.Printf("valid")
fmt.Fprint(os.Stdout, "valid")
} else {
fmt.Printf("invalid")
fmt.Fprint(os.Stdout, "invalid")
}
if cert.IsCA {
fmt.Printf(", is a CA certificate")
fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid {
fmt.Printf(" (basic constraint failure)")
fmt.Fprint(os.Stdout, " (basic constraint failure)")
}
} else {
fmt.Printf("is not a CA certificate")
fmt.Fprint(os.Stdout, "is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Printf(" (key encipherment usage enabled!)")
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
}
}
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Printf(", max path length %d", cert.MaxPathLen)
fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
}
fmt.Printf("\n")
fmt.Fprintln(os.Stdout)
}
const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
@@ -136,39 +138,41 @@ var (
func wrapPrint(text string, indent int) {
tabs := ""
for i := 0; i < indent; i++ {
tabs += "\t"
var tabsSb140 strings.Builder
for range indent {
tabsSb140.WriteString("\t")
}
tabs += tabsSb140.String()
fmt.Printf(tabs+"%s\n", wrap(text, indent))
fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
}
func displayCert(cert *x509.Certificate) {
fmt.Println("CERTIFICATE")
fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash {
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
}
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm))
fmt.Println("Details:")
fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber)
fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
}
if len(cert.SubjectKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
}
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage))
fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 {
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
}
showBasicConstraints(cert)
@@ -221,13 +225,13 @@ func displayAllCerts(in []byte, leafOnly bool) {
if err != nil {
certs, _, err = certlib.ParseCertificatesDER(in, "")
if err != nil {
lib.Warn(err, "failed to parse certificates")
_, _ = lib.Warn(err, "failed to parse certificates")
return
}
}
if len(certs) == 0 {
lib.Warnx("no certificates found")
_, _ = lib.Warnx("no certificates found")
return
}
@@ -243,29 +247,45 @@ func displayAllCerts(in []byte, leafOnly bool) {
func displayAllCertsWeb(uri string, leafOnly bool) {
ci := getConnInfo(uri)
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig())
d := &tls.Dialer{Config: permissiveConfig()}
nc, err := d.DialContext(context.Background(), "tcp", ci.Addr)
if err != nil {
lib.Warn(err, "couldn't connect to %s", ci.Addr)
_, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr)
return
}
conn, ok := nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return
}
defer conn.Close()
state := conn.ConnectionState()
conn.Close()
if err = conn.Close(); err != nil {
_, _ = lib.Warn(err, "couldn't close TLS connection")
}
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host))
d = &tls.Dialer{Config: verifyConfig(ci.Host)}
nc, err = d.DialContext(context.Background(), "tcp", ci.Addr)
if err == nil {
conn, ok = nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return
}
err = conn.VerifyHostname(ci.Host)
if err == nil {
state = conn.ConnectionState()
}
conn.Close()
} else {
lib.Warn(err, "TLS verification error with server name %s", ci.Host)
_, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host)
}
if len(state.PeerCertificates) == 0 {
lib.Warnx("no certificates found")
_, _ = lib.Warnx("no certificates found")
return
}
@@ -275,14 +295,14 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
}
if len(state.VerifiedChains) == 0 {
lib.Warnx("no verified chains found; using peer chain")
_, _ = lib.Warnx("no verified chains found; using peer chain")
for i := range state.PeerCertificates {
displayCert(state.PeerCertificates[i])
}
} else {
fmt.Println("TLS chain verified successfully.")
fmt.Fprintln(os.Stdout, "TLS chain verified successfully.")
for i := range state.VerifiedChains {
fmt.Printf("--- Verified certificate chain %d ---\n", i+1)
fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n")
for j := range state.VerifiedChains[i] {
displayCert(state.VerifiedChains[i][j])
}
@@ -290,6 +310,32 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
}
}
func shouldReadStdin(argc int, argv []string) bool {
if argc == 0 {
return true
}
if argc == 1 && argv[0] == "-" {
return true
}
return false
}
func readStdin(leafOnly bool) {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.ReplaceAll(certs, []byte(`\n`), []byte{0xa})
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
}
func main() {
var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
@@ -297,32 +343,23 @@ func main() {
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse()
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
if shouldReadStdin(flag.NArg(), flag.Args()) {
readStdin(leafOnly)
return
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1)
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
} else {
for _, filename := range flag.Args() {
fmt.Printf("--%s ---\n", filename)
if strings.HasPrefix(filename, "https://") {
displayAllCertsWeb(filename, leafOnly)
} else {
in, err := os.ReadFile(filename)
if err != nil {
lib.Warn(err, "couldn't read certificate")
continue
}
displayAllCerts(in, leafOnly)
for _, filename := range flag.Args() {
fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
if strings.HasPrefix(filename, "https://") {
displayAllCertsWeb(filename, leafOnly)
} else {
in, err := os.ReadFile(filename)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificate")
continue
}
displayAllCerts(in, leafOnly)
}
}
}

View File

@@ -13,6 +13,11 @@ import (
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment",
@@ -26,42 +31,36 @@ var keyUsage = map[x509.KeyUsage]string{
}
var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
}
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
switch a {
case x509.RSA:
return "RSA"
case x509.ECDSA:
return "ECDSA"
case x509.DSA:
return "DSA"
default:
return "unknown public key algorithm"
}
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
}
func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default:
return "unknown public key algorithm"
}
@@ -76,11 +75,21 @@ func sigAlgoHash(a x509.SignatureAlgorithm) string {
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return "SHA256"
return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return "SHA512"
return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default:
return "unknown hash algorithm"
}
@@ -90,9 +99,11 @@ const maxLine = 78
func makeIndent(n int) string {
s := " "
for i := 0; i < n; i++ {
s += " "
var sSb97 strings.Builder
for range n {
sSb97.WriteString(" ")
}
s += sSb97.String()
return s
}
@@ -100,7 +111,7 @@ func indentLen(n int) int {
return 4 + (8 * n)
}
// this isn't real efficient, but that's not a problem here
// this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string {
if indent > 3 {
indent = 3
@@ -123,9 +134,11 @@ func wrap(s string, indent int) string {
func dumpHex(in []byte) string {
var s string
var sSb130 strings.Builder
for i := range in {
s += fmt.Sprintf("%02X:", in[i])
sSb130.WriteString(fmt.Sprintf("%02X:", in[i]))
}
s += sSb130.String()
return strings.Trim(s, ":")
}
@@ -136,14 +149,14 @@ func dumpHex(in []byte) string {
func permissiveConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
}
} // #nosec G402
}
// verifyConfig returns a config that will verify the connection.
func verifyConfig(hostname string) *tls.Config {
return &tls.Config{
ServerName: hostname,
}
} // #nosec G402
}
type connInfo struct {

View File

@@ -5,7 +5,6 @@ import (
"crypto/x509/pkix"
"flag"
"fmt"
"io/ioutil"
"os"
"strings"
"time"
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
}
func expires(cert *x509.Certificate) time.Duration {
return cert.NotAfter.Sub(time.Now())
return time.Until(cert.NotAfter)
}
func inDanger(cert *x509.Certificate) bool {
@@ -81,15 +80,15 @@ func main() {
flag.Parse()
for _, file := range flag.Args() {
in, err := ioutil.ReadFile(file)
in, err := os.ReadFile(file)
if err != nil {
lib.Warn(err, "failed to read file")
_, _ = lib.Warn(err, "failed to read file")
continue
}
certs, err := certlib.ParseCertificatesPEM(in)
if err != nil {
lib.Warn(err, "while parsing certificates")
_, _ = lib.Warn(err, "while parsing certificates")
continue
}

51
cmd/certser/main.go Normal file
View File

@@ -0,0 +1,51 @@
package main
import (
"crypto/x509"
"flag"
"fmt"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const displayInt lib.HexEncodeMode = iota
func parseDisplayMode(mode string) lib.HexEncodeMode {
mode = strings.ToLower(mode)
if mode == "int" {
return displayInt
}
return lib.ParseHexEncodeMode(mode)
}
func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
if mode == displayInt {
return cert.SerialNumber.String()
}
return lib.HexEncode(cert.SerialNumber.Bytes(), mode)
}
func main() {
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date")
flag.Parse()
displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() {
cert, err := certlib.LoadCertificate(arg)
die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode))
if *showExpiry {
fmt.Printf(" (%s)", cert.NotAfter.Format("2006-01-02"))
}
fmt.Println()
}
}

View File

@@ -4,13 +4,11 @@ import (
"crypto/x509"
"flag"
"fmt"
"io/ioutil"
"os"
"time"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
@@ -30,83 +28,116 @@ func printRevocation(cert *x509.Certificate) {
}
}
func main() {
var caFile, intFile string
var forceIntermediateBundle, revexp, verbose bool
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`")
flag.StringVar(&intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&forceIntermediateBundle, "f", false,
type appConfig struct {
caFile, intFile string
forceIntermediateBundle bool
revexp, verbose bool
}
func parseFlags() appConfig {
var cfg appConfig
flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&verbose, "v", false, "verbose")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse()
return cfg
}
var roots *x509.CertPool
if caFile != "" {
var err error
if verbose {
fmt.Println("[+] loading root certificates from", caFile)
}
roots, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile == "" {
return x509.SystemCertPool()
}
var ints *x509.CertPool
if intFile != "" {
var err error
if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile)
}
ints, err = certlib.LoadPEMCertPool(caFile)
die.If(err)
} else {
ints = x509.NewCertPool()
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
lib.ProgName())
}
fileData, err := ioutil.ReadFile(flag.Arg(0))
die.If(err)
chain, err := certlib.ParseCertificatesPEM(fileData)
die.If(err)
if verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
fmt.Println("[+] loading root certificates from", caFile)
}
return certlib.LoadPEMCertPool(caFile)
}
cert := chain[0]
if len(chain) > 1 {
if !forceIntermediateBundle {
for _, intermediate := range chain[1:] {
if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
}
func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if intFile == "" {
return x509.NewCertPool(), nil
}
if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile)
}
// Note: use intFile here (previously used caFile mistakenly)
return certlib.LoadPEMCertPool(intFile)
}
ints.AddCert(intermediate)
}
func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
for _, intermediate := range chain[1:] {
if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
}
pool.AddCert(intermediate)
}
}
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
opts := x509.VerifyOptions{
Intermediates: ints,
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
}
_, err := cert.Verify(opts)
return err
}
_, err = cert.Verify(opts)
func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil {
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err)
os.Exit(1)
return err
}
if verbose {
ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
fileData, err := os.ReadFile(flag.Arg(0))
if err != nil {
return err
}
chain, err := certlib.ParseCertificatesPEM(fileData)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK")
}
if revexp {
if cfg.revexp {
printRevocation(cert)
}
return nil
}
func main() {
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}

View File

@@ -2,6 +2,8 @@ package main
import (
"bufio"
"context"
"errors"
"flag"
"fmt"
"io"
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
}
func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
}
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
}
}
func logError(host string, err error, format string, args ...interface{}) {
func logError(host string, err error, format string, args ...any) {
msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
}
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down")
}
}
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
}
shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close()
logError(host, err, "request for pty failed")
return
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down")
}
}
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
}
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
} else if err != nil {
logError(host, err, "reading chunk")
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() {
for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]()
if err != nil && err != io.EOF {
if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down")
}
}
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
}
if err == io.EOF {
if errors.Is(err, io.EOF) {
break
} else if err != nil {
logError(host, err, "reading chunk")

View File

@@ -10,6 +10,7 @@ import (
"io"
"os"
"path/filepath"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/fileutil"
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
if verbose {
fmt.Printf("\tchmod %0#o\n", hdr.Mode)
}
err := file.Chmod(os.FileMode(hdr.Mode))
err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
if err != nil {
return err
}
@@ -48,73 +49,105 @@ func linkTarget(target, top string) string {
return target
}
return filepath.Clean(filepath.Join(target, top))
return filepath.Clean(filepath.Join(top, target))
}
// safeJoin joins base and elem and ensures the resulting path does not escape base.
func safeJoin(base, elem string) (string, error) {
cleanBase := filepath.Clean(base)
joined := filepath.Clean(filepath.Join(cleanBase, elem))
absBase, err := filepath.Abs(cleanBase)
if err != nil {
return "", err
}
absJoined, err := filepath.Abs(joined)
if err != nil {
return "", err
}
rel, err := filepath.Rel(absBase, absJoined)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
}
return joined, nil
}
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
if _, err = io.Copy(file, tfr); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeLink(hdr *tar.Header, top, filePath string) error {
file, err := os.Create(filePath)
if err != nil {
return err
}
defer file.Close()
srcPath, err := safeJoin(top, hdr.Linkname)
if err != nil {
return err
}
source, err := os.Open(srcPath)
if err != nil {
return err
}
defer source.Close()
if _, err = io.Copy(file, source); err != nil {
return err
}
return setupFile(hdr, file)
}
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
}
func handleTypeDir(hdr *tar.Header, filePath string) error {
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
}
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
if verbose {
fmt.Println(hdr.Name)
}
filePath := filepath.Clean(filepath.Join(top, hdr.Name))
switch hdr.Typeflag {
case tar.TypeReg:
file, err := os.Create(filePath)
if err != nil {
return err
}
_, err = io.Copy(file, tfr)
if err != nil {
return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeLink:
file, err := os.Create(filePath)
if err != nil {
return err
}
source, err := os.Open(hdr.Linkname)
if err != nil {
return err
}
_, err = io.Copy(file, source)
if err != nil {
return err
}
err = setupFile(hdr, file)
if err != nil {
return err
}
case tar.TypeSymlink:
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s",
hdr.Linkname, top)
}
path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
} else if err != nil {
return err
}
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath)
if err != nil {
return err
}
case tar.TypeDir:
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode))
if err != nil {
return err
}
filePath, err := safeJoin(top, hdr.Name)
if err != nil {
return err
}
switch hdr.Typeflag {
case tar.TypeReg:
return handleTypeReg(tfr, hdr, filePath)
case tar.TypeLink:
return handleTypeLink(hdr, top, filePath)
case tar.TypeSymlink:
return handleTypeSymlink(hdr, top, filePath)
case tar.TypeDir:
return handleTypeDir(hdr, filePath)
}
return nil
}
@@ -261,16 +294,16 @@ func main() {
die.If(err)
tfr := tar.NewReader(r)
var hdr *tar.Header
for {
hdr, err := tfr.Next()
if err == io.EOF {
hdr, err = tfr.Next()
if errors.Is(err, io.EOF) {
break
}
die.If(err)
err = processFile(tfr, hdr, top)
die.If(err)
}
r.Close()

View File

@@ -7,9 +7,9 @@ import (
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
)
@@ -17,17 +17,10 @@ func main() {
flag.Parse()
for _, fileName := range flag.Args() {
in, err := ioutil.ReadFile(fileName)
in, err := os.ReadFile(fileName)
die.If(err)
if p, _ := pem.Decode(in); p != nil {
if p.Type != "CERTIFICATE REQUEST" {
log.Fatal("INVALID FILE TYPE")
}
in = p.Bytes
}
csr, err := x509.ParseCertificateRequest(in)
csr, _, err := certlib.ParseCSR(in)
die.If(err)
out, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
@@ -48,8 +41,8 @@ func main() {
Bytes: out,
}
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644)
err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
die.If(err)
fmt.Printf("[+] wrote %s.\n", fileName+".pub")
fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
}
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"io"
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
return err
}
cmd := exec.Command(path, args...)
cmd := exec.CommandContext(context.Background(), path, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()
@@ -163,7 +164,6 @@ func init() {
}
func main() {
var logLevel, mountDir, syncDir, target string
var dryRun, quietMode, noSyslog, verboseRsync bool
@@ -219,7 +219,7 @@ func main() {
if excludeFile != "" {
defer func() {
log.Infof("removing exclude file %s", excludeFile)
if err := os.Remove(excludeFile); err != nil {
if rmErr := os.Remove(excludeFile); rmErr != nil {
log.Warningf("failed to remove temp file %s", excludeFile)
}
}()

View File

@@ -15,43 +15,41 @@ import (
const defaultHashAlgorithm = "sha256"
var (
hAlgo string
hAlgo string
debug = dbg.New()
)
func openImage(imageFile string) (image *os.File, hash []byte, err error) {
image, err = os.Open(imageFile)
func openImage(imageFile string) (*os.File, []byte, error) {
f, err := os.Open(imageFile)
if err != nil {
return
return nil, nil, err
}
hash, err = ahash.SumReader(hAlgo, image)
h, err := ahash.SumReader(hAlgo, f)
if err != nil {
return
return nil, nil, err
}
_, err = image.Seek(0, 0)
if err != nil {
return
if _, err = f.Seek(0, 0); err != nil {
return nil, nil, err
}
debug.Printf("%s %x\n", imageFile, hash)
return
debug.Printf("%s %x\n", imageFile, h)
return f, h, nil
}
func openDevice(devicePath string) (device *os.File, err error) {
func openDevice(devicePath string) (*os.File, error) {
fi, err := os.Stat(devicePath)
if err != nil {
return
return nil, err
}
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
if err != nil {
return
return nil, err
}
return
return device, nil
}
func main() {
@@ -105,12 +103,12 @@ func main() {
die.If(err)
if !bytes.Equal(deviceHash, hash) {
fmt.Fprintln(os.Stderr, "Hash mismatch:")
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash)
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash)
os.Exit(1)
buf := &bytes.Buffer{}
fmt.Fprintln(buf, "Hash mismatch:")
fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
die.With(buf.String())
}
debug.Println("OK")
os.Exit(0)
}

View File

@@ -1,30 +1,33 @@
package main
import (
"errors"
"flag"
"fmt"
"git.wntrmute.dev/kyle/goutils/die"
"io"
"os"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
)
func usage(w io.Writer, exc int) {
fmt.Fprintln(w, `usage: dumpbytes <file>`)
fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
os.Exit(exc)
}
func printBytes(buf []byte) {
fmt.Printf("\t")
for i := 0; i < len(buf); i++ {
for i := range buf {
fmt.Printf("0x%02x, ", buf[i])
}
fmt.Println()
}
func dumpFile(path string, indentLevel int) error {
indent := ""
for i := 0; i < indentLevel; i++ {
indent += "\t"
var indent strings.Builder
for range indentLevel {
indent.WriteByte('\t')
}
file, err := os.Open(path)
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
defer file.Close()
fmt.Printf("%svar buffer = []byte{\n", indent)
fmt.Printf("%svar buffer = []byte{\n", indent.String())
var n int
for {
buf := make([]byte, 8)
n, err := file.Read(buf)
if err == io.EOF {
n, err = file.Read(buf)
if errors.Is(err, io.EOF) {
if n > 0 {
fmt.Printf("%s", indent)
fmt.Printf("%s", indent.String())
printBytes(buf[:n])
}
break
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
return err
}
fmt.Printf("%s", indent)
fmt.Printf("%s", indent.String())
printBytes(buf[:n])
}
fmt.Printf("%s}\n", indent)
fmt.Printf("%s}\n", indent.String())
return nil
}

View File

@@ -7,7 +7,7 @@ import (
"git.wntrmute.dev/kyle/goutils/die"
)
// size of a kilobit in bytes
// size of a kilobit in bytes.
const kilobit = 128
const pageSize = 4096
@@ -26,10 +26,10 @@ func main() {
path = flag.Arg(0)
}
fillByte := uint8(*fill)
fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
buf := make([]byte, pageSize)
for i := 0; i < pageSize; i++ {
for i := range pageSize {
buf[i] = fillByte
}
@@ -40,7 +40,7 @@ func main() {
die.If(err)
defer file.Close()
for i := 0; i < pages; i++ {
for range pages {
_, err = file.Write(buf)
die.If(err)
}

View File

@@ -72,15 +72,13 @@ func main() {
if end < start {
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
tmp := end
end = start
start = tmp
start, end = end, start
}
var fmtStr string
if !*quiet {
maxLine := fmt.Sprintf("%d", len(lines))
maxLine := strconv.Itoa(len(lines))
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
}
@@ -98,9 +96,9 @@ func main() {
fmtStr += "\n"
for i := start; !endFunc(i); i++ {
if *quiet {
fmt.Println(lines[i])
fmt.Fprintln(os.Stdout, lines[i])
} else {
fmt.Printf(fmtStr, i, lines[i])
fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
}
}
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"log"
@@ -8,7 +9,8 @@ import (
)
func lookupHost(host string) error {
cname, err := net.LookupCNAME(host)
r := &net.Resolver{}
cname, err := r.LookupCNAME(context.Background(), host)
if err != nil {
return err
}
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
host = cname
}
addrs, err := net.LookupHost(host)
addrs, err := r.LookupHost(context.Background(), host)
if err != nil {
return err
}

View File

@@ -5,7 +5,7 @@ import (
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"io"
"os"
"git.wntrmute.dev/kyle/goutils/lib"
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
var err error
if file == "-" {
in, err = ioutil.ReadAll(os.Stdin)
in, err = io.ReadAll(os.Stdin)
} else {
in, err = ioutil.ReadFile(file)
in, err = os.ReadFile(file)
}
if err != nil {
lib.Warn(err, "ReadFile")
_, _ = lib.Warn(err, "ReadFile")
return err
}
var buf = &bytes.Buffer{}
err = json.Indent(buf, in, "", " ")
if err != nil {
lib.Warn(err, "%s", file)
_, _ = lib.Warn(err, "%s", file)
return err
}
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
if file == "-" {
_, err = os.Stdout.Write(buf.Bytes())
} else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644)
err = os.WriteFile(file, buf.Bytes(), 0o644)
}
if err != nil {
lib.Warn(err, "WriteFile")
_, _ = lib.Warn(err, "WriteFile")
}
return err
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
var err error
if file == "-" {
in, err = ioutil.ReadAll(os.Stdin)
in, err = io.ReadAll(os.Stdin)
} else {
in, err = ioutil.ReadFile(file)
in, err = os.ReadFile(file)
}
if err != nil {
lib.Warn(err, "ReadFile")
_, _ = lib.Warn(err, "ReadFile")
return err
}
var buf = &bytes.Buffer{}
err = json.Compact(buf, in)
if err != nil {
lib.Warn(err, "%s", file)
_, _ = lib.Warn(err, "%s", file)
return err
}
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
if file == "-" {
_, err = os.Stdout.Write(buf.Bytes())
} else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644)
err = os.WriteFile(file, buf.Bytes(), 0o644)
}
if err != nil {
lib.Warn(err, "WriteFile")
_, _ = lib.Warn(err, "WriteFile")
}
return err
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
func usage() {
progname := lib.ProgName()
fmt.Printf(`Usage: %s [-h] files...
fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
%s is used to lint and prettify (or compact) JSON files. The
files will be updated in-place.
@@ -100,7 +100,6 @@ func usage() {
-h Print this help message.
-n Don't prettify; only perform validation.
`, progname, progname)
}
func init() {

View File

@@ -11,7 +11,10 @@ based on whether the source filename ends in ".gz".
Flags:
-l level Compression level (0-9). Only meaninful when
compressing a file.
compressing a file.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.

View File

@@ -1,68 +1,84 @@
package main
import (
"compress/flate"
"compress/gzip"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"compress/flate"
"compress/gzip"
"flag"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
const gzipExt = ".gz"
func compress(path, target string, level int) error {
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
}
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
}
defer sourceFile.Close()
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
}
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
}
defer destFile.Close()
gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
if err != nil {
return fmt.Errorf("invalid compression level: %w", err)
}
gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
if err != nil {
return fmt.Errorf("invalid compression level: %w", err)
}
defer gzipCompressor.Close()
_, err = io.Copy(gzipCompressor, sourceFile)
if err != nil {
return fmt.Errorf("compressing file: %w", err)
}
_, err = io.Copy(gzipCompressor, sourceFile)
if err != nil {
return fmt.Errorf("compressing file: %w", err)
}
return nil
}
func uncompress(path, target string) error {
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
}
func uncompress(path, target string, unrestrict bool) error {
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
}
defer sourceFile.Close()
gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil {
return fmt.Errorf("reading gzip headers: %w", err)
}
fi, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("reading file stats: %w", err)
}
maxDecompressionSize := fi.Size() * 32
gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil {
return fmt.Errorf("reading gzip headers: %w", err)
}
defer gzipUncompressor.Close()
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
}
var reader io.Reader = &io.LimitedReader{
R: gzipUncompressor,
N: maxDecompressionSize,
}
if unrestrict {
reader = gzipUncompressor
}
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
}
defer destFile.Close()
_, err = io.Copy(destFile, gzipUncompressor)
if err != nil {
return fmt.Errorf("uncompressing file: %w", err)
}
_, err = io.Copy(destFile, reader)
if err != nil {
return fmt.Errorf("uncompressing file: %w", err)
}
return nil
}
@@ -87,8 +103,8 @@ func isDir(path string) bool {
file, err := os.Open(path)
if err == nil {
defer file.Close()
stat, err := file.Stat()
if err != nil {
stat, err2 := file.Stat()
if err2 != nil {
return false
}
@@ -106,9 +122,9 @@ func pathForUncompressing(source, dest string) (string, error) {
}
source = filepath.Base(source)
if !strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a not gzip-compressed file", source)
}
if !strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a not gzip-compressed file", source)
}
outFile := source[:len(source)-len(gzipExt)]
outFile = filepath.Join(dest, outFile)
return outFile, nil
@@ -120,9 +136,9 @@ func pathForCompressing(source, dest string) (string, error) {
}
source = filepath.Base(source)
if strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a gzip-compressed file", source)
}
if strings.HasSuffix(source, gzipExt) {
return "", fmt.Errorf("%s is a gzip-compressed file", source)
}
dest = filepath.Join(dest, source+gzipExt)
return dest, nil
@@ -132,8 +148,11 @@ func main() {
var level int
var path string
var target = "."
var err error
var unrestrict bool
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -147,30 +166,31 @@ func main() {
}
if strings.HasSuffix(path, gzipExt) {
target, err := pathForUncompressing(path, target)
target, err = pathForUncompressing(path, target)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
err = uncompress(path, target)
err = uncompress(path, target, unrestrict)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
} else {
target, err := pathForCompressing(path, target)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
return
}
err = compress(path, target, level)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
target, err = pathForCompressing(path, target)
if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
err = compress(path, target, level)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
}

View File

@@ -40,14 +40,14 @@ func main() {
usage()
}
min, err := strconv.Atoi(flag.Arg(1))
minVal, err := strconv.Atoi(flag.Arg(1))
dieIf(err)
max, err := strconv.Atoi(flag.Arg(2))
maxVal, err := strconv.Atoi(flag.Arg(2))
dieIf(err)
code := kind << 6
code += (min << 3)
code += max
fmt.Printf("%0o\n", code)
code += (minVal << 3)
code += maxVal
fmt.Fprintf(os.Stdout, "%0o\n", code)
}

View File

@@ -5,7 +5,6 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"sort"
@@ -47,7 +46,7 @@ func help(w io.Writer) {
}
func loadDatabase() {
data, err := ioutil.ReadFile(dbFile)
data, err := os.ReadFile(dbFile)
if err != nil && os.IsNotExist(err) {
partsDB = &database{
Version: dbVersion,
@@ -74,7 +73,7 @@ func writeDB() {
data, err := json.Marshal(partsDB)
die.If(err)
err = ioutil.WriteFile(dbFile, data, 0644)
err = os.WriteFile(dbFile, data, 0644)
die.If(err)
}

View File

@@ -4,14 +4,13 @@ import (
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"os"
)
var ext = ".bin"
func stripPEM(path string) error {
data, err := ioutil.ReadFile(path)
data, err := os.ReadFile(path)
if err != nil {
return err
}
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
}
return ioutil.WriteFile(path+ext, p.Bytes, 0644)
return os.WriteFile(path+ext, p.Bytes, 0644)
}
func main() {

View File

@@ -3,8 +3,7 @@ package main
import (
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"io"
"os"
"git.wntrmute.dev/kyle/goutils/lib"
@@ -21,9 +20,9 @@ func main() {
path := flag.Arg(0)
if path == "-" {
in, err = ioutil.ReadAll(os.Stdin)
in, err = io.ReadAll(os.Stdin)
} else {
in, err = ioutil.ReadFile(flag.Arg(0))
in, err = os.ReadFile(flag.Arg(0))
}
if err != nil {
lib.Err(lib.ExitFailure, err, "couldn't read file")
@@ -33,5 +32,7 @@ func main() {
if p == nil {
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
}
fmt.Printf("%s", p.Bytes)
if _, err = os.Stdout.Write(p.Bytes); err != nil {
lib.Err(lib.ExitFailure, err, "writing body")
}
}

View File

@@ -70,7 +70,7 @@ func main() {
lib.Err(lib.ExitFailure, err, "failed to read input")
}
case argc > 1:
for i := 0; i < argc; i++ {
for i := range argc {
path := flag.Arg(i)
err = copyFile(path, buf)
if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"os"
)
@@ -13,14 +12,14 @@ func main() {
flag.Parse()
for _, fileName := range flag.Args() {
data, err := ioutil.ReadFile(fileName)
data, err := os.ReadFile(fileName)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
continue
}
fmt.Printf("[+] %s:\n", fileName)
rest := data[:]
fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
rest := data
for {
var p *pem.Block
p, rest = pem.Decode(rest)
@@ -28,13 +27,14 @@ func main() {
break
}
cert, err := x509.ParseCertificate(p.Bytes)
var cert *x509.Certificate
cert, err = x509.ParseCertificate(p.Bytes)
if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
break
}
fmt.Printf("\t%+v\n", cert.Subject.CommonName)
fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
}
}
}

View File

@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
return hashName(path, encodedHash), nil
}
func move(dst, src string, force bool) (err error) {
func move(dst, src string, force bool) error {
if fileutil.FileDoesExist(dst) && !force {
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
}
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
return err
}
defer func(e error) {
var retErr error
defer func(e *error) {
dstFile.Close()
if e != nil {
if *e != nil {
os.Remove(dst)
}
}(err)
}(&retErr)
srcFile, err := os.Open(src)
if err != nil {
retErr = err
return err
}
defer srcFile.Close()
_, err = io.Copy(dstFile, srcFile)
if err != nil {
if _, err = io.Copy(dstFile, srcFile); err != nil {
retErr = err
return err
}
@@ -94,6 +96,44 @@ func init() {
flag.Usage = func() { usage(os.Stdout) }
}
type options struct {
dryRun, force, printChanged, verbose bool
}
func processOne(file string, opt options) error {
renamed, err := newName(file)
if err != nil {
_, _ = lib.Warn(err, "failed to get new file name")
return err
}
if opt.verbose && !opt.printChanged {
fmt.Fprintln(os.Stdout, file)
}
if renamed == file {
return nil
}
if !opt.dryRun {
if err = move(renamed, file, opt.force); err != nil {
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
return err
}
}
if opt.printChanged && !opt.verbose {
fmt.Fprintln(os.Stdout, file, "->", renamed)
}
return nil
}
func run(dryRun, force, printChanged, verbose bool, files []string) {
if verbose && printChanged {
printChanged = false
}
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
for _, file := range files {
_ = processOne(file, opt)
}
}
func main() {
var dryRun, force, printChanged, verbose bool
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
@@ -102,34 +142,5 @@ func main() {
flag.BoolVar(&verbose, "v", false, "list all processed files")
flag.Parse()
if verbose && printChanged {
printChanged = false
}
for _, file := range flag.Args() {
renamed, err := newName(file)
if err != nil {
lib.Warn(err, "failed to get new file name")
continue
}
if verbose && !printChanged {
fmt.Println(file)
}
if renamed != file {
if !dryRun {
err = move(renamed, file, force)
if err != nil {
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
continue
}
}
if printChanged && !verbose {
fmt.Println(file, "->", renamed)
}
}
}
run(dryRun, force, printChanged, verbose, flag.Args())
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"flag"
"fmt"
"io"
@@ -66,24 +67,25 @@ func main() {
for _, remote := range flag.Args() {
u, err := url.Parse(remote)
if err != nil {
lib.Warn(err, "parsing %s", remote)
_, _ = lib.Warn(err, "parsing %s", remote)
continue
}
name := filepath.Base(u.Path)
if name == "" {
lib.Warnx("source URL doesn't appear to name a file")
_, _ = lib.Warnx("source URL doesn't appear to name a file")
continue
}
resp, err := http.Get(remote)
if err != nil {
lib.Warn(err, "fetching %s", remote)
req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
if reqErr != nil {
_, _ = lib.Warn(reqErr, "building request for %s", remote)
continue
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
lib.Warn(err, "fetching %s", remote)
_, _ = lib.Warn(err, "fetching %s", remote)
continue
}

View File

@@ -3,7 +3,7 @@ package main
import (
"flag"
"fmt"
"math/rand"
"math/rand/v2"
"os"
"regexp"
"strconv"
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
sum := 0
var rolls []int
for i := 0; i < count; i++ {
roll := rand.Intn(sides) + 1
for range count {
roll := rand.IntN(sides) + 1 // #nosec G404
sum += roll
rolls = append(rolls, roll)
}

View File

@@ -53,7 +53,7 @@ func init() {
project = wd[len(gopath):]
}
func walkFile(path string, info os.FileInfo, err error) error {
func walkFile(path string, _ os.FileInfo, err error) error {
if ignores[path] {
return filepath.SkipDir
}
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
return nil
}
debug.Println(path)
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err != nil {
return err
}
debug.Println(path)
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err2 != nil {
return err2
}
for _, importSpec := range f.Imports {
importPath := strings.Trim(importSpec.Path.Value, `"`)
if stdLibRegexp.MatchString(importPath) {
switch {
case stdLibRegexp.MatchString(importPath):
debug.Println("standard lib:", importPath)
continue
} else if strings.HasPrefix(importPath, project) {
case strings.HasPrefix(importPath, project):
debug.Println("internal import:", importPath)
continue
} else if strings.HasPrefix(importPath, "golang.org/") {
case strings.HasPrefix(importPath, "golang.org/"):
debug.Println("extended lib:", importPath)
continue
}
@@ -102,7 +107,7 @@ func main() {
ignores["vendor"] = true
}
for _, word := range strings.Split(ignoreLine, ",") {
for word := range strings.SplitSeq(ignoreLine, ",") {
ignores[strings.TrimSpace(word)] = true
}

View File

@@ -2,10 +2,9 @@ package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha1"
"crypto/sha1" // #nosec G505
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
@@ -13,14 +12,19 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"strings"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -39,14 +43,14 @@ func init() {
flag.Usage = func() { usage(os.Stderr) }
}
func parse(path string) (public []byte, kt, ft string) {
data, err := ioutil.ReadFile(path)
func parse(path string) ([]byte, string, string) {
data, err := os.ReadFile(path)
die.If(err)
data = bytes.TrimSpace(data)
p, rest := pem.Decode(data)
if len(rest) > 0 {
lib.Warnx("trailing data in PEM file")
_, _ = lib.Warnx("trailing data in PEM file")
}
if p == nil {
@@ -55,6 +59,12 @@ func parse(path string) (public []byte, kt, ft string) {
data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data)
@@ -69,82 +79,79 @@ func parse(path string) (public []byte, kt, ft string) {
die.With("unknown PEM type %s", p.Type)
}
return
return public, kt, ft
}
func parseKey(data []byte) (public []byte, kt string) {
privInterface, err := x509.ParsePKCS8PrivateKey(data)
func parseKey(data []byte) ([]byte, string) {
priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil {
privInterface, err = x509.ParsePKCS1PrivateKey(data)
if err != nil {
privInterface, err = x509.ParseECPrivateKey(data)
if err != nil {
die.With("couldn't parse private key.")
}
}
die.If(err)
}
var priv crypto.Signer
switch privInterface.(type) {
case *rsa.PrivateKey:
priv = privInterface.(*rsa.PrivateKey)
kt = "RSA"
case *ecdsa.PrivateKey:
priv = privInterface.(*ecdsa.PrivateKey)
kt = "ECDSA"
var kt string
switch priv.Public().(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown private key type %T", privInterface)
die.With("unknown private key type %T", priv)
}
public, err = x509.MarshalPKIXPublicKey(priv.Public())
public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err)
return
return public, kt
}
func parseCertificate(data []byte) (public []byte, kt string) {
func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data)
die.If(err)
pub := cert.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = "RSA"
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = "ECDSA"
kt = keyTypeECDSA
default:
die.With("unknown public key type %T", pub)
}
public, err = x509.MarshalPKIXPublicKey(pub)
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return
return public, kt
}
func parseCSR(data []byte) (public []byte, kt string) {
csr, err := x509.ParseCertificateRequest(data)
func parseCSR(data []byte) ([]byte, string) {
// Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err)
pub := csr.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = "RSA"
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = "ECDSA"
kt = keyTypeECDSA
default:
die.With("unknown public key type %T", pub)
}
public, err = x509.MarshalPKIXPublicKey(pub)
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return
return public, kt
}
func dumpHex(in []byte) string {
var s string
var sSb153 strings.Builder
for i := range in {
s += fmt.Sprintf("%02X:", in[i])
sSb153.WriteString(fmt.Sprintf("%02X:", in[i]))
}
s += sSb153.String()
return strings.Trim(s, ":")
}
@@ -172,18 +179,18 @@ func main() {
var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI)
if err != nil {
lib.Warn(err, "failed to get subject PKI")
_, _ = lib.Warn(err, "failed to get subject PKI")
continue
}
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes)
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:])
if ski == "" {
ski = pubHashString
}
if shouldMatch && ski != pubHashString {
lib.Warnx("%s: SKI mismatch (%s != %s)",
_, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString)
}
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)

View File

@@ -1,16 +1,17 @@
package main
import (
"context"
"flag"
"io"
"log"
"net"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
func proxy(conn net.Conn, inside string) error {
proxyConn, err := net.Dial("tcp", inside)
proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
if err != nil {
return err
}
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
defer conn.Close()
go func() {
io.Copy(conn, proxyConn)
_, _ = io.Copy(conn, proxyConn)
}()
_, err = io.Copy(proxyConn, conn)
return err
@@ -31,16 +32,22 @@ func main() {
flag.StringVar(&inside, "p", "4000", "inside port")
flag.Parse()
l, err := net.Listen("tcp", "0.0.0.0:"+outside)
lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
die.If(err)
for {
conn, err := l.Accept()
var conn net.Conn
conn, err = l.Accept()
if err != nil {
log.Println(err)
_, _ = lib.Warn(err, "accept failed")
continue
}
go proxy(conn, "127.0.0.1:"+inside)
go func() {
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
_, _ = lib.Warn(err, "proxy error")
}
}()
}
}

View File

@@ -1,6 +1,7 @@
package main
import (
"context"
"crypto/rand"
"crypto/tls"
"crypto/x509"
@@ -8,7 +9,6 @@ import (
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
@@ -16,7 +16,7 @@ import (
)
func main() {
cfg := &tls.Config{}
cfg := &tls.Config{} // #nosec G402
var sysRoot, listenAddr, certFile, keyFile string
var verify bool
@@ -47,7 +47,8 @@ func main() {
}
cfg.Certificates = append(cfg.Certificates, cert)
if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot)
var pemList []byte
pemList, err = os.ReadFile(sysRoot)
die.If(err)
roots := x509.NewCertPool()
@@ -59,48 +60,54 @@ func main() {
cfg.RootCAs = roots
}
l, err := net.Listen("tcp", listenAddr)
lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
}
for {
conn, err := l.Accept()
var conn net.Conn
conn, err = l.Accept()
if err != nil {
fmt.Println(err.Error())
}
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
err = tconn.Handshake()
if err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
continue
}
cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
continue
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
_, err = rand.Read(nonce[:])
if err != nil {
panic(err)
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
err = ioutil.WriteFile(fname, chain, 0644)
die.If(err)
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
handleConn(conn, cfg)
}
}
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
func handleConn(conn net.Conn, cfg *tls.Config) {
defer conn.Close()
raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg)
if err := tconn.HandshakeContext(context.Background()); err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
return
}
cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr)
return
}
var chain []byte
for _, cert := range cs.PeerCertificates {
p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
chain = append(chain, pem.EncodeToMemory(p)...)
}
var nonce [16]byte
if _, err := rand.Read(nonce[:]); err != nil {
fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
return
}
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
if err := os.WriteFile(fname, chain, 0o644); err != nil {
fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
return
}
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
}

View File

@@ -1,12 +1,12 @@
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
@@ -14,7 +14,7 @@ import (
)
func main() {
var cfg = &tls.Config{}
var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
@@ -23,7 +23,7 @@ func main() {
flag.Parse()
if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot)
pemList, err := os.ReadFile(sysRoot)
die.If(err)
roots := x509.NewCertPool()
@@ -44,10 +44,13 @@ func main() {
if err != nil {
site += ":443"
}
conn, err := tls.Dial("tcp", site, cfg)
if err != nil {
fmt.Println(err.Error())
os.Exit(1)
d := &tls.Dialer{Config: cfg}
nc, err := d.DialContext(context.Background(), "tcp", site)
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
cs := conn.ConnectionState()
@@ -61,8 +64,9 @@ func main() {
chain = append(chain, pem.EncodeToMemory(p)...)
}
err = ioutil.WriteFile(site+".pem", chain, 0644)
err = os.WriteFile(site+".pem", chain, 0644)
die.If(err)
fmt.Printf("[+] wrote %s.pem.\n", site)
}
}

View File

@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
for _, path := range paths {
cert, err := certlib.LoadCertificate(path)
if err != nil {
lib.Warn(err, "failed to load certificate from %s", path)
_, _ = lib.Warn(err, "failed to load certificate from %s", path)
continue
}
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
}
var invalid int
for {
if len(paths) == 0 {
break
}
for len(paths) > 0 {
fst := paths[0]
snd := paths[1]
paths = paths[2:]
fstCert, err := certlib.LoadCertificate(fst)
die.If(err)
sndCert, err := certlib.LoadCertificate(snd)
die.If(err)
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
lib.Warnx("certificates don't match: %s and %s", fst, snd)
_, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
invalid++
}
}

View File

@@ -1,10 +1,14 @@
package main
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/die"
)
func main() {
@@ -13,16 +17,23 @@ func main() {
os.Exit(1)
}
hostPort := os.Args[1]
conn, err := tls.Dial("tcp", hostPort, &tls.Config{
InsecureSkipVerify: true,
})
hostPort, err := hosts.ParseHost(os.Args[1])
die.If(err)
if err != nil {
fmt.Printf("Failed to connect to the TLS server: %v\n", err)
os.Exit(1)
d := &tls.Dialer{Config: &tls.Config{
InsecureSkipVerify: true,
}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
state := conn.ConnectionState()
printConnectionDetails(state)
}
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
func tlsVersion(version uint16) string {
switch version {
case tls.VersionTLS13:
return "TLS 1.3"
case tls.VersionTLS12:

View File

@@ -11,10 +11,9 @@ import (
"errors"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
)
@@ -32,7 +31,7 @@ const (
curveP521
)
func getECCurve(pub interface{}) int {
func getECCurve(pub any) int {
switch pub := pub.(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
@@ -52,42 +51,88 @@ func getECCurve(pub interface{}) int {
}
}
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) {
in, err := ioutil.ReadFile(path)
in, err := os.ReadFile(path)
if err != nil {
return nil, err
}
in = bytes.TrimSpace(in)
p, _ := pem.Decode(in)
if p != nil {
if p, _ := pem.Decode(in); p != nil {
if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type)
}
in = p.Bytes
return certlib.ParsePrivateKeyPEM(in)
}
priv, err := x509.ParsePKCS8PrivateKey(in)
if err != nil {
priv, err = x509.ParsePKCS1PrivateKey(in)
if err != nil {
priv, err = x509.ParseECPrivateKey(in)
if err != nil {
return nil, err
}
}
}
switch priv.(type) {
case *rsa.PrivateKey:
return priv.(*rsa.PrivateKey), nil
case *ecdsa.PrivateKey:
return priv.(*ecdsa.PrivateKey), nil
}
// should never reach here
return nil, errors.New("invalid private key")
return certlib.ParsePrivateKeyDER(in)
}
func main() {
@@ -96,7 +141,7 @@ func main() {
flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse()
in, err := ioutil.ReadFile(certFile)
in, err := os.ReadFile(certFile)
die.If(err)
p, _ := pem.Decode(in)
@@ -112,50 +157,11 @@ func main() {
priv, err := loadKey(keyFile)
die.If(err)
switch pub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.")
return
case *ecdsa.PublicKey:
fmt.Println("No match (RSA private key, EC public key).")
os.Exit(1)
}
case *ecdsa.PublicKey:
privCurve := getECCurve(pub)
certCurve := getECCurve(cert.PublicKey)
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
if certCurve == curveRSA {
fmt.Println("No match (private key is EC, certificate is RSA).")
os.Exit(1)
} else if privCurve == curveInvalid {
fmt.Println("No match (invalid private key curve).")
os.Exit(1)
} else if privCurve != certCurve {
fmt.Println("No match (EC curves don't match).")
os.Exit(1)
}
certPub := cert.PublicKey.(*ecdsa.PublicKey)
if pub.X.Cmp(certPub.X) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
if pub.Y.Cmp(certPub.Y) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
matched, reason := matchKeys(cert, priv)
if matched {
fmt.Println("Match.")
default:
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
os.Exit(1)
return
}
fmt.Printf("No match (%s).\n", reason)
os.Exit(1)
}

View File

@@ -201,10 +201,6 @@ func init() {
os.Exit(1)
}
if fromLoc == time.UTC {
}
toLoc = time.UTC
}
@@ -257,15 +253,16 @@ func main() {
showTime(time.Now())
os.Exit(0)
case 1:
if flag.Arg(0) == "-" {
switch {
case flag.Arg(0) == "-":
s := bufio.NewScanner(os.Stdin)
for s.Scan() {
times = append(times, s.Text())
}
} else if flag.Arg(0) == "help" {
case flag.Arg(0) == "help":
usageExamples()
} else {
default:
times = flag.Args()
}
default:

View File

@@ -4,7 +4,6 @@ import (
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"gopkg.in/yaml.v2"
@@ -12,9 +11,8 @@ import (
type empty struct{}
func errorf(format string, args ...interface{}) {
format += "\n"
fmt.Fprintf(os.Stderr, format, args...)
func errorf(path string, err error) {
fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
}
func usage(w io.Writer) {
@@ -44,16 +42,16 @@ func main() {
if flag.NArg() == 1 && flag.Arg(0) == "-" {
path := "stdin"
in, err := ioutil.ReadAll(os.Stdin)
in, err := io.ReadAll(os.Stdin)
if err != nil {
errorf("%s FAILED: %s", path, err)
errorf(path, err)
os.Exit(1)
}
var e empty
err = yaml.Unmarshal(in, &e)
if err != nil {
errorf("%s FAILED: %s", path, err)
errorf(path, err)
os.Exit(1)
}
@@ -65,16 +63,16 @@ func main() {
}
for _, path := range flag.Args() {
in, err := ioutil.ReadFile(path)
in, err := os.ReadFile(path)
if err != nil {
errorf("%s FAILED: %s", path, err)
errorf(path, err)
continue
}
var e empty
err = yaml.Unmarshal(in, &e)
if err != nil {
errorf("%s FAILED: %s", path, err)
errorf(path, err)
continue
}

View File

@@ -14,16 +14,16 @@ import (
"os"
"path/filepath"
"regexp"
"git.wntrmute.dev/kyle/goutils/lib"
)
const defaultDirectory = ".git/objects"
func errorf(format string, a ...interface{}) {
fmt.Fprintf(os.Stderr, format, a...)
if format[len(format)-1] != '\n' {
fmt.Fprintf(os.Stderr, "\n")
}
}
// maxDecompressedSize limits how many bytes we will decompress from a zlib
// stream to mitigate decompression bombs (gosec G110).
// Increase this if you expect larger objects.
const maxDecompressedSize int64 = 64 << 30 // 64 GiB
func isDir(path string) bool {
fi, err := os.Stat(path)
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
}
defer zread.Close()
_, err = io.Copy(buf, zread)
if err != nil {
// Protect against decompression bombs by limiting how much we read.
lr := io.LimitReader(zread, maxDecompressedSize+1)
if _, err = buf.ReadFrom(lr); err != nil {
return nil, err
}
if int64(buf.Len()) > maxDecompressedSize {
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
}
return buf.Bytes(), nil
}
func showFile(path string) {
fileData, err := loadFile(path)
if err != nil {
errorf("%v", err)
lib.Warn(err, "failed to load %s", path)
return
}
@@ -68,39 +72,71 @@ func showFile(path string) {
func searchFile(path string, search *regexp.Regexp) error {
file, err := os.Open(path)
if err != nil {
errorf("%v", err)
lib.Warn(err, "failed to open %s", path)
return err
}
defer file.Close()
zread, err := zlib.NewReader(file)
if err != nil {
errorf("%v", err)
lib.Warn(err, "failed to decompress %s", path)
return err
}
defer zread.Close()
zbuf := bufio.NewReader(zread)
if search.MatchReader(zbuf) {
fileData, err := loadFile(path)
if err != nil {
errorf("%v", err)
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
// Limit how much we scan to avoid DoS via huge decompression.
lr := io.LimitReader(zread, maxDecompressedSize+1)
zbuf := bufio.NewReader(lr)
if !search.MatchReader(zbuf) {
return nil
}
fileData, err := loadFile(path)
if err != nil {
lib.Warn(err, "failed to load %s", path)
return err
}
fmt.Printf("%s:\n%s\n", path, fileData)
return nil
}
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error {
if info.Mode().IsRegular() {
return searchFile(path, searchExpr)
return func(path string, info os.FileInfo, _ error) error {
if !info.Mode().IsRegular() {
return nil
}
return nil
return searchFile(path, searchExpr)
}
}
// runSearch compiles the search expression and processes the provided paths.
// It returns an error for fatal conditions; per-file errors are logged.
func runSearch(expr string) error {
search, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("invalid regexp: %w", err)
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
return err2
}
continue
}
if err2 := searchFile(path, search); err2 != nil {
// Non-fatal: keep going, but report it.
lib.Warn(err2, "non-fatal error while searching files")
}
}
return nil
}
func main() {
flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)")
flag.Parse()
@@ -109,28 +145,10 @@ func main() {
for _, path := range flag.Args() {
showFile(path)
}
} else {
search, err := regexp.Compile(*flSearch)
if err != nil {
errorf("Bad regexp: %v", err)
return
}
return
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
err := filepath.Walk(path, buildWalker(search))
if err != nil {
errorf("%v", err)
return
}
} else {
searchFile(path, search)
}
}
if err := runSearch(*flSearch); err != nil {
lib.Err(lib.ExitFailure, err, "failed to run search")
}
}

1
go.mod
View File

@@ -22,4 +22,5 @@ require (
github.com/kr/pretty v0.1.0 // indirect
github.com/pkg/errors v0.9.1 // indirect
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
rsc.io/qr v0.2.0 // indirect
)

2
go.sum
View File

@@ -44,3 +44,5 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -1,4 +1,4 @@
// +build freebsd darwin,386 netbsd
//go:build bsd
package lib

View File

@@ -1,4 +1,4 @@
// +build unix linux openbsd darwin,amd64
//go:build unix || linux || openbsd || (darwin && amd64)
package lib
@@ -18,7 +18,7 @@ type FileTime struct {
func timeSpecToTime(ts unix.Timespec) time.Time {
// The casts to int64 are needed because on 386, these are int32s.
return time.Unix(int64(ts.Sec), int64(ts.Nsec))
return time.Unix(ts.Sec, ts.Nsec)
}
// LoadFileTime returns a FileTime associated with the file.

View File

@@ -2,14 +2,22 @@
package lib
import (
"encoding/hex"
"fmt"
"os"
"path/filepath"
"strings"
"time"
)
var progname = filepath.Base(os.Args[0])
const (
daysInYear = 365
digitWidth = 10
hoursInQuarterDay = 6
)
// ProgName returns what lib thinks the program name is, namely the
// basename of argv0.
//
@@ -20,7 +28,7 @@ func ProgName() string {
// Warnx displays a formatted error message to standard error, à la
// warnx(3).
func Warnx(format string, a ...interface{}) (int, error) {
func Warnx(format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n"
return fmt.Fprintf(os.Stderr, format, a...)
@@ -28,7 +36,7 @@ func Warnx(format string, a ...interface{}) (int, error) {
// Warn displays a formatted error message to standard output,
// appending the error string, à la warn(3).
func Warn(err error, format string, a ...interface{}) (int, error) {
func Warn(err error, format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n"
a = append(a, err)
@@ -37,7 +45,7 @@ func Warn(err error, format string, a ...interface{}) (int, error) {
// Errx displays a formatted error message to standard error and exits
// with the status code from `exit`, à la errx(3).
func Errx(exit int, format string, a ...interface{}) {
func Errx(exit int, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n"
fmt.Fprintf(os.Stderr, format, a...)
@@ -47,7 +55,7 @@ func Errx(exit int, format string, a ...interface{}) {
// Err displays a formatting error message to standard error,
// appending the error string, and exits with the status code from
// `exit`, à la err(3).
func Err(exit int, err error, format string, a ...interface{}) {
func Err(exit int, err error, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n"
a = append(a, err)
@@ -62,30 +70,30 @@ func Itoa(i int, wid int) string {
// Assemble decimal in reverse order.
var b [20]byte
bp := len(b) - 1
for i >= 10 || wid > 1 {
for i >= digitWidth || wid > 1 {
wid--
q := i / 10
b[bp] = byte('0' + i - q*10)
q := i / digitWidth
b[bp] = byte('0' + i - q*digitWidth)
bp--
i = q
}
// i < 10
b[bp] = byte('0' + i)
return string(b[bp:])
}
var (
dayDuration = 24 * time.Hour
yearDuration = (365 * dayDuration) + (6 * time.Hour)
yearDuration = (daysInYear * dayDuration) + (hoursInQuarterDay * time.Hour)
)
// Duration returns a prettier string for time.Durations.
func Duration(d time.Duration) string {
var s string
if d >= yearDuration {
years := d / yearDuration
years := int64(d / yearDuration)
s += fmt.Sprintf("%dy", years)
d -= years * yearDuration
d -= time.Duration(years) * yearDuration
}
if d >= dayDuration {
@@ -98,8 +106,116 @@ func Duration(d time.Duration) string {
}
d %= 1 * time.Second
hours := d / time.Hour
d -= hours * time.Hour
hours := int64(d / time.Hour)
d -= time.Duration(hours) * time.Hour
s += fmt.Sprintf("%dh%s", hours, d)
return s
}
type HexEncodeMode uint8
const (
// HexEncodeLower prints the bytes as lowercase hexadecimal.
HexEncodeLower HexEncodeMode = iota + 1
// HexEncodeUpper prints the bytes as uppercase hexadecimal.
HexEncodeUpper
// HexEncodeLowerColon prints the bytes as lowercase hexadecimal
// with colons between each pair of bytes.
HexEncodeLowerColon
// HexEncodeUpperColon prints the bytes as uppercase hexadecimal
// with colons between each pair of bytes.
HexEncodeUpperColon
)
func (m HexEncodeMode) String() string {
switch m {
case HexEncodeLower:
return "lower"
case HexEncodeUpper:
return "upper"
case HexEncodeLowerColon:
return "lcolon"
case HexEncodeUpperColon:
return "ucolon"
default:
panic("invalid hex encode mode")
}
}
func ParseHexEncodeMode(s string) HexEncodeMode {
switch strings.ToLower(s) {
case "lower":
return HexEncodeLower
case "upper":
return HexEncodeUpper
case "lcolon":
return HexEncodeLowerColon
case "ucolon":
return HexEncodeUpperColon
}
panic("invalid hex encode mode")
}
func hexColons(s string) string {
if len(s)%2 != 0 {
fmt.Fprintf(os.Stderr, "hex string: %s\n", s)
fmt.Fprintf(os.Stderr, "hex length: %d\n", len(s))
panic("invalid hex string length")
}
n := len(s)
if n <= 2 {
return s
}
pairCount := n / 2
if n%2 != 0 {
pairCount++
}
var b strings.Builder
b.Grow(n + pairCount - 1)
for i := 0; i < n; i += 2 {
b.WriteByte(s[i])
if i+1 < n {
b.WriteByte(s[i+1])
}
if i+2 < n {
b.WriteByte(':')
}
}
return b.String()
}
func hexEncode(b []byte) string {
s := hex.EncodeToString(b)
if len(s)%2 != 0 {
s = "0" + s
}
return s
}
// HexEncode encodes the given bytes as a hexadecimal string.
func HexEncode(b []byte, mode HexEncodeMode) string {
str := hexEncode(b)
switch mode {
case HexEncodeLower:
return str
case HexEncodeUpper:
return strings.ToUpper(str)
case HexEncodeLowerColon:
return hexColons(str)
case HexEncodeUpperColon:
return strings.ToUpper(hexColons(str))
default:
panic("invalid hex encode mode")
}
}

79
lib/lib_test.go Normal file
View File

@@ -0,0 +1,79 @@
package lib_test
import (
"testing"
"git.wntrmute.dev/kyle/goutils/lib"
)
func TestHexEncode_LowerUpper(t *testing.T) {
b := []byte{0x0f, 0xa1, 0x00, 0xff}
gotLower := lib.HexEncode(b, lib.HexEncodeLower)
if gotLower != "0fa100ff" {
t.Fatalf("lib.HexEncode lower: expected %q, got %q", "0fa100ff", gotLower)
}
gotUpper := lib.HexEncode(b, lib.HexEncodeUpper)
if gotUpper != "0FA100FF" {
t.Fatalf("lib.HexEncode upper: expected %q, got %q", "0FA100FF", gotUpper)
}
}
func TestHexEncode_ColonModes(t *testing.T) {
// Includes leading zero nibble and a zero byte to verify padding and separators
b := []byte{0x0f, 0xa1, 0x00, 0xff}
gotLColon := lib.HexEncode(b, lib.HexEncodeLowerColon)
if gotLColon != "0f:a1:00:ff" {
t.Fatalf("lib.HexEncode colon lower: expected %q, got %q", "0f:a1:00:ff", gotLColon)
}
gotUColon := lib.HexEncode(b, lib.HexEncodeUpperColon)
if gotUColon != "0F:A1:00:FF" {
t.Fatalf("lib.HexEncode colon upper: expected %q, got %q", "0F:A1:00:FF", gotUColon)
}
}
func TestHexEncode_EmptyInput(t *testing.T) {
var b []byte
if got := lib.HexEncode(b, lib.HexEncodeLower); got != "" {
t.Fatalf("empty lower: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpper); got != "" {
t.Fatalf("empty upper: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeLowerColon); got != "" {
t.Fatalf("empty colon lower: expected empty string, got %q", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpperColon); got != "" {
t.Fatalf("empty colon upper: expected empty string, got %q", got)
}
}
func TestHexEncode_SingleByte(t *testing.T) {
b := []byte{0x0f}
if got := lib.HexEncode(b, lib.HexEncodeLower); got != "0f" {
t.Fatalf("single byte lower: expected %q, got %q", "0f", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpper); got != "0F" {
t.Fatalf("single byte upper: expected %q, got %q", "0F", got)
}
// For a single byte, colon modes should not introduce separators
if got := lib.HexEncode(b, lib.HexEncodeLowerColon); got != "0f" {
t.Fatalf("single byte colon lower: expected %q, got %q", "0f", got)
}
if got := lib.HexEncode(b, lib.HexEncodeUpperColon); got != "0F" {
t.Fatalf("single byte colon upper: expected %q, got %q", "0F", got)
}
}
func TestHexEncode_InvalidModePanics(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatalf("expected panic for invalid mode, but function returned normally")
}
}()
// 0 is not a valid lib.HexEncodeMode (valid modes start at 1)
_ = lib.HexEncode([]byte{0x01}, lib.HexEncodeMode(0))
}

View File

@@ -1,8 +1,9 @@
package logging
import (
"fmt"
"os"
"errors"
"fmt"
"os"
)
// File writes its logs to file.
@@ -59,12 +60,12 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) {
fl.fe, err = os.OpenFile(errpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
}
if err != nil {
if closeErr := fl.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err)
}
return nil, err
}
if err != nil {
if closeErr := fl.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err))
}
return nil, fmt.Errorf("failed to open error log: %w", err)
}
fl.LogWriter = NewLogWriter(fl.fo, fl.fe)
return fl, nil
@@ -94,13 +95,13 @@ func (fl *File) Flush() error {
}
func (fl *File) Chmod(mode os.FileMode) error {
if err := fl.fo.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod output log: %w", err)
}
if err := fl.fo.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod output log: %w", err)
}
if err := fl.fe.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod error log: %w", err)
}
if err := fl.fe.Chmod(mode); err != nil {
return fmt.Errorf("failed to chmod error log: %w", err)
}
return nil
return nil
}

33
twofactor/README.md Normal file
View File

@@ -0,0 +1,33 @@
## `twofactor`
[![GoDoc](https://godoc.org/github.com/gokyle/twofactor?status.svg)](https://godoc.org/github.com/gokyle/twofactor)
### Author
`twofactor` was written by Kyle Isom <kyle@tyrfingr.is>.
### License
```
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
```

5
twofactor/doc.go Normal file
View File

@@ -0,0 +1,5 @@
// Package twofactor implements two-factor authentication.
//
// Currently supported are RFC 4226 HOTP one-time passwords and
// RFC 6238 TOTP SHA-1 one-time passwords.
package twofactor

103
twofactor/hotp.go Normal file
View File

@@ -0,0 +1,103 @@
package twofactor
import (
"crypto"
"crypto/sha1" // #nosec G505 - required by RFC
"encoding/base32"
"io"
"net/url"
"strconv"
"strings"
)
// HOTP represents an RFC-4226 Hash-based One Time Password instance.
type HOTP struct {
*OATH
}
// NewHOTP takes the key, the initial counter value, and the number
// of digits (typically 6 or 8) and returns a new HOTP instance.
func NewHOTP(key []byte, counter uint64, digits int) *HOTP {
return &HOTP{
OATH: &OATH{
key: key,
counter: counter,
size: digits,
hash: sha1.New,
algo: crypto.SHA1,
},
}
}
// Type returns OATH_HOTP.
func (otp *HOTP) Type() Type {
return OATH_HOTP
}
// OTP returns the next OTP and increments the counter.
func (otp *HOTP) OTP() string {
code := otp.OATH.OTP(otp.counter)
otp.counter++
return code
}
// URL returns an HOTP URL (i.e. for putting in a QR code).
func (otp *HOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *HOTP) SetProvider(provider string) {
otp.provider = provider
}
// GenerateGoogleHOTP generates a new HOTP instance as used by
// Google Authenticator.
func GenerateGoogleHOTP() *HOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewHOTP(key, 0, 6)
}
func hotpFromURL(u *url.URL) (*HOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var counter uint64
if scounter := v.Get("counter"); scounter != "" {
var err error
counter, err = strconv.ParseUint(scounter, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewHOTP(key, counter, digits)
return otp, label, nil
}
// QR generates a new QR code for the HOTP.
func (otp *HOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}

View File

@@ -0,0 +1,58 @@
package twofactor
import (
"testing"
)
var testKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
var rfcHotpKey = []byte("12345678901234567890")
var rfcHotpExpected = []string{
"755224",
"287082",
"359152",
"969429",
"338314",
"254676",
"287922",
"162583",
"399871",
"520489",
}
// This test runs through the test cases presented in the RFC, and
// ensures that this implementation is in compliance.
func TestHotpRFC(t *testing.T) {
otp := NewHOTP(rfcHotpKey, 0, 6)
for i := range rfcHotpExpected {
if otp.Counter() != uint64(i) {
t.Fatalf("twofactor: invalid counter (should be %d, is %d",
i, otp.Counter())
}
code := otp.OTP()
if code == "" {
t.Fatal("twofactor: failed to produce an OTP")
} else if code != rfcHotpExpected[i] {
t.Logf("twofactor: invalid OTP\n")
t.Logf("\tExpected: %s\n", rfcHotpExpected[i])
t.Logf("\t Actual: %s\n", code)
t.Fatalf("\t Counter: %d\n", otp.counter)
}
}
}
// This test uses a different key than the test cases in the RFC,
// but runs through the same test cases to ensure that they fail as
// expected.
func TestHotpBadRFC(t *testing.T) {
otp := NewHOTP(testKey, 0, 6)
for i := range rfcHotpExpected {
code := otp.OTP()
switch code {
case "":
t.Error("twofactor: failed to produce an OTP")
case rfcHotpExpected[i]:
t.Error("twofactor: should not have received a valid OTP")
}
}
}

150
twofactor/oath.go Normal file
View File

@@ -0,0 +1,150 @@
package twofactor
import (
"crypto"
"crypto/hmac"
"encoding/base32"
"encoding/binary"
"fmt"
"hash"
"net/url"
"strconv"
"rsc.io/qr"
)
const defaultSize = 6
// OATH provides a baseline structure for the two OATH algorithms.
type OATH struct {
key []byte
counter uint64
size int
hash func() hash.Hash
algo crypto.Hash
provider string
}
// Size returns the output size (in characters) of the password.
func (o *OATH) Size() int {
return o.size
}
// Counter returns the OATH token's counter.
func (o *OATH) Counter() uint64 {
return o.counter
}
// SetCounter updates the OATH token's counter to a new value.
func (o *OATH) SetCounter(counter uint64) {
o.counter = counter
}
// Key returns the token's secret key.
func (o *OATH) Key() []byte {
return o.key
}
// Hash returns the token's hash function.
func (o *OATH) Hash() func() hash.Hash {
return o.hash
}
// URL constructs a URL appropriate for the token (i.e. for use in a
// QR code).
func (o *OATH) URL(t Type, label string) string {
secret := base32.StdEncoding.EncodeToString(o.key)
u := url.URL{}
v := url.Values{}
u.Scheme = "otpauth"
switch t {
case OATH_HOTP:
u.Host = "hotp"
case OATH_TOTP:
u.Host = "totp"
}
u.Path = label
v.Add("secret", secret)
if o.Counter() != 0 && t == OATH_HOTP {
v.Add("counter", strconv.FormatUint(o.Counter(), 10))
}
if o.Size() != defaultSize {
v.Add("digits", strconv.Itoa(o.Size()))
}
switch o.algo {
case crypto.SHA256:
v.Add("algorithm", "SHA256")
case crypto.SHA512:
v.Add("algorithm", "SHA512")
}
if o.provider != "" {
v.Add("provider", o.provider)
}
u.RawQuery = v.Encode()
return u.String()
}
var digits = []int64{
0: 1,
1: 10,
2: 100,
3: 1000,
4: 10000,
5: 100000,
6: 1000000,
7: 10000000,
8: 100000000,
9: 1000000000,
10: 10000000000,
}
// OTP top-level type should provide a counter; for example, HOTP
// will provide the counter directly while TOTP will provide the
// time-stepped counter.
func (o *OATH) OTP(counter uint64) string {
var ctr [8]byte
binary.BigEndian.PutUint64(ctr[:], counter)
var mod int64 = 1
if len(digits) > o.size {
for i := 1; i <= o.size; i++ {
mod *= 10
}
} else {
mod = digits[o.size]
}
h := hmac.New(o.hash, o.key)
h.Write(ctr[:])
dt := truncate(h.Sum(nil)) % mod
fmtStr := fmt.Sprintf("%%0%dd", o.size)
return fmt.Sprintf(fmtStr, dt)
}
// truncate contains the DT function from the RFC; this is used to
// deterministically select a sequence of 4 bytes from the HMAC
// counter hash.
func truncate(in []byte) int64 {
offset := int(in[len(in)-1] & 0xF)
p := in[offset : offset+4]
var binCode int32
binCode = int32((p[0] & 0x7f)) << 24
binCode += int32((p[1] & 0xff)) << 16
binCode += int32((p[2] & 0xff)) << 8
binCode += int32((p[3] & 0xff))
return int64(binCode) & 0x7FFFFFFF
}
// QR generates a byte slice containing the a QR code encoded as a
// PNG with level Q error correction.
func (o *OATH) QR(t Type, label string) ([]byte, error) {
u := o.URL(t, label)
code, err := qr.Encode(u, qr.Q)
if err != nil {
return nil, err
}
return code.PNG(), nil
}

View File

@@ -0,0 +1,27 @@
package twofactor
import (
"testing"
)
var sha1Hmac = []byte{
0x1f, 0x86, 0x98, 0x69, 0x0e,
0x02, 0xca, 0x16, 0x61, 0x85,
0x50, 0xef, 0x7f, 0x19, 0xda,
0x8e, 0x94, 0x5b, 0x55, 0x5a,
}
var truncExpect int64 = 0x50ef7f19
// This test runs through the truncation example given in the RFC.
func TestTruncate(t *testing.T) {
if result := truncate(sha1Hmac); result != truncExpect {
t.Fatalf("hotp: expected truncate -> %d, saw %d\n",
truncExpect, result)
}
sha1Hmac[19]++
if result := truncate(sha1Hmac); result == truncExpect {
t.Fatal("hotp: expected truncation to fail")
}
}

86
twofactor/otp.go Normal file
View File

@@ -0,0 +1,86 @@
package twofactor
import (
"crypto/rand"
"errors"
"fmt"
"hash"
"net/url"
)
type Type uint
const (
OATH_HOTP = iota
OATH_TOTP
)
// PRNG is an io.Reader that provides a cryptographically secure
// random byte stream.
var PRNG = rand.Reader
var (
ErrInvalidURL = errors.New("twofactor: invalid URL")
ErrInvalidAlgo = errors.New("twofactor: invalid algorithm")
)
// OTP represents a one-time password token -- whether a
// software taken (as in the case of Google Authenticator) or a
// hardware token (as in the case of a YubiKey).
type OTP interface {
// Returns the current counter value; the meaning of the
// returned value is algorithm-specific.
Counter() uint64
// Set the counter to a specific value.
SetCounter(uint64)
// the secret key contained in the OTP
Key() []byte
// generate a new OTP
OTP() string
// the output size of the OTP
Size() int
// the hash function used by the OTP
Hash() func() hash.Hash
// Returns the type of this OTP.
Type() Type
}
func otpString(otp OTP) string {
var typeName string
switch otp.Type() {
case OATH_HOTP:
typeName = "OATH-HOTP"
case OATH_TOTP:
typeName = "OATH-TOTP"
default:
typeName = "UNKNOWN"
}
return fmt.Sprintf("%s, %d", typeName, otp.Size())
}
// FromURL constructs a new OTP token from a URL string.
func FromURL(otpURL string) (OTP, string, error) {
u, err := url.Parse(otpURL)
if err != nil {
return nil, "", err
}
if u.Scheme != "otpauth" {
return nil, "", ErrInvalidURL
}
switch u.Host {
case "totp":
return totpFromURL(u)
case "hotp":
return hotpFromURL(u)
default:
return nil, "", ErrInvalidURL
}
}

View File

@@ -0,0 +1,126 @@
package twofactor
import (
"io"
"testing"
)
func TestHOTPString(t *testing.T) {
hotp := NewHOTP(nil, 0, 6)
hotpString := otpString(hotp)
if hotpString != "OATH-HOTP, 6" {
t.Fatal("twofactor: invalid OTP string")
}
}
// This test generates a new OTP, outputs the URL for that OTP,
// and attempts to parse that URL. It verifies that the two OTPs
// are the same, and that they produce the same output.
func TestURL(t *testing.T) {
var ident = "testuser@foo"
otp := NewHOTP(testKey, 0, 6)
url := otp.URL("testuser@foo")
otp2, id, err := FromURL(url)
switch {
case err != nil:
t.Fatal("hotp: failed to parse HOTP URL\n")
case id != ident:
t.Logf("hotp: bad label\n")
t.Logf("\texpected: %s\n", ident)
t.Fatalf("\t actual: %s\n", id)
case otp2.Counter() != otp.Counter():
t.Logf("hotp: OTP counters aren't synced\n")
t.Logf("\toriginal: %d\n", otp.Counter())
t.Fatalf("\t second: %d\n", otp2.Counter())
}
code1 := otp.OTP()
code2 := otp2.OTP()
if code1 != code2 {
t.Logf("hotp: mismatched OTPs\n")
t.Logf("\texpected: %s\n", code1)
t.Fatalf("\t actual: %s\n", code2)
}
// There's not much we can do test the QR code, except to
// ensure it doesn't fail.
_, err = otp.QR(ident)
if err != nil {
t.Fatalf("hotp: failed to generate QR code PNG (%v)\n", err)
}
// This should fail because the maximum size of an alphanumeric
// QR code with the lowest-level of error correction should
// max out at 4296 bytes. 8k may be a bit overkill... but it
// gets the job done. The value is read from the PRNG to
// increase the likelihood that the returned data is
// uncompressible.
var tooBigIdent = make([]byte, 8192)
_, err = io.ReadFull(PRNG, tooBigIdent)
if err != nil {
t.Fatalf("hotp: failed to read identity (%v)\n", err)
} else if _, err = otp.QR(string(tooBigIdent)); err == nil {
t.Fatal("hotp: QR code should fail to encode oversized URL")
}
}
// This test makes sure we can generate codes for padded and non-padded
// entries.
func TestPaddedURL(t *testing.T) {
var urlList = []string{
"otpauth://hotp/?secret=ME",
"otpauth://hotp/?secret=MEFR",
"otpauth://hotp/?secret=MFRGG",
"otpauth://hotp/?secret=MFRGGZA",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by=======",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by",
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by%3D%3D%3D%3D%3D%3D%3D",
}
var codeList = []string{
"413198",
"770938",
"670717",
"402378",
"069864",
"069864",
"069864",
}
for i := range urlList {
if o, id, err := FromURL(urlList[i]); err != nil {
t.Log("hotp: URL should have parsed successfully (id=", id, ")")
t.Logf("\turl was: %s\n", urlList[i])
t.Fatalf("\t%s, %s\n", o.OTP(), id)
} else {
code2 := o.OTP()
if code2 != codeList[i] {
t.Logf("hotp: mismatched OTPs\n")
t.Logf("\texpected: %s\n", codeList[i])
t.Fatalf("\t actual: %s\n", code2)
}
}
}
}
// This test attempts a variety of invalid urls against the parser
// to ensure they fail.
func TestBadURL(t *testing.T) {
var urlList = []string{
"http://google.com",
"",
"-",
"foo",
"otpauth:/foo/bar/baz",
"://",
"otpauth://hotp/?digits=",
"otpauth://hotp/?secret=MFRGGZDF&digits=ABCD",
"otpauth://hotp/?secret=MFRGGZDF&counter=ABCD",
}
for i := range urlList {
if _, _, err := FromURL(urlList[i]); err == nil {
t.Log("hotp: URL should not have parsed successfully")
t.Fatalf("\turl was: %s\n", urlList[i])
}
}
}

172
twofactor/totp.go Normal file
View File

@@ -0,0 +1,172 @@
package twofactor
import (
"crypto"
"crypto/sha1" // #nosec G505 - required by RFC
"crypto/sha256"
"crypto/sha512"
"encoding/base32"
"hash"
"io"
"net/url"
"strconv"
"strings"
"github.com/benbjohnson/clock"
)
var timeSource = clock.New()
// TOTP represents an RFC 6238 Time-based One-Time Password instance.
type TOTP struct {
*OATH
step uint64
}
// NewTOTP takes a new key, a starting time, a step, the number of
// digits of output (typically 6 or 8) and the hash algorithm to
// use, and builds a new OTP.
func NewTOTP(key []byte, start uint64, step uint64, digits int, algo crypto.Hash) *TOTP {
h := hashFromAlgo(algo)
if h == nil {
return nil
}
return &TOTP{
OATH: &OATH{
key: key,
counter: start,
size: digits,
hash: h,
algo: algo,
},
step: step,
}
}
// NewGoogleTOTP takes a secret as a base32-encoded string and
// returns an appropriate Google Authenticator TOTP instance.
func NewGoogleTOTP(secret string) (*TOTP, error) {
key, err := base32.StdEncoding.DecodeString(secret)
if err != nil {
return nil, err
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1), nil
}
// NewTOTPSHA1 will build a new TOTP using SHA-1.
func NewTOTPSHA1(key []byte, start uint64, step uint64, digits int) *TOTP {
return NewTOTP(key, start, step, digits, crypto.SHA1)
}
// Type returns OATH_TOTP.
func (otp *TOTP) Type() Type {
return OATH_TOTP
}
func (otp *TOTP) otp(counter uint64) string {
return otp.OATH.OTP(counter)
}
// OTP returns the OTP for the current timestep.
func (otp *TOTP) OTP() string {
return otp.otp(otp.OTPCounter())
}
// URL returns a TOTP URL (i.e. for putting in a QR code).
func (otp *TOTP) URL(label string) string {
return otp.OATH.URL(otp.Type(), label)
}
// SetProvider sets up the provider component of the OTP URL.
func (otp *TOTP) SetProvider(provider string) {
otp.provider = provider
}
func (otp *TOTP) otpCounter(t uint64) uint64 {
return (t - otp.counter) / otp.step
}
// OTPCounter returns the current time value for the OTP.
func (otp *TOTP) OTPCounter() uint64 {
return otp.otpCounter(uint64(timeSource.Now().Unix() & 0x7FFFFFFF)) //#nosec G115 - masked out overflow bits
}
func hashFromAlgo(algo crypto.Hash) func() hash.Hash {
switch algo {
case crypto.SHA1:
return sha1.New
case crypto.SHA256:
return sha256.New
case crypto.SHA512:
return sha512.New
}
return nil
}
// GenerateGoogleTOTP produces a new TOTP token with the defaults expected by
// Google Authenticator.
func GenerateGoogleTOTP() *TOTP {
key := make([]byte, sha1.Size)
if _, err := io.ReadFull(PRNG, key); err != nil {
return nil
}
return NewTOTP(key, 0, 30, 6, crypto.SHA1)
}
func totpFromURL(u *url.URL) (*TOTP, string, error) {
label := u.Path[1:]
v := u.Query()
secret := strings.ToUpper(v.Get("secret"))
if secret == "" {
return nil, "", ErrInvalidURL
}
var algo = crypto.SHA1
if algorithm := v.Get("algorithm"); algorithm != "" {
switch {
case strings.EqualFold(algorithm, "SHA256"):
algo = crypto.SHA256
case strings.EqualFold(algorithm, "SHA512"):
algo = crypto.SHA512
case !strings.EqualFold(algorithm, "SHA1"):
return nil, "", ErrInvalidAlgo
}
}
var digits = 6
if sdigit := v.Get("digits"); sdigit != "" {
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
if err != nil {
return nil, "", err
}
digits = int(tmpDigits)
}
var period uint64 = 30
if speriod := v.Get("period"); speriod != "" {
var err error
period, err = strconv.ParseUint(speriod, 10, 64)
if err != nil {
return nil, "", err
}
}
key, err := base32.StdEncoding.DecodeString(Pad(secret))
if err != nil {
// assume secret isn't base32 encoded
key = []byte(secret)
}
otp := NewTOTP(key, 0, period, digits, algo)
return otp, label, nil
}
// QR generates a new TOTP QR code.
func (otp *TOTP) QR(label string) ([]byte, error) {
return otp.OATH.QR(otp.Type(), label)
}
func SetClock(c clock.Clock) {
timeSource = c
}

View File

@@ -0,0 +1,85 @@
package twofactor
import (
"crypto"
"testing"
"time"
"github.com/benbjohnson/clock"
)
var rfcTotpKey = map[crypto.Hash][]byte{
crypto.SHA1: []byte("12345678901234567890"),
crypto.SHA256: []byte("12345678901234567890123456789012"),
crypto.SHA512: []byte("1234567890123456789012345678901234567890123456789012345678901234"),
}
var rfcTotpStep uint64 = 30
var rfcTotpTests = []struct {
Time uint64
Code string
T uint64
Algo crypto.Hash
}{
{59, "94287082", 1, crypto.SHA1},
{59, "46119246", 1, crypto.SHA256},
{59, "90693936", 1, crypto.SHA512},
{1111111109, "07081804", 37037036, crypto.SHA1},
{1111111109, "68084774", 37037036, crypto.SHA256},
{1111111109, "25091201", 37037036, crypto.SHA512},
{1111111111, "14050471", 37037037, crypto.SHA1},
{1111111111, "67062674", 37037037, crypto.SHA256},
{1111111111, "99943326", 37037037, crypto.SHA512},
{1234567890, "89005924", 41152263, crypto.SHA1},
{1234567890, "91819424", 41152263, crypto.SHA256},
{1234567890, "93441116", 41152263, crypto.SHA512},
{2000000000, "69279037", 66666666, crypto.SHA1},
{2000000000, "90698825", 66666666, crypto.SHA256},
{2000000000, "38618901", 66666666, crypto.SHA512},
{20000000000, "65353130", 666666666, crypto.SHA1},
{20000000000, "77737706", 666666666, crypto.SHA256},
{20000000000, "47863826", 666666666, crypto.SHA512},
}
func TestTotpRFC(t *testing.T) {
for _, tc := range rfcTotpTests {
otp := NewTOTP(rfcTotpKey[tc.Algo], 0, rfcTotpStep, 8, tc.Algo)
if otp.otpCounter(tc.Time) != tc.T {
t.Logf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
t.Logf("\texpected: %d\n", tc.T)
t.Errorf("\t actual: %d\n", otp.otpCounter(tc.Time))
}
if code := otp.otp(otp.otpCounter(tc.Time)); code != tc.Code {
t.Logf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
t.Logf("\texpected: %s\n", tc.Code)
t.Errorf("\t actual: %s\n", code)
}
}
}
func TestTOTPTime(t *testing.T) {
otp := GenerateGoogleTOTP()
testClock := clock.NewMock()
testClock.Add(2 * time.Minute)
SetClock(testClock)
code := otp.OTP()
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: previous code %s shouldn't match code %s", newCode, code)
}
testClock.Add(2 * time.Minute)
if newCode := otp.OTP(); newCode == code {
t.Errorf("twofactor: TOTP: future code %s shouldn't match code %s", newCode, code)
}
testClock.Add(-1 * time.Minute)
if newCode := otp.OTP(); newCode != code {
t.Errorf("twofactor: TOTP: current code %s shouldn't match code %s", newCode, code)
}
}

16
twofactor/util.go Normal file
View File

@@ -0,0 +1,16 @@
package twofactor
import (
"strings"
)
// Pad calculates the number of '='s to add to our encoded string
// to make base32.StdEncoding.DecodeString happy.
func Pad(s string) string {
if !strings.HasSuffix(s, "=") && len(s)%8 != 0 {
for len(s)%8 != 0 {
s += "="
}
}
return s
}

51
twofactor/util_test.go Normal file
View File

@@ -0,0 +1,51 @@
package twofactor_test
import (
"encoding/base32"
"math/rand"
"strings"
"testing"
"git.wntrmute.dev/kyle/goutils/twofactor"
)
const letters = "1234567890!@#$%^&*()abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randString() string {
b := make([]byte, rand.Intn(len(letters)))
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
}
return base32.StdEncoding.EncodeToString(b)
}
func TestPadding(t *testing.T) {
for range 300 {
b := randString()
origEncoding := b
modEncoding := strings.ReplaceAll(b, "=", "")
str, err := base32.StdEncoding.DecodeString(origEncoding)
if err != nil {
t.Fatal("Can't decode: ", b)
}
paddedEncoding := twofactor.Pad(modEncoding)
if origEncoding != paddedEncoding {
t.Log("Padding failed:")
t.Logf("Expected: '%s'", origEncoding)
t.Fatalf("Got: '%s'", paddedEncoding)
} else {
var mstr []byte
mstr, err = base32.StdEncoding.DecodeString(paddedEncoding)
if err != nil {
t.Fatal("Can't decode: ", paddedEncoding)
}
if string(mstr) != string(str) {
t.Log("Re-padding failed:")
t.Logf("Expected: '%s'", str)
t.Fatalf("Got: '%s'", mstr)
}
}
}
}