Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 83f88c49fe | |||
| 7c437ac45f | |||
| c999bf35b0 | |||
| 4dc135cfe0 | |||
| 790113e189 | |||
| 8348c5fd65 | |||
| 1eafb638a8 | |||
| 3ad562b6fa | |||
| 0f77bd49dc | |||
| f31d74243f | |||
| 0dcd18c6f1 | |||
| 024d552293 | |||
| 9cd2ced695 | |||
| 619c08a13f | |||
| 944a57bf0e | |||
| 0857b29624 | |||
|
|
e95404bfc5 | ||
|
|
924654e7c4 | ||
| 9e0979e07f | |||
|
|
bbc82ff8de | ||
|
|
5fd928f69a | ||
|
|
acefe4a3b9 | ||
| a1452cebc9 | |||
| 6e9812e6f5 | |||
|
|
8c34415c34 | ||
|
|
2cf2c15def | ||
|
|
eaad1884d4 | ||
| 5d57d844d4 | |||
|
|
31b9d175dd | ||
|
|
79e106da2e | ||
|
|
939b1bc272 | ||
|
|
89e74f390b | ||
|
|
7881b6fdfc | ||
|
|
5bef33245f | ||
|
|
84250b0501 | ||
|
|
459e9f880f | ||
|
|
0982f47ce3 | ||
|
|
1dec15fd11 | ||
|
|
2ee9cae5ba | ||
|
|
dc04475120 | ||
|
|
dbbd5116b5 |
@@ -64,4 +64,4 @@ workflows:
|
||||
testbuild:
|
||||
jobs:
|
||||
- testbuild
|
||||
# - lint
|
||||
- lint
|
||||
|
||||
@@ -18,6 +18,19 @@ issues:
|
||||
# Default: 3
|
||||
max-same-issues: 50
|
||||
|
||||
# Exclude some lints for CLI programs under cmd/ (package main).
|
||||
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
|
||||
exclude-rules:
|
||||
- path: ^cmd/
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: cmd/.*
|
||||
linters:
|
||||
- forbidigo
|
||||
- path: .*/cmd/.*
|
||||
linters:
|
||||
- forbidigo
|
||||
|
||||
formatters:
|
||||
enable:
|
||||
- goimports # checks if the code and import statements are formatted according to the 'goimports' command
|
||||
@@ -73,7 +86,6 @@ linters:
|
||||
- godoclint # checks Golang's documentation practice
|
||||
- godot # checks if comments end in a period
|
||||
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
|
||||
- goprintffuncname # checks that printf-like functions are named with f at the end
|
||||
- gosec # inspects source code for security problems
|
||||
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
|
||||
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
|
||||
@@ -230,6 +242,10 @@ linters:
|
||||
check-type-assertions: true
|
||||
exclude-functions:
|
||||
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
|
||||
- git.wntrmute.dev/kyle/goutils/lib.Warn
|
||||
- git.wntrmute.dev/kyle/goutils/lib.Warnx
|
||||
- git.wntrmute.dev/kyle/goutils/lib.Err
|
||||
- git.wntrmute.dev/kyle/goutils/lib.Errx
|
||||
|
||||
exhaustive:
|
||||
# Program elements to check for exhaustiveness.
|
||||
@@ -321,6 +337,12 @@ linters:
|
||||
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
|
||||
- no-unused-link
|
||||
|
||||
gosec:
|
||||
excludes:
|
||||
- G104 # handled by errcheck
|
||||
- G301
|
||||
- G306
|
||||
|
||||
govet:
|
||||
# Enable all analyzers.
|
||||
# Default: false
|
||||
@@ -356,6 +378,12 @@ linters:
|
||||
- os.WriteFile
|
||||
- prometheus.ExponentialBuckets.*
|
||||
- prometheus.LinearBuckets
|
||||
ignored-numbers:
|
||||
- 1
|
||||
- 2
|
||||
- 3
|
||||
- 4
|
||||
- 8
|
||||
|
||||
nakedret:
|
||||
# Make an issue if func has more lines of code than this setting, and it has naked returns.
|
||||
@@ -424,6 +452,8 @@ linters:
|
||||
# Omit embedded fields from selector expression.
|
||||
# https://staticcheck.dev/docs/checks/#QF1008
|
||||
- -QF1008
|
||||
# We often explicitly enable old/deprecated ciphers for research.
|
||||
- -SA1019
|
||||
|
||||
usetesting:
|
||||
# Enable/disable `os.TempDir()` detections.
|
||||
@@ -452,6 +482,8 @@ linters:
|
||||
linters: [ testableexamples ]
|
||||
- path: 'main.go'
|
||||
linters: [ forbidigo, mnd, reassign ]
|
||||
- path: 'cmd/cruntar/main.go'
|
||||
linters: [ unparam ]
|
||||
- source: 'TODO'
|
||||
linters: [ godot ]
|
||||
- text: 'should have a package comment'
|
||||
|
||||
21
CHANGELOG
21
CHANGELOG
@@ -1,5 +1,26 @@
|
||||
CHANGELOG
|
||||
|
||||
v1.12.0 - 2025-11-16
|
||||
|
||||
Added
|
||||
- twofactor: the github.com/kisom/twofactor repo has been subtree'd
|
||||
into this repo.
|
||||
|
||||
v1.11.2 - 2025-11-16
|
||||
|
||||
Changed
|
||||
- cmd/ski, cmd/csrpubdump, cmd/tlskeypair: centralize
|
||||
certificate/private-key/CSR parsing by reusing certlib helpers.
|
||||
This reduces duplication and improves consistency across commands.
|
||||
- csr: CSR parsing in the above commands now uses certlib.ParseCSR,
|
||||
which verifies CSR signatures (behavioral hardening compared to
|
||||
prior parsing without signature verification).
|
||||
|
||||
v1.11.1 - 2025-11-16
|
||||
|
||||
Changed
|
||||
- cmd: complete linting fixes across programs; no functional changes.
|
||||
|
||||
v1.11.0 - 2025-11-15
|
||||
|
||||
Added
|
||||
|
||||
53
README.md
53
README.md
@@ -2,39 +2,52 @@ GOUTILS
|
||||
|
||||
This is a collection of small utility code I've written in Go; the `cmd/`
|
||||
directory has a number of command-line utilities. Rather than keep all
|
||||
of these in superfluous repositories of their own, or rewriting them
|
||||
of these in superfluous repositories of their own or rewriting them
|
||||
for each project, I'm putting them here.
|
||||
|
||||
The project can be built with the standard Go tooling, or it can be built
|
||||
with Bazel.
|
||||
The project can be built with the standard Go tooling.
|
||||
|
||||
Contents:
|
||||
|
||||
ahash/ Provides hashes from string algorithm specifiers.
|
||||
assert/ Error handling, assertion-style.
|
||||
backoff/ Implementation of an intelligent backoff strategy.
|
||||
cache/ Implementations of various caches.
|
||||
lru/ Least-recently-used cache.
|
||||
mru/ Most-recently-used cache.
|
||||
certlib/ Library for working with TLS certificates.
|
||||
cmd/
|
||||
atping/ Automated TCP ping, meant for putting in cronjobs.
|
||||
certchain/ Display the certificate chain from a
|
||||
TLS connection.
|
||||
ca-signed/ Validate whether a certificate is signed by a CA.
|
||||
cert-bundler/
|
||||
Create certificate bundles from a source of PEM
|
||||
certificates.
|
||||
cert-revcheck/
|
||||
Check whether a certificate has been revoked or is
|
||||
expired.
|
||||
certchain/ Display the certificate chain from a TLS connection.
|
||||
certdump/ Dump certificate information.
|
||||
certexpiry/ Print a list of certificate subjects and expiry times
|
||||
or warn about certificates expiring within a certain
|
||||
window.
|
||||
certverify/ Verify a TLS X.509 certificate, optionally printing
|
||||
certverify/ Verify a TLS X.509 certificate file, optionally printing
|
||||
the time to expiry and checking for revocations.
|
||||
clustersh/ Run commands or transfer files across multiple
|
||||
servers via SSH.
|
||||
cruntar/ Untar an archive with hard links, copying instead of
|
||||
cruntar/ (Un)tar an archive with hard links, copying instead of
|
||||
linking.
|
||||
csrpubdump/ Dump the public key from an X.509 certificate request.
|
||||
data_sync/ Sync the user's homedir to external storage.
|
||||
diskimg/ Write a disk image to a device.
|
||||
dumpbytes/ Dump the contents of a file as hex bytes, printing it as
|
||||
a Go []byte literal.
|
||||
eig/ EEPROM image generator.
|
||||
fragment/ Print a fragment of a file.
|
||||
host/ Go imlpementation of the host(1) command.
|
||||
jlp/ JSON linter/prettifier.
|
||||
kgz/ Custom gzip compressor / decompressor that handles 99%
|
||||
of my use cases.
|
||||
minmax/ Generate a minmax code for use in uLisp.
|
||||
parts/ Simple parts database management for my collection of
|
||||
electronic components.
|
||||
pem2bin/ Dump the binary body of a PEM-encoded block.
|
||||
@@ -44,41 +57,45 @@ Contents:
|
||||
in a bundle.
|
||||
renfnv/ Rename a file to base32-encoded 64-bit FNV-1a hash.
|
||||
rhash/ Compute the digest of remote files.
|
||||
rolldie/ Roll some dice.
|
||||
showimp/ List the external (e.g. non-stdlib and outside the
|
||||
current working directory) imports for a Go file.
|
||||
ski Display the SKI for PEM-encoded TLS material.
|
||||
sprox/ Simple TCP proxy.
|
||||
stealchain/ Dump the verified chain from a TLS
|
||||
connection to a server.
|
||||
stealchain- Dump the verified chain from a TLS
|
||||
server/ connection from a client.
|
||||
stealchain/ Dump the verified chain from a TLS connection to a
|
||||
server.
|
||||
stealchain-server/
|
||||
Dump the verified chain from a TLS connection from
|
||||
from a client.
|
||||
subjhash/ Print or match subject info from a certificate.
|
||||
tlsinfo/ Print information about a TLS connection (the TLS version
|
||||
and cipher suite).
|
||||
tlskeypair/ Check whether a TLS certificate and key file match.
|
||||
utc/ Convert times to UTC.
|
||||
yamll/ A small YAML linter.
|
||||
zsearch/ Search for a string in directory of gzipped files.
|
||||
config/ A simple global configuration system where configuration
|
||||
data is pulled from a file or an environment variable
|
||||
transparently.
|
||||
iniconf/ A simple INI-style configuration system.
|
||||
dbg/ A debug printer.
|
||||
die/ Death of a program.
|
||||
fileutil/ Common file functions.
|
||||
lib/ Commonly-useful functions for writing Go programs.
|
||||
log/ A syslog library.
|
||||
logging/ A logging library.
|
||||
mwc/ MultiwriteCloser implementation.
|
||||
rand/ Utilities for working with math/rand.
|
||||
sbuf/ A byte buffer that can be wiped.
|
||||
seekbuf/ A read-seekable byte buffer.
|
||||
syslog/ Syslog-type logging.
|
||||
tee/ Emulate tee(1)'s functionality in io.Writers.
|
||||
testio/ Various I/O utilities useful during testing.
|
||||
testutil/ Various utility functions useful during testing.
|
||||
|
||||
twofactor/ Two-factor authentication.
|
||||
|
||||
Each program should have a small README in the directory with more
|
||||
information.
|
||||
|
||||
All code here is licensed under the ISC license.
|
||||
|
||||
All code here is licensed under the Apache 2.0 license.
|
||||
|
||||
Error handling
|
||||
--------------
|
||||
@@ -99,7 +116,7 @@ Examples:
|
||||
```
|
||||
cert, err := certlib.LoadCertificate(path)
|
||||
if err != nil {
|
||||
// sentinel match
|
||||
// sentinel match:
|
||||
if errors.Is(err, certerr.ErrEmptyCertificate) {
|
||||
// handle empty input
|
||||
}
|
||||
@@ -116,5 +133,3 @@ if err != nil {
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Avoid including sensitive data (keys, passwords, tokens) in error messages.
|
||||
|
||||
@@ -91,7 +91,7 @@ func TestReset(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
const decay = 5 * time.Millisecond
|
||||
const decay = 25 * time.Millisecond
|
||||
const maxDuration = 10 * time.Millisecond
|
||||
const interval = time.Millisecond
|
||||
|
||||
|
||||
126
cache/lru/lru_internal_test.go
vendored
126
cache/lru/lru_internal_test.go
vendored
@@ -1,87 +1,87 @@
|
||||
package lru
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
// These tests mirror the MRU-style behavior present in this LRU package
|
||||
// implementation (eviction removes the most-recently-used entry).
|
||||
func TestBasicCacheEviction(t *testing.T) {
|
||||
mock := clock.NewMock()
|
||||
c := NewStringKeyCache[int](2)
|
||||
c.clock = mock
|
||||
mock := clock.NewMock()
|
||||
c := NewStringKeyCache[int](2)
|
||||
c.clock = mock
|
||||
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if c.Len() != 0 {
|
||||
t.Fatal("cache should have size 0")
|
||||
}
|
||||
if c.Len() != 0 {
|
||||
t.Fatal("cache should have size 0")
|
||||
}
|
||||
|
||||
c.evict()
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.evict()
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c.Store("raven", 1)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.Store("raven", 1)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(c.store) != 1 {
|
||||
t.Fatalf("store should have length=1, have length=%d", len(c.store))
|
||||
}
|
||||
if len(c.store) != 1 {
|
||||
t.Fatalf("store should have length=1, have length=%d", len(c.store))
|
||||
}
|
||||
|
||||
mock.Add(time.Second)
|
||||
c.Store("owl", 2)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mock.Add(time.Second)
|
||||
c.Store("owl", 2)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(c.store) != 2 {
|
||||
t.Fatalf("store should have length=2, have length=%d", len(c.store))
|
||||
}
|
||||
if len(c.store) != 2 {
|
||||
t.Fatalf("store should have length=2, have length=%d", len(c.store))
|
||||
}
|
||||
|
||||
mock.Add(time.Second)
|
||||
c.Store("goat", 3)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mock.Add(time.Second)
|
||||
c.Store("goat", 3)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(c.store) != 2 {
|
||||
t.Fatalf("store should have length=2, have length=%d", len(c.store))
|
||||
}
|
||||
if len(c.store) != 2 {
|
||||
t.Fatalf("store should have length=2, have length=%d", len(c.store))
|
||||
}
|
||||
|
||||
// Since this implementation evicts the most-recently-used item, inserting
|
||||
// "goat" when full evicts "owl" (the most recent at that time).
|
||||
mock.Add(time.Second)
|
||||
if _, ok := c.Get("owl"); ok {
|
||||
t.Fatal("store should not have an entry for owl (MRU-evicted)")
|
||||
}
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Since this implementation evicts the most-recently-used item, inserting
|
||||
// "goat" when full evicts "owl" (the most recent at that time).
|
||||
mock.Add(time.Second)
|
||||
if _, ok := c.Get("owl"); ok {
|
||||
t.Fatal("store should not have an entry for owl (MRU-evicted)")
|
||||
}
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
mock.Add(time.Second)
|
||||
c.Store("elk", 4)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
mock.Add(time.Second)
|
||||
c.Store("elk", 4)
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !c.Has("elk") {
|
||||
t.Fatal("store should contain an entry for 'elk'")
|
||||
}
|
||||
if !c.Has("elk") {
|
||||
t.Fatal("store should contain an entry for 'elk'")
|
||||
}
|
||||
|
||||
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
|
||||
if !c.Has("raven") {
|
||||
t.Fatal("store should contain an entry for 'raven'")
|
||||
}
|
||||
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
|
||||
if !c.Has("raven") {
|
||||
t.Fatal("store should contain an entry for 'raven'")
|
||||
}
|
||||
|
||||
if c.Has("goat") {
|
||||
t.Fatal("store should not contain an entry for 'goat'")
|
||||
}
|
||||
if c.Has("goat") {
|
||||
t.Fatal("store should not contain an entry for 'goat'")
|
||||
}
|
||||
}
|
||||
|
||||
64
cache/lru/timestamps_internal_test.go
vendored
64
cache/lru/timestamps_internal_test.go
vendored
@@ -1,50 +1,50 @@
|
||||
package lru
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
// These tests validate timestamps ordering semantics for the LRU package.
|
||||
// Note: The LRU timestamps are sorted with most-recent-first (descending by t).
|
||||
func TestTimestamps(t *testing.T) {
|
||||
ts := newTimestamps[string](3)
|
||||
mock := clock.NewMock()
|
||||
ts := newTimestamps[string](3)
|
||||
mock := clock.NewMock()
|
||||
|
||||
// raven
|
||||
ts.Update("raven", mock.Now().UnixNano())
|
||||
// raven
|
||||
ts.Update("raven", mock.Now().UnixNano())
|
||||
|
||||
// raven, owl
|
||||
mock.Add(time.Millisecond)
|
||||
ts.Update("owl", mock.Now().UnixNano())
|
||||
// raven, owl
|
||||
mock.Add(time.Millisecond)
|
||||
ts.Update("owl", mock.Now().UnixNano())
|
||||
|
||||
// raven, owl, goat
|
||||
mock.Add(time.Second)
|
||||
ts.Update("goat", mock.Now().UnixNano())
|
||||
// raven, owl, goat
|
||||
mock.Add(time.Second)
|
||||
ts.Update("goat", mock.Now().UnixNano())
|
||||
|
||||
if err := ts.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := ts.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// make owl the most recent
|
||||
mock.Add(time.Millisecond)
|
||||
ts.Update("owl", mock.Now().UnixNano())
|
||||
if err := ts.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// make owl the most recent
|
||||
mock.Add(time.Millisecond)
|
||||
ts.Update("owl", mock.Now().UnixNano())
|
||||
if err := ts.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
|
||||
if ts.K(0) != "owl" {
|
||||
t.Fatalf("first key should be owl, have %s", ts.K(0))
|
||||
}
|
||||
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
|
||||
if ts.K(0) != "owl" {
|
||||
t.Fatalf("first key should be owl, have %s", ts.K(0))
|
||||
}
|
||||
|
||||
if ts.K(1) != "goat" {
|
||||
t.Fatalf("second key should be goat, have %s", ts.K(1))
|
||||
}
|
||||
if ts.K(1) != "goat" {
|
||||
t.Fatalf("second key should be goat, have %s", ts.K(1))
|
||||
}
|
||||
|
||||
if ts.K(2) != "raven" {
|
||||
t.Fatalf("third key should be raven, have %s", ts.K(2))
|
||||
}
|
||||
if ts.K(2) != "raven" {
|
||||
t.Fatalf("third key should be raven, have %s", ts.K(2))
|
||||
}
|
||||
}
|
||||
|
||||
28
cache/mru/mru_internal_test.go
vendored
28
cache/mru/mru_internal_test.go
vendored
@@ -8,9 +8,9 @@ import (
|
||||
)
|
||||
|
||||
func TestBasicCacheEviction(t *testing.T) {
|
||||
mock := clock.NewMock()
|
||||
c := NewStringKeyCache[int](2)
|
||||
c.clock = mock
|
||||
mock := clock.NewMock()
|
||||
c := NewStringKeyCache[int](2)
|
||||
c.clock = mock
|
||||
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -55,18 +55,18 @@ func TestBasicCacheEviction(t *testing.T) {
|
||||
}
|
||||
|
||||
mock.Add(time.Second)
|
||||
v, ok := c.Get("owl")
|
||||
if !ok {
|
||||
t.Fatal("store should have an entry for owl")
|
||||
}
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
v, ok := c.Get("owl")
|
||||
if !ok {
|
||||
t.Fatal("store should have an entry for owl")
|
||||
}
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
itm := v
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
itm := v
|
||||
if err := c.ConsistencyCheck(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if itm != 2 {
|
||||
t.Fatalf("stored item should be 2, have %d", itm)
|
||||
|
||||
4
cache/mru/timestamps_internal_test.go
vendored
4
cache/mru/timestamps_internal_test.go
vendored
@@ -8,8 +8,8 @@ import (
|
||||
)
|
||||
|
||||
func TestTimestamps(t *testing.T) {
|
||||
ts := newTimestamps[string](3)
|
||||
mock := clock.NewMock()
|
||||
ts := newTimestamps[string](3)
|
||||
mock := clock.NewMock()
|
||||
|
||||
// raven
|
||||
ts.Update("raven", mock.Now().UnixNano())
|
||||
|
||||
@@ -458,8 +458,6 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
|
||||
}
|
||||
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
|
||||
if password != nil {
|
||||
// nolintlint requires rationale:
|
||||
//nolint:staticcheck // legacy RFC1423 PEM encryption supported for backward compatibility when caller supplies a password
|
||||
return x509.DecryptPEMBlock(keyDER, password)
|
||||
}
|
||||
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"flag"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib"
|
||||
@@ -23,6 +24,13 @@ var (
|
||||
verbose bool
|
||||
)
|
||||
|
||||
var (
|
||||
strOK = "OK"
|
||||
strExpired = "EXPIRED"
|
||||
strRevoked = "REVOKED"
|
||||
strUnknown = "UNKNOWN"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
|
||||
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
|
||||
@@ -42,16 +50,16 @@ func main() {
|
||||
for _, target := range flag.Args() {
|
||||
status, err := processTarget(target)
|
||||
switch status {
|
||||
case "OK":
|
||||
fmt.Printf("%s: OK\n", target)
|
||||
case "EXPIRED":
|
||||
fmt.Printf("%s: EXPIRED: %v\n", target, err)
|
||||
case strOK:
|
||||
fmt.Printf("%s: %s\n", target, strOK)
|
||||
case strExpired:
|
||||
fmt.Printf("%s: %s: %v\n", target, strExpired, err)
|
||||
exitCode = 1
|
||||
case "REVOKED":
|
||||
fmt.Printf("%s: REVOKED\n", target)
|
||||
case strRevoked:
|
||||
fmt.Printf("%s: %s\n", target, strRevoked)
|
||||
exitCode = 1
|
||||
case "UNKNOWN":
|
||||
fmt.Printf("%s: UNKNOWN: %v\n", target, err)
|
||||
case strUnknown:
|
||||
fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
|
||||
if hardfail {
|
||||
// In hardfail, treat unknown as failure
|
||||
exitCode = 1
|
||||
@@ -67,74 +75,77 @@ func processTarget(target string) (string, error) {
|
||||
return checkFile(target)
|
||||
}
|
||||
|
||||
// Not a file; treat as site
|
||||
return checkSite(target)
|
||||
}
|
||||
|
||||
func checkFile(path string) (string, error) {
|
||||
in, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return "UNKNOWN", err
|
||||
// Prefer high-level helpers from certlib to load certificates from disk
|
||||
if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
|
||||
// Evaluate the first certificate (leaf) by default
|
||||
return evaluateCert(certs[0])
|
||||
}
|
||||
|
||||
// Try PEM first; if that fails, try single DER cert
|
||||
certs, err := certlib.ReadCertificates(in)
|
||||
if err != nil || len(certs) == 0 {
|
||||
cert, _, derr := certlib.ReadCertificate(in)
|
||||
if derr != nil || cert == nil {
|
||||
if err == nil {
|
||||
err = derr
|
||||
}
|
||||
return "UNKNOWN", err
|
||||
}
|
||||
return evaluateCert(cert)
|
||||
cert, err := certlib.LoadCertificate(path)
|
||||
if err != nil || cert == nil {
|
||||
return strUnknown, err
|
||||
}
|
||||
|
||||
// Evaluate the first certificate (leaf) by default
|
||||
return evaluateCert(certs[0])
|
||||
return evaluateCert(cert)
|
||||
}
|
||||
|
||||
func checkSite(hostport string) (string, error) {
|
||||
// Use certlib/hosts to parse host/port (supports https URLs and host:port)
|
||||
target, err := hosts.ParseHost(hostport)
|
||||
if err != nil {
|
||||
return "UNKNOWN", err
|
||||
return strUnknown, err
|
||||
}
|
||||
|
||||
d := &net.Dialer{Timeout: timeout}
|
||||
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host})
|
||||
tcfg := &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
ServerName: target.Host,
|
||||
} // #nosec G402 -- CLI tool only verifies revocation
|
||||
td := &tls.Dialer{NetDialer: d, Config: tcfg}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
||||
defer cancel()
|
||||
|
||||
conn, err := td.DialContext(ctx, "tcp", target.String())
|
||||
if err != nil {
|
||||
return "UNKNOWN", err
|
||||
return strUnknown, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
state := conn.ConnectionState()
|
||||
tconn, ok := conn.(*tls.Conn)
|
||||
if !ok {
|
||||
return strUnknown, errors.New("connection is not TLS")
|
||||
}
|
||||
|
||||
state := tconn.ConnectionState()
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
return "UNKNOWN", errors.New("no peer certificates presented")
|
||||
return strUnknown, errors.New("no peer certificates presented")
|
||||
}
|
||||
return evaluateCert(state.PeerCertificates[0])
|
||||
}
|
||||
|
||||
func evaluateCert(cert *x509.Certificate) (string, error) {
|
||||
// Expiry check
|
||||
now := time.Now()
|
||||
if !now.Before(cert.NotAfter) {
|
||||
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
|
||||
}
|
||||
if !now.After(cert.NotBefore) {
|
||||
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
|
||||
}
|
||||
|
||||
// Revocation check using certlib/revoke
|
||||
// Delegate validity and revocation checks to certlib/revoke helper.
|
||||
// It returns revoked=true for both revoked and expired/not-yet-valid.
|
||||
// Map those cases back to our statuses using the returned error text.
|
||||
revoked, ok, err := revoke.VerifyCertificateError(cert)
|
||||
if revoked {
|
||||
// If revoked is true, ok will be true per implementation, err may describe why
|
||||
return "REVOKED", err
|
||||
if err != nil {
|
||||
msg := err.Error()
|
||||
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
|
||||
strings.Contains(msg, "not valid until") {
|
||||
return strExpired, err
|
||||
}
|
||||
}
|
||||
return strRevoked, err
|
||||
}
|
||||
if !ok {
|
||||
// Revocation status could not be determined
|
||||
return "UNKNOWN", err
|
||||
return strUnknown, err
|
||||
}
|
||||
|
||||
return "OK", nil
|
||||
return strOK, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
@@ -20,20 +23,26 @@ func main() {
|
||||
server += ":443"
|
||||
}
|
||||
|
||||
var chain string
|
||||
|
||||
conn, err := tls.Dial("tcp", server, nil)
|
||||
d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
|
||||
nc, err := d.DialContext(context.Background(), "tcp", server)
|
||||
die.If(err)
|
||||
conn, ok := nc.(*tls.Conn)
|
||||
if !ok {
|
||||
die.With("invalid TLS connection (not a *tls.Conn)")
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
details := conn.ConnectionState()
|
||||
var chain strings.Builder
|
||||
for _, cert := range details.PeerCertificates {
|
||||
p := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
chain += string(pem.EncodeToMemory(&p))
|
||||
chain.Write(pem.EncodeToMemory(&p))
|
||||
}
|
||||
|
||||
fmt.Println(chain)
|
||||
fmt.Fprintln(os.Stdout, chain.String())
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
//lint:file-ignore SA1019 allow strict compatibility for old certs
|
||||
package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/dsa"
|
||||
"crypto/ecdsa"
|
||||
"crypto/elliptic"
|
||||
@@ -101,30 +103,30 @@ func extUsage(ext []x509.ExtKeyUsage) string {
|
||||
}
|
||||
|
||||
func showBasicConstraints(cert *x509.Certificate) {
|
||||
fmt.Printf("\tBasic constraints: ")
|
||||
fmt.Fprint(os.Stdout, "\tBasic constraints: ")
|
||||
if cert.BasicConstraintsValid {
|
||||
fmt.Printf("valid")
|
||||
fmt.Fprint(os.Stdout, "valid")
|
||||
} else {
|
||||
fmt.Printf("invalid")
|
||||
fmt.Fprint(os.Stdout, "invalid")
|
||||
}
|
||||
|
||||
if cert.IsCA {
|
||||
fmt.Printf(", is a CA certificate")
|
||||
fmt.Fprint(os.Stdout, ", is a CA certificate")
|
||||
if !cert.BasicConstraintsValid {
|
||||
fmt.Printf(" (basic constraint failure)")
|
||||
fmt.Fprint(os.Stdout, " (basic constraint failure)")
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("is not a CA certificate")
|
||||
fmt.Fprint(os.Stdout, "is not a CA certificate")
|
||||
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
|
||||
fmt.Printf(" (key encipherment usage enabled!)")
|
||||
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
|
||||
}
|
||||
}
|
||||
|
||||
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
|
||||
fmt.Printf(", max path length %d", cert.MaxPathLen)
|
||||
fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
|
||||
}
|
||||
|
||||
fmt.Printf("\n")
|
||||
fmt.Fprintln(os.Stdout)
|
||||
}
|
||||
|
||||
const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
|
||||
@@ -136,39 +138,41 @@ var (
|
||||
|
||||
func wrapPrint(text string, indent int) {
|
||||
tabs := ""
|
||||
for i := 0; i < indent; i++ {
|
||||
tabs += "\t"
|
||||
var tabsSb140 strings.Builder
|
||||
for range indent {
|
||||
tabsSb140.WriteString("\t")
|
||||
}
|
||||
tabs += tabsSb140.String()
|
||||
|
||||
fmt.Printf(tabs+"%s\n", wrap(text, indent))
|
||||
fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
|
||||
}
|
||||
|
||||
func displayCert(cert *x509.Certificate) {
|
||||
fmt.Println("CERTIFICATE")
|
||||
fmt.Fprintln(os.Stdout, "CERTIFICATE")
|
||||
if showHash {
|
||||
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
|
||||
fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
|
||||
}
|
||||
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0))
|
||||
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0))
|
||||
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
|
||||
fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
|
||||
fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
|
||||
fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
|
||||
sigAlgoHash(cert.SignatureAlgorithm))
|
||||
fmt.Println("Details:")
|
||||
fmt.Fprintln(os.Stdout, "Details:")
|
||||
wrapPrint("Public key: "+certPublic(cert), 1)
|
||||
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber)
|
||||
fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
|
||||
|
||||
if len(cert.AuthorityKeyId) > 0 {
|
||||
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
|
||||
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
|
||||
}
|
||||
if len(cert.SubjectKeyId) > 0 {
|
||||
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
|
||||
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
|
||||
}
|
||||
|
||||
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
|
||||
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat))
|
||||
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage))
|
||||
fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
|
||||
fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
|
||||
|
||||
if len(cert.ExtKeyUsage) > 0 {
|
||||
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
|
||||
fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
|
||||
}
|
||||
|
||||
showBasicConstraints(cert)
|
||||
@@ -221,13 +225,13 @@ func displayAllCerts(in []byte, leafOnly bool) {
|
||||
if err != nil {
|
||||
certs, _, err = certlib.ParseCertificatesDER(in, "")
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to parse certificates")
|
||||
_, _ = lib.Warn(err, "failed to parse certificates")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(certs) == 0 {
|
||||
lib.Warnx("no certificates found")
|
||||
_, _ = lib.Warnx("no certificates found")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -243,29 +247,45 @@ func displayAllCerts(in []byte, leafOnly bool) {
|
||||
|
||||
func displayAllCertsWeb(uri string, leafOnly bool) {
|
||||
ci := getConnInfo(uri)
|
||||
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig())
|
||||
d := &tls.Dialer{Config: permissiveConfig()}
|
||||
nc, err := d.DialContext(context.Background(), "tcp", ci.Addr)
|
||||
if err != nil {
|
||||
lib.Warn(err, "couldn't connect to %s", ci.Addr)
|
||||
_, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr)
|
||||
return
|
||||
}
|
||||
|
||||
conn, ok := nc.(*tls.Conn)
|
||||
if !ok {
|
||||
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
state := conn.ConnectionState()
|
||||
conn.Close()
|
||||
if err = conn.Close(); err != nil {
|
||||
_, _ = lib.Warn(err, "couldn't close TLS connection")
|
||||
}
|
||||
|
||||
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host))
|
||||
d = &tls.Dialer{Config: verifyConfig(ci.Host)}
|
||||
nc, err = d.DialContext(context.Background(), "tcp", ci.Addr)
|
||||
if err == nil {
|
||||
conn, ok = nc.(*tls.Conn)
|
||||
if !ok {
|
||||
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
|
||||
return
|
||||
}
|
||||
|
||||
err = conn.VerifyHostname(ci.Host)
|
||||
if err == nil {
|
||||
state = conn.ConnectionState()
|
||||
}
|
||||
conn.Close()
|
||||
} else {
|
||||
lib.Warn(err, "TLS verification error with server name %s", ci.Host)
|
||||
_, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host)
|
||||
}
|
||||
|
||||
if len(state.PeerCertificates) == 0 {
|
||||
lib.Warnx("no certificates found")
|
||||
_, _ = lib.Warnx("no certificates found")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -275,14 +295,14 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
|
||||
}
|
||||
|
||||
if len(state.VerifiedChains) == 0 {
|
||||
lib.Warnx("no verified chains found; using peer chain")
|
||||
_, _ = lib.Warnx("no verified chains found; using peer chain")
|
||||
for i := range state.PeerCertificates {
|
||||
displayCert(state.PeerCertificates[i])
|
||||
}
|
||||
} else {
|
||||
fmt.Println("TLS chain verified successfully.")
|
||||
fmt.Fprintln(os.Stdout, "TLS chain verified successfully.")
|
||||
for i := range state.VerifiedChains {
|
||||
fmt.Printf("--- Verified certificate chain %d ---\n", i+1)
|
||||
fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n")
|
||||
for j := range state.VerifiedChains[i] {
|
||||
displayCert(state.VerifiedChains[i][j])
|
||||
}
|
||||
@@ -290,6 +310,32 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldReadStdin(argc int, argv []string) bool {
|
||||
if argc == 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if argc == 1 && argv[0] == "-" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func readStdin(leafOnly bool) {
|
||||
certs, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
_, _ = lib.Warn(err, "couldn't read certificates from standard input")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// This is needed for getting certs from JSON/jq.
|
||||
certs = bytes.TrimSpace(certs)
|
||||
certs = bytes.ReplaceAll(certs, []byte(`\n`), []byte{0xa})
|
||||
certs = bytes.Trim(certs, `"`)
|
||||
displayAllCerts(certs, leafOnly)
|
||||
}
|
||||
|
||||
func main() {
|
||||
var leafOnly bool
|
||||
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
|
||||
@@ -297,32 +343,23 @@ func main() {
|
||||
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
|
||||
flag.Parse()
|
||||
|
||||
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") {
|
||||
certs, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
lib.Warn(err, "couldn't read certificates from standard input")
|
||||
os.Exit(1)
|
||||
}
|
||||
if shouldReadStdin(flag.NArg(), flag.Args()) {
|
||||
readStdin(leafOnly)
|
||||
return
|
||||
}
|
||||
|
||||
// This is needed for getting certs from JSON/jq.
|
||||
certs = bytes.TrimSpace(certs)
|
||||
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1)
|
||||
certs = bytes.Trim(certs, `"`)
|
||||
displayAllCerts(certs, leafOnly)
|
||||
} else {
|
||||
for _, filename := range flag.Args() {
|
||||
fmt.Printf("--%s ---\n", filename)
|
||||
if strings.HasPrefix(filename, "https://") {
|
||||
displayAllCertsWeb(filename, leafOnly)
|
||||
} else {
|
||||
in, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
lib.Warn(err, "couldn't read certificate")
|
||||
continue
|
||||
}
|
||||
|
||||
displayAllCerts(in, leafOnly)
|
||||
for _, filename := range flag.Args() {
|
||||
fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
|
||||
if strings.HasPrefix(filename, "https://") {
|
||||
displayAllCertsWeb(filename, leafOnly)
|
||||
} else {
|
||||
in, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
_, _ = lib.Warn(err, "couldn't read certificate")
|
||||
continue
|
||||
}
|
||||
|
||||
displayAllCerts(in, leafOnly)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -13,6 +13,11 @@ import (
|
||||
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
|
||||
// "\2: \1,")
|
||||
|
||||
const (
|
||||
sSHA256 = "SHA256"
|
||||
sSHA512 = "SHA512"
|
||||
)
|
||||
|
||||
var keyUsage = map[x509.KeyUsage]string{
|
||||
x509.KeyUsageDigitalSignature: "digital signature",
|
||||
x509.KeyUsageContentCommitment: "content committment",
|
||||
@@ -26,42 +31,36 @@ var keyUsage = map[x509.KeyUsage]string{
|
||||
}
|
||||
|
||||
var extKeyUsages = map[x509.ExtKeyUsage]string{
|
||||
x509.ExtKeyUsageAny: "any",
|
||||
x509.ExtKeyUsageServerAuth: "server auth",
|
||||
x509.ExtKeyUsageClientAuth: "client auth",
|
||||
x509.ExtKeyUsageCodeSigning: "code signing",
|
||||
x509.ExtKeyUsageEmailProtection: "s/mime",
|
||||
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
|
||||
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
|
||||
x509.ExtKeyUsageIPSECUser: "ipsec user",
|
||||
x509.ExtKeyUsageTimeStamping: "timestamping",
|
||||
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
|
||||
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
|
||||
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
|
||||
}
|
||||
|
||||
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
|
||||
switch a {
|
||||
case x509.RSA:
|
||||
return "RSA"
|
||||
case x509.ECDSA:
|
||||
return "ECDSA"
|
||||
case x509.DSA:
|
||||
return "DSA"
|
||||
default:
|
||||
return "unknown public key algorithm"
|
||||
}
|
||||
x509.ExtKeyUsageAny: "any",
|
||||
x509.ExtKeyUsageServerAuth: "server auth",
|
||||
x509.ExtKeyUsageClientAuth: "client auth",
|
||||
x509.ExtKeyUsageCodeSigning: "code signing",
|
||||
x509.ExtKeyUsageEmailProtection: "s/mime",
|
||||
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
|
||||
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
|
||||
x509.ExtKeyUsageIPSECUser: "ipsec user",
|
||||
x509.ExtKeyUsageTimeStamping: "timestamping",
|
||||
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
|
||||
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
|
||||
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
|
||||
x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
|
||||
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
|
||||
}
|
||||
|
||||
func sigAlgoPK(a x509.SignatureAlgorithm) string {
|
||||
switch a {
|
||||
|
||||
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
|
||||
return "RSA"
|
||||
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
|
||||
return "RSA-PSS"
|
||||
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
|
||||
return "ECDSA"
|
||||
case x509.DSAWithSHA1, x509.DSAWithSHA256:
|
||||
return "DSA"
|
||||
case x509.PureEd25519:
|
||||
return "Ed25519"
|
||||
case x509.UnknownSignatureAlgorithm:
|
||||
return "unknown public key algorithm"
|
||||
default:
|
||||
return "unknown public key algorithm"
|
||||
}
|
||||
@@ -76,11 +75,21 @@ func sigAlgoHash(a x509.SignatureAlgorithm) string {
|
||||
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
|
||||
return "SHA1"
|
||||
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
|
||||
return "SHA256"
|
||||
return sSHA256
|
||||
case x509.SHA256WithRSAPSS:
|
||||
return sSHA256
|
||||
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
|
||||
return "SHA384"
|
||||
case x509.SHA384WithRSAPSS:
|
||||
return "SHA384"
|
||||
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
|
||||
return "SHA512"
|
||||
return sSHA512
|
||||
case x509.SHA512WithRSAPSS:
|
||||
return sSHA512
|
||||
case x509.PureEd25519:
|
||||
return sSHA512
|
||||
case x509.UnknownSignatureAlgorithm:
|
||||
return "unknown hash algorithm"
|
||||
default:
|
||||
return "unknown hash algorithm"
|
||||
}
|
||||
@@ -90,9 +99,11 @@ const maxLine = 78
|
||||
|
||||
func makeIndent(n int) string {
|
||||
s := " "
|
||||
for i := 0; i < n; i++ {
|
||||
s += " "
|
||||
var sSb97 strings.Builder
|
||||
for range n {
|
||||
sSb97.WriteString(" ")
|
||||
}
|
||||
s += sSb97.String()
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -100,7 +111,7 @@ func indentLen(n int) int {
|
||||
return 4 + (8 * n)
|
||||
}
|
||||
|
||||
// this isn't real efficient, but that's not a problem here
|
||||
// this isn't real efficient, but that's not a problem here.
|
||||
func wrap(s string, indent int) string {
|
||||
if indent > 3 {
|
||||
indent = 3
|
||||
@@ -123,9 +134,11 @@ func wrap(s string, indent int) string {
|
||||
|
||||
func dumpHex(in []byte) string {
|
||||
var s string
|
||||
var sSb130 strings.Builder
|
||||
for i := range in {
|
||||
s += fmt.Sprintf("%02X:", in[i])
|
||||
sSb130.WriteString(fmt.Sprintf("%02X:", in[i]))
|
||||
}
|
||||
s += sSb130.String()
|
||||
|
||||
return strings.Trim(s, ":")
|
||||
}
|
||||
@@ -136,14 +149,14 @@ func dumpHex(in []byte) string {
|
||||
func permissiveConfig() *tls.Config {
|
||||
return &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}
|
||||
} // #nosec G402
|
||||
}
|
||||
|
||||
// verifyConfig returns a config that will verify the connection.
|
||||
func verifyConfig(hostname string) *tls.Config {
|
||||
return &tls.Config{
|
||||
ServerName: hostname,
|
||||
}
|
||||
} // #nosec G402
|
||||
}
|
||||
|
||||
type connInfo struct {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"crypto/x509/pkix"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
|
||||
}
|
||||
|
||||
func expires(cert *x509.Certificate) time.Duration {
|
||||
return cert.NotAfter.Sub(time.Now())
|
||||
return time.Until(cert.NotAfter)
|
||||
}
|
||||
|
||||
func inDanger(cert *x509.Certificate) bool {
|
||||
@@ -81,15 +80,15 @@ func main() {
|
||||
flag.Parse()
|
||||
|
||||
for _, file := range flag.Args() {
|
||||
in, err := ioutil.ReadFile(file)
|
||||
in, err := os.ReadFile(file)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to read file")
|
||||
_, _ = lib.Warn(err, "failed to read file")
|
||||
continue
|
||||
}
|
||||
|
||||
certs, err := certlib.ParseCertificatesPEM(in)
|
||||
if err != nil {
|
||||
lib.Warn(err, "while parsing certificates")
|
||||
_, _ = lib.Warn(err, "while parsing certificates")
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@ import (
|
||||
"crypto/x509"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib"
|
||||
"git.wntrmute.dev/kyle/goutils/certlib/revoke"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
)
|
||||
|
||||
@@ -30,83 +28,116 @@ func printRevocation(cert *x509.Certificate) {
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
var caFile, intFile string
|
||||
var forceIntermediateBundle, revexp, verbose bool
|
||||
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`")
|
||||
flag.StringVar(&intFile, "i", "", "intermediate `bundle`")
|
||||
flag.BoolVar(&forceIntermediateBundle, "f", false,
|
||||
type appConfig struct {
|
||||
caFile, intFile string
|
||||
forceIntermediateBundle bool
|
||||
revexp, verbose bool
|
||||
}
|
||||
|
||||
func parseFlags() appConfig {
|
||||
var cfg appConfig
|
||||
flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
|
||||
flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
|
||||
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
|
||||
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
|
||||
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information")
|
||||
flag.BoolVar(&verbose, "v", false, "verbose")
|
||||
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
|
||||
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
|
||||
flag.Parse()
|
||||
return cfg
|
||||
}
|
||||
|
||||
var roots *x509.CertPool
|
||||
if caFile != "" {
|
||||
var err error
|
||||
if verbose {
|
||||
fmt.Println("[+] loading root certificates from", caFile)
|
||||
}
|
||||
roots, err = certlib.LoadPEMCertPool(caFile)
|
||||
die.If(err)
|
||||
func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
|
||||
if caFile == "" {
|
||||
return x509.SystemCertPool()
|
||||
}
|
||||
|
||||
var ints *x509.CertPool
|
||||
if intFile != "" {
|
||||
var err error
|
||||
if verbose {
|
||||
fmt.Println("[+] loading intermediate certificates from", intFile)
|
||||
}
|
||||
ints, err = certlib.LoadPEMCertPool(caFile)
|
||||
die.If(err)
|
||||
} else {
|
||||
ints = x509.NewCertPool()
|
||||
}
|
||||
|
||||
if flag.NArg() != 1 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
|
||||
lib.ProgName())
|
||||
}
|
||||
|
||||
fileData, err := ioutil.ReadFile(flag.Arg(0))
|
||||
die.If(err)
|
||||
|
||||
chain, err := certlib.ParseCertificatesPEM(fileData)
|
||||
die.If(err)
|
||||
if verbose {
|
||||
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
|
||||
fmt.Println("[+] loading root certificates from", caFile)
|
||||
}
|
||||
return certlib.LoadPEMCertPool(caFile)
|
||||
}
|
||||
|
||||
cert := chain[0]
|
||||
if len(chain) > 1 {
|
||||
if !forceIntermediateBundle {
|
||||
for _, intermediate := range chain[1:] {
|
||||
if verbose {
|
||||
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
|
||||
}
|
||||
func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
|
||||
if intFile == "" {
|
||||
return x509.NewCertPool(), nil
|
||||
}
|
||||
if verbose {
|
||||
fmt.Println("[+] loading intermediate certificates from", intFile)
|
||||
}
|
||||
// Note: use intFile here (previously used caFile mistakenly)
|
||||
return certlib.LoadPEMCertPool(intFile)
|
||||
}
|
||||
|
||||
ints.AddCert(intermediate)
|
||||
}
|
||||
func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
|
||||
for _, intermediate := range chain[1:] {
|
||||
if verbose {
|
||||
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
|
||||
}
|
||||
pool.AddCert(intermediate)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
|
||||
opts := x509.VerifyOptions{
|
||||
Intermediates: ints,
|
||||
Roots: roots,
|
||||
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
|
||||
}
|
||||
_, err := cert.Verify(opts)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cert.Verify(opts)
|
||||
func run(cfg appConfig) error {
|
||||
roots, err := loadRoots(cfg.caFile, cfg.verbose)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
return err
|
||||
}
|
||||
|
||||
if verbose {
|
||||
ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if flag.NArg() != 1 {
|
||||
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
|
||||
}
|
||||
|
||||
fileData, err := os.ReadFile(flag.Arg(0))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chain, err := certlib.ParseCertificatesPEM(fileData)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cfg.verbose {
|
||||
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
|
||||
}
|
||||
|
||||
cert := chain[0]
|
||||
if len(chain) > 1 && !cfg.forceIntermediateBundle {
|
||||
addBundledIntermediates(chain, ints, cfg.verbose)
|
||||
}
|
||||
|
||||
if err = verifyCert(cert, roots, ints); err != nil {
|
||||
return fmt.Errorf("certificate verification failed: %w", err)
|
||||
}
|
||||
|
||||
if cfg.verbose {
|
||||
fmt.Println("OK")
|
||||
}
|
||||
|
||||
if revexp {
|
||||
if cfg.revexp {
|
||||
printRevocation(cert)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
cfg := parseFlags()
|
||||
if err := run(cfg); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,8 @@ package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
|
||||
}
|
||||
|
||||
func sshAgent() ssh.AuthMethod {
|
||||
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK"))
|
||||
a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
|
||||
if err == nil {
|
||||
return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
|
||||
}
|
||||
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
|
||||
}
|
||||
}
|
||||
|
||||
func logError(host string, err error, format string, args ...interface{}) {
|
||||
func logError(host string, err error, format string, args ...any) {
|
||||
msg := fmt.Sprintf(format, args...)
|
||||
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
|
||||
}
|
||||
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
|
||||
defer func() {
|
||||
for i := len(shutdown) - 1; i >= 0; i-- {
|
||||
err := shutdown[i]()
|
||||
if err != nil && err != io.EOF {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logError(host, err, "shutting down")
|
||||
}
|
||||
}
|
||||
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
|
||||
}
|
||||
shutdown = append(shutdown, session.Close)
|
||||
|
||||
if err := session.RequestPty("xterm", 80, 40, modes); err != nil {
|
||||
if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
|
||||
session.Close()
|
||||
logError(host, err, "request for pty failed")
|
||||
return
|
||||
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
|
||||
defer func() {
|
||||
for i := len(shutdown) - 1; i >= 0; i-- {
|
||||
err := shutdown[i]()
|
||||
if err != nil && err != io.EOF {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logError(host, err, "shutting down")
|
||||
}
|
||||
}
|
||||
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
|
||||
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
logError(host, err, "reading chunk")
|
||||
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
|
||||
defer func() {
|
||||
for i := len(shutdown) - 1; i >= 0; i-- {
|
||||
err := shutdown[i]()
|
||||
if err != nil && err != io.EOF {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logError(host, err, "shutting down")
|
||||
}
|
||||
}
|
||||
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
|
||||
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
|
||||
}
|
||||
|
||||
if err == io.EOF {
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if err != nil {
|
||||
logError(host, err, "reading chunk")
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
"git.wntrmute.dev/kyle/goutils/fileutil"
|
||||
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
|
||||
if verbose {
|
||||
fmt.Printf("\tchmod %0#o\n", hdr.Mode)
|
||||
}
|
||||
err := file.Chmod(os.FileMode(hdr.Mode))
|
||||
err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -48,73 +49,105 @@ func linkTarget(target, top string) string {
|
||||
return target
|
||||
}
|
||||
|
||||
return filepath.Clean(filepath.Join(target, top))
|
||||
return filepath.Clean(filepath.Join(top, target))
|
||||
}
|
||||
|
||||
// safeJoin joins base and elem and ensures the resulting path does not escape base.
|
||||
func safeJoin(base, elem string) (string, error) {
|
||||
cleanBase := filepath.Clean(base)
|
||||
joined := filepath.Clean(filepath.Join(cleanBase, elem))
|
||||
|
||||
absBase, err := filepath.Abs(cleanBase)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
absJoined, err := filepath.Abs(joined)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
rel, err := filepath.Rel(absBase, absJoined)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
|
||||
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
|
||||
}
|
||||
return joined, nil
|
||||
}
|
||||
|
||||
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
if _, err = io.Copy(file, tfr); err != nil {
|
||||
return err
|
||||
}
|
||||
return setupFile(hdr, file)
|
||||
}
|
||||
|
||||
func handleTypeLink(hdr *tar.Header, top, filePath string) error {
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
srcPath, err := safeJoin(top, hdr.Linkname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
source, err := os.Open(srcPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer source.Close()
|
||||
|
||||
if _, err = io.Copy(file, source); err != nil {
|
||||
return err
|
||||
}
|
||||
return setupFile(hdr, file)
|
||||
}
|
||||
|
||||
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
|
||||
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
|
||||
return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
|
||||
}
|
||||
path := linkTarget(hdr.Linkname, top)
|
||||
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
|
||||
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
|
||||
}
|
||||
|
||||
func handleTypeDir(hdr *tar.Header, filePath string) error {
|
||||
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
|
||||
}
|
||||
|
||||
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
|
||||
if verbose {
|
||||
fmt.Println(hdr.Name)
|
||||
}
|
||||
filePath := filepath.Clean(filepath.Join(top, hdr.Name))
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeReg:
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(file, tfr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = setupFile(hdr, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeLink:
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
source, err := os.Open(hdr.Linkname)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(file, source)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = setupFile(hdr, file)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeSymlink:
|
||||
if !fileutil.ValidateSymlink(hdr.Linkname, top) {
|
||||
return fmt.Errorf("symlink %s is outside the top-level %s",
|
||||
hdr.Linkname, top)
|
||||
}
|
||||
path := linkTarget(hdr.Linkname, top)
|
||||
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
|
||||
return fmt.Errorf("symlink %s isn't in %s", hdr.Linkname, top)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case tar.TypeDir:
|
||||
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
filePath, err := safeJoin(top, hdr.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch hdr.Typeflag {
|
||||
case tar.TypeReg:
|
||||
return handleTypeReg(tfr, hdr, filePath)
|
||||
case tar.TypeLink:
|
||||
return handleTypeLink(hdr, top, filePath)
|
||||
case tar.TypeSymlink:
|
||||
return handleTypeSymlink(hdr, top, filePath)
|
||||
case tar.TypeDir:
|
||||
return handleTypeDir(hdr, filePath)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -261,16 +294,16 @@ func main() {
|
||||
die.If(err)
|
||||
|
||||
tfr := tar.NewReader(r)
|
||||
var hdr *tar.Header
|
||||
for {
|
||||
hdr, err := tfr.Next()
|
||||
if err == io.EOF {
|
||||
hdr, err = tfr.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
die.If(err)
|
||||
|
||||
err = processFile(tfr, hdr, top)
|
||||
die.If(err)
|
||||
|
||||
}
|
||||
|
||||
r.Close()
|
||||
|
||||
@@ -7,9 +7,9 @@ import (
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
|
||||
@@ -17,17 +17,10 @@ func main() {
|
||||
flag.Parse()
|
||||
|
||||
for _, fileName := range flag.Args() {
|
||||
in, err := ioutil.ReadFile(fileName)
|
||||
in, err := os.ReadFile(fileName)
|
||||
die.If(err)
|
||||
|
||||
if p, _ := pem.Decode(in); p != nil {
|
||||
if p.Type != "CERTIFICATE REQUEST" {
|
||||
log.Fatal("INVALID FILE TYPE")
|
||||
}
|
||||
in = p.Bytes
|
||||
}
|
||||
|
||||
csr, err := x509.ParseCertificateRequest(in)
|
||||
csr, _, err := certlib.ParseCSR(in)
|
||||
die.If(err)
|
||||
|
||||
out, err := x509.MarshalPKIXPublicKey(csr.PublicKey)
|
||||
@@ -48,8 +41,8 @@ func main() {
|
||||
Bytes: out,
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644)
|
||||
err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
|
||||
die.If(err)
|
||||
fmt.Printf("[+] wrote %s.\n", fileName+".pub")
|
||||
fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd := exec.Command(path, args...)
|
||||
cmd := exec.CommandContext(context.Background(), path, args...)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
return cmd.Run()
|
||||
@@ -163,7 +164,6 @@ func init() {
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
var logLevel, mountDir, syncDir, target string
|
||||
var dryRun, quietMode, noSyslog, verboseRsync bool
|
||||
|
||||
@@ -219,7 +219,7 @@ func main() {
|
||||
if excludeFile != "" {
|
||||
defer func() {
|
||||
log.Infof("removing exclude file %s", excludeFile)
|
||||
if err := os.Remove(excludeFile); err != nil {
|
||||
if rmErr := os.Remove(excludeFile); rmErr != nil {
|
||||
log.Warningf("failed to remove temp file %s", excludeFile)
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -15,43 +15,41 @@ import (
|
||||
const defaultHashAlgorithm = "sha256"
|
||||
|
||||
var (
|
||||
hAlgo string
|
||||
hAlgo string
|
||||
debug = dbg.New()
|
||||
)
|
||||
|
||||
|
||||
func openImage(imageFile string) (image *os.File, hash []byte, err error) {
|
||||
image, err = os.Open(imageFile)
|
||||
func openImage(imageFile string) (*os.File, []byte, error) {
|
||||
f, err := os.Open(imageFile)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
hash, err = ahash.SumReader(hAlgo, image)
|
||||
h, err := ahash.SumReader(hAlgo, f)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
_, err = image.Seek(0, 0)
|
||||
if err != nil {
|
||||
return
|
||||
if _, err = f.Seek(0, 0); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
debug.Printf("%s %x\n", imageFile, hash)
|
||||
return
|
||||
debug.Printf("%s %x\n", imageFile, h)
|
||||
return f, h, nil
|
||||
}
|
||||
|
||||
func openDevice(devicePath string) (device *os.File, err error) {
|
||||
func openDevice(devicePath string) (*os.File, error) {
|
||||
fi, err := os.Stat(devicePath)
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
|
||||
device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
|
||||
if err != nil {
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -105,12 +103,12 @@ func main() {
|
||||
die.If(err)
|
||||
|
||||
if !bytes.Equal(deviceHash, hash) {
|
||||
fmt.Fprintln(os.Stderr, "Hash mismatch:")
|
||||
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash)
|
||||
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash)
|
||||
os.Exit(1)
|
||||
buf := &bytes.Buffer{}
|
||||
fmt.Fprintln(buf, "Hash mismatch:")
|
||||
fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
|
||||
fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
|
||||
die.With(buf.String())
|
||||
}
|
||||
|
||||
debug.Println("OK")
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
@@ -1,30 +1,33 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
|
||||
func usage(w io.Writer, exc int) {
|
||||
fmt.Fprintln(w, `usage: dumpbytes <file>`)
|
||||
fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
|
||||
os.Exit(exc)
|
||||
}
|
||||
|
||||
func printBytes(buf []byte) {
|
||||
fmt.Printf("\t")
|
||||
for i := 0; i < len(buf); i++ {
|
||||
for i := range buf {
|
||||
fmt.Printf("0x%02x, ", buf[i])
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
|
||||
func dumpFile(path string, indentLevel int) error {
|
||||
indent := ""
|
||||
for i := 0; i < indentLevel; i++ {
|
||||
indent += "\t"
|
||||
var indent strings.Builder
|
||||
for range indentLevel {
|
||||
indent.WriteByte('\t')
|
||||
}
|
||||
|
||||
file, err := os.Open(path)
|
||||
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
|
||||
|
||||
defer file.Close()
|
||||
|
||||
fmt.Printf("%svar buffer = []byte{\n", indent)
|
||||
fmt.Printf("%svar buffer = []byte{\n", indent.String())
|
||||
var n int
|
||||
for {
|
||||
buf := make([]byte, 8)
|
||||
n, err := file.Read(buf)
|
||||
if err == io.EOF {
|
||||
n, err = file.Read(buf)
|
||||
if errors.Is(err, io.EOF) {
|
||||
if n > 0 {
|
||||
fmt.Printf("%s", indent)
|
||||
fmt.Printf("%s", indent.String())
|
||||
printBytes(buf[:n])
|
||||
}
|
||||
break
|
||||
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("%s", indent)
|
||||
fmt.Printf("%s", indent.String())
|
||||
printBytes(buf[:n])
|
||||
}
|
||||
|
||||
fmt.Printf("%s}\n", indent)
|
||||
fmt.Printf("%s}\n", indent.String())
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
|
||||
// size of a kilobit in bytes
|
||||
// size of a kilobit in bytes.
|
||||
const kilobit = 128
|
||||
const pageSize = 4096
|
||||
|
||||
@@ -26,10 +26,10 @@ func main() {
|
||||
path = flag.Arg(0)
|
||||
}
|
||||
|
||||
fillByte := uint8(*fill)
|
||||
fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
|
||||
|
||||
buf := make([]byte, pageSize)
|
||||
for i := 0; i < pageSize; i++ {
|
||||
for i := range pageSize {
|
||||
buf[i] = fillByte
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ func main() {
|
||||
die.If(err)
|
||||
defer file.Close()
|
||||
|
||||
for i := 0; i < pages; i++ {
|
||||
for range pages {
|
||||
_, err = file.Write(buf)
|
||||
die.If(err)
|
||||
}
|
||||
|
||||
@@ -72,15 +72,13 @@ func main() {
|
||||
|
||||
if end < start {
|
||||
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
|
||||
tmp := end
|
||||
end = start
|
||||
start = tmp
|
||||
start, end = end, start
|
||||
}
|
||||
|
||||
var fmtStr string
|
||||
|
||||
if !*quiet {
|
||||
maxLine := fmt.Sprintf("%d", len(lines))
|
||||
maxLine := strconv.Itoa(len(lines))
|
||||
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
|
||||
}
|
||||
|
||||
@@ -98,9 +96,9 @@ func main() {
|
||||
fmtStr += "\n"
|
||||
for i := start; !endFunc(i); i++ {
|
||||
if *quiet {
|
||||
fmt.Println(lines[i])
|
||||
fmt.Fprintln(os.Stdout, lines[i])
|
||||
} else {
|
||||
fmt.Printf(fmtStr, i, lines[i])
|
||||
fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
@@ -8,7 +9,8 @@ import (
|
||||
)
|
||||
|
||||
func lookupHost(host string) error {
|
||||
cname, err := net.LookupCNAME(host)
|
||||
r := &net.Resolver{}
|
||||
cname, err := r.LookupCNAME(context.Background(), host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
|
||||
host = cname
|
||||
}
|
||||
|
||||
addrs, err := net.LookupHost(host)
|
||||
addrs, err := r.LookupHost(context.Background(), host)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
|
||||
var err error
|
||||
|
||||
if file == "-" {
|
||||
in, err = ioutil.ReadAll(os.Stdin)
|
||||
in, err = io.ReadAll(os.Stdin)
|
||||
} else {
|
||||
in, err = ioutil.ReadFile(file)
|
||||
in, err = os.ReadFile(file)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lib.Warn(err, "ReadFile")
|
||||
_, _ = lib.Warn(err, "ReadFile")
|
||||
return err
|
||||
}
|
||||
|
||||
var buf = &bytes.Buffer{}
|
||||
err = json.Indent(buf, in, "", " ")
|
||||
if err != nil {
|
||||
lib.Warn(err, "%s", file)
|
||||
_, _ = lib.Warn(err, "%s", file)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
|
||||
if file == "-" {
|
||||
_, err = os.Stdout.Write(buf.Bytes())
|
||||
} else {
|
||||
err = ioutil.WriteFile(file, buf.Bytes(), 0644)
|
||||
err = os.WriteFile(file, buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lib.Warn(err, "WriteFile")
|
||||
_, _ = lib.Warn(err, "WriteFile")
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
|
||||
var err error
|
||||
|
||||
if file == "-" {
|
||||
in, err = ioutil.ReadAll(os.Stdin)
|
||||
in, err = io.ReadAll(os.Stdin)
|
||||
} else {
|
||||
in, err = ioutil.ReadFile(file)
|
||||
in, err = os.ReadFile(file)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lib.Warn(err, "ReadFile")
|
||||
_, _ = lib.Warn(err, "ReadFile")
|
||||
return err
|
||||
}
|
||||
|
||||
var buf = &bytes.Buffer{}
|
||||
err = json.Compact(buf, in)
|
||||
if err != nil {
|
||||
lib.Warn(err, "%s", file)
|
||||
_, _ = lib.Warn(err, "%s", file)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
|
||||
if file == "-" {
|
||||
_, err = os.Stdout.Write(buf.Bytes())
|
||||
} else {
|
||||
err = ioutil.WriteFile(file, buf.Bytes(), 0644)
|
||||
err = os.WriteFile(file, buf.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
lib.Warn(err, "WriteFile")
|
||||
_, _ = lib.Warn(err, "WriteFile")
|
||||
}
|
||||
|
||||
return err
|
||||
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
|
||||
|
||||
func usage() {
|
||||
progname := lib.ProgName()
|
||||
fmt.Printf(`Usage: %s [-h] files...
|
||||
fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
|
||||
%s is used to lint and prettify (or compact) JSON files. The
|
||||
files will be updated in-place.
|
||||
|
||||
@@ -100,7 +100,6 @@ func usage() {
|
||||
-h Print this help message.
|
||||
-n Don't prettify; only perform validation.
|
||||
`, progname, progname)
|
||||
|
||||
}
|
||||
|
||||
func init() {
|
||||
@@ -11,7 +11,10 @@ based on whether the source filename ends in ".gz".
|
||||
|
||||
Flags:
|
||||
-l level Compression level (0-9). Only meaninful when
|
||||
compressing a file.
|
||||
compressing a file.
|
||||
-u Do not restrict the size during decompression. As
|
||||
a safeguard against gzip bombs, the maximum size
|
||||
allowed is 32 * the compressed file size.
|
||||
|
||||
|
||||
|
||||
|
||||
146
cmd/kgz/main.go
146
cmd/kgz/main.go
@@ -1,68 +1,84 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"compress/flate"
|
||||
"compress/gzip"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const gzipExt = ".gz"
|
||||
|
||||
func compress(path, target string, level int) error {
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for read: %w", err)
|
||||
}
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for read: %w", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
destFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for write: %w", err)
|
||||
}
|
||||
destFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for write: %w", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid compression level: %w", err)
|
||||
}
|
||||
gzipCompressor, err := gzip.NewWriterLevel(destFile, level)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid compression level: %w", err)
|
||||
}
|
||||
defer gzipCompressor.Close()
|
||||
|
||||
_, err = io.Copy(gzipCompressor, sourceFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compressing file: %w", err)
|
||||
}
|
||||
_, err = io.Copy(gzipCompressor, sourceFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("compressing file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func uncompress(path, target string) error {
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for read: %w", err)
|
||||
}
|
||||
func uncompress(path, target string, unrestrict bool) error {
|
||||
sourceFile, err := os.Open(path)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for read: %w", err)
|
||||
}
|
||||
defer sourceFile.Close()
|
||||
|
||||
gzipUncompressor, err := gzip.NewReader(sourceFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading gzip headers: %w", err)
|
||||
}
|
||||
fi, err := sourceFile.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading file stats: %w", err)
|
||||
}
|
||||
|
||||
maxDecompressionSize := fi.Size() * 32
|
||||
|
||||
gzipUncompressor, err := gzip.NewReader(sourceFile)
|
||||
if err != nil {
|
||||
return fmt.Errorf("reading gzip headers: %w", err)
|
||||
}
|
||||
defer gzipUncompressor.Close()
|
||||
|
||||
destFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for write: %w", err)
|
||||
}
|
||||
var reader io.Reader = &io.LimitedReader{
|
||||
R: gzipUncompressor,
|
||||
N: maxDecompressionSize,
|
||||
}
|
||||
|
||||
if unrestrict {
|
||||
reader = gzipUncompressor
|
||||
}
|
||||
|
||||
destFile, err := os.Create(target)
|
||||
if err != nil {
|
||||
return fmt.Errorf("opening file for write: %w", err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
|
||||
_, err = io.Copy(destFile, gzipUncompressor)
|
||||
if err != nil {
|
||||
return fmt.Errorf("uncompressing file: %w", err)
|
||||
}
|
||||
_, err = io.Copy(destFile, reader)
|
||||
if err != nil {
|
||||
return fmt.Errorf("uncompressing file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -87,8 +103,8 @@ func isDir(path string) bool {
|
||||
file, err := os.Open(path)
|
||||
if err == nil {
|
||||
defer file.Close()
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
stat, err2 := file.Stat()
|
||||
if err2 != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -106,9 +122,9 @@ func pathForUncompressing(source, dest string) (string, error) {
|
||||
}
|
||||
|
||||
source = filepath.Base(source)
|
||||
if !strings.HasSuffix(source, gzipExt) {
|
||||
return "", fmt.Errorf("%s is a not gzip-compressed file", source)
|
||||
}
|
||||
if !strings.HasSuffix(source, gzipExt) {
|
||||
return "", fmt.Errorf("%s is a not gzip-compressed file", source)
|
||||
}
|
||||
outFile := source[:len(source)-len(gzipExt)]
|
||||
outFile = filepath.Join(dest, outFile)
|
||||
return outFile, nil
|
||||
@@ -120,9 +136,9 @@ func pathForCompressing(source, dest string) (string, error) {
|
||||
}
|
||||
|
||||
source = filepath.Base(source)
|
||||
if strings.HasSuffix(source, gzipExt) {
|
||||
return "", fmt.Errorf("%s is a gzip-compressed file", source)
|
||||
}
|
||||
if strings.HasSuffix(source, gzipExt) {
|
||||
return "", fmt.Errorf("%s is a gzip-compressed file", source)
|
||||
}
|
||||
|
||||
dest = filepath.Join(dest, source+gzipExt)
|
||||
return dest, nil
|
||||
@@ -132,8 +148,11 @@ func main() {
|
||||
var level int
|
||||
var path string
|
||||
var target = "."
|
||||
var err error
|
||||
var unrestrict bool
|
||||
|
||||
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
|
||||
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
|
||||
flag.Parse()
|
||||
|
||||
if flag.NArg() < 1 || flag.NArg() > 2 {
|
||||
@@ -147,30 +166,31 @@ func main() {
|
||||
}
|
||||
|
||||
if strings.HasSuffix(path, gzipExt) {
|
||||
target, err := pathForUncompressing(path, target)
|
||||
target, err = pathForUncompressing(path, target)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = uncompress(path, target)
|
||||
err = uncompress(path, target, unrestrict)
|
||||
if err != nil {
|
||||
os.Remove(target)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
} else {
|
||||
target, err := pathForCompressing(path, target)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = compress(path, target, level)
|
||||
if err != nil {
|
||||
os.Remove(target)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
target, err = pathForCompressing(path, target)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
err = compress(path, target, level)
|
||||
if err != nil {
|
||||
os.Remove(target)
|
||||
fmt.Fprintf(os.Stderr, "%s\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,14 +40,14 @@ func main() {
|
||||
usage()
|
||||
}
|
||||
|
||||
min, err := strconv.Atoi(flag.Arg(1))
|
||||
minVal, err := strconv.Atoi(flag.Arg(1))
|
||||
dieIf(err)
|
||||
|
||||
max, err := strconv.Atoi(flag.Arg(2))
|
||||
maxVal, err := strconv.Atoi(flag.Arg(2))
|
||||
dieIf(err)
|
||||
|
||||
code := kind << 6
|
||||
code += (min << 3)
|
||||
code += max
|
||||
fmt.Printf("%0o\n", code)
|
||||
code += (minVal << 3)
|
||||
code += maxVal
|
||||
fmt.Fprintf(os.Stdout, "%0o\n", code)
|
||||
}
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
@@ -47,7 +46,7 @@ func help(w io.Writer) {
|
||||
}
|
||||
|
||||
func loadDatabase() {
|
||||
data, err := ioutil.ReadFile(dbFile)
|
||||
data, err := os.ReadFile(dbFile)
|
||||
if err != nil && os.IsNotExist(err) {
|
||||
partsDB = &database{
|
||||
Version: dbVersion,
|
||||
@@ -74,7 +73,7 @@ func writeDB() {
|
||||
data, err := json.Marshal(partsDB)
|
||||
die.If(err)
|
||||
|
||||
err = ioutil.WriteFile(dbFile, data, 0644)
|
||||
err = os.WriteFile(dbFile, data, 0644)
|
||||
die.If(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -4,14 +4,13 @@ import (
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
var ext = ".bin"
|
||||
|
||||
func stripPEM(path string) error {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
|
||||
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
|
||||
}
|
||||
|
||||
return ioutil.WriteFile(path+ext, p.Bytes, 0644)
|
||||
return os.WriteFile(path+ext, p.Bytes, 0644)
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -3,8 +3,7 @@ package main
|
||||
import (
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
@@ -21,9 +20,9 @@ func main() {
|
||||
|
||||
path := flag.Arg(0)
|
||||
if path == "-" {
|
||||
in, err = ioutil.ReadAll(os.Stdin)
|
||||
in, err = io.ReadAll(os.Stdin)
|
||||
} else {
|
||||
in, err = ioutil.ReadFile(flag.Arg(0))
|
||||
in, err = os.ReadFile(flag.Arg(0))
|
||||
}
|
||||
if err != nil {
|
||||
lib.Err(lib.ExitFailure, err, "couldn't read file")
|
||||
@@ -33,5 +32,7 @@ func main() {
|
||||
if p == nil {
|
||||
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
|
||||
}
|
||||
fmt.Printf("%s", p.Bytes)
|
||||
if _, err = os.Stdout.Write(p.Bytes); err != nil {
|
||||
lib.Err(lib.ExitFailure, err, "writing body")
|
||||
}
|
||||
}
|
||||
@@ -70,7 +70,7 @@ func main() {
|
||||
lib.Err(lib.ExitFailure, err, "failed to read input")
|
||||
}
|
||||
case argc > 1:
|
||||
for i := 0; i < argc; i++ {
|
||||
for i := range argc {
|
||||
path := flag.Arg(i)
|
||||
err = copyFile(path, buf)
|
||||
if err != nil {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
)
|
||||
|
||||
@@ -13,14 +12,14 @@ func main() {
|
||||
flag.Parse()
|
||||
|
||||
for _, fileName := range flag.Args() {
|
||||
data, err := ioutil.ReadFile(fileName)
|
||||
data, err := os.ReadFile(fileName)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
|
||||
continue
|
||||
}
|
||||
|
||||
fmt.Printf("[+] %s:\n", fileName)
|
||||
rest := data[:]
|
||||
fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
|
||||
rest := data
|
||||
for {
|
||||
var p *pem.Block
|
||||
p, rest = pem.Decode(rest)
|
||||
@@ -28,13 +27,14 @@ func main() {
|
||||
break
|
||||
}
|
||||
|
||||
cert, err := x509.ParseCertificate(p.Bytes)
|
||||
var cert *x509.Certificate
|
||||
cert, err = x509.ParseCertificate(p.Bytes)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
|
||||
break
|
||||
}
|
||||
|
||||
fmt.Printf("\t%+v\n", cert.Subject.CommonName)
|
||||
fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
|
||||
return hashName(path, encodedHash), nil
|
||||
}
|
||||
|
||||
func move(dst, src string, force bool) (err error) {
|
||||
func move(dst, src string, force bool) error {
|
||||
if fileutil.FileDoesExist(dst) && !force {
|
||||
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
|
||||
}
|
||||
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func(e error) {
|
||||
var retErr error
|
||||
defer func(e *error) {
|
||||
dstFile.Close()
|
||||
if e != nil {
|
||||
if *e != nil {
|
||||
os.Remove(dst)
|
||||
}
|
||||
}(err)
|
||||
}(&retErr)
|
||||
|
||||
srcFile, err := os.Open(src)
|
||||
if err != nil {
|
||||
retErr = err
|
||||
return err
|
||||
}
|
||||
defer srcFile.Close()
|
||||
|
||||
_, err = io.Copy(dstFile, srcFile)
|
||||
if err != nil {
|
||||
if _, err = io.Copy(dstFile, srcFile); err != nil {
|
||||
retErr = err
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -94,6 +96,44 @@ func init() {
|
||||
flag.Usage = func() { usage(os.Stdout) }
|
||||
}
|
||||
|
||||
type options struct {
|
||||
dryRun, force, printChanged, verbose bool
|
||||
}
|
||||
|
||||
func processOne(file string, opt options) error {
|
||||
renamed, err := newName(file)
|
||||
if err != nil {
|
||||
_, _ = lib.Warn(err, "failed to get new file name")
|
||||
return err
|
||||
}
|
||||
if opt.verbose && !opt.printChanged {
|
||||
fmt.Fprintln(os.Stdout, file)
|
||||
}
|
||||
if renamed == file {
|
||||
return nil
|
||||
}
|
||||
if !opt.dryRun {
|
||||
if err = move(renamed, file, opt.force); err != nil {
|
||||
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
|
||||
return err
|
||||
}
|
||||
}
|
||||
if opt.printChanged && !opt.verbose {
|
||||
fmt.Fprintln(os.Stdout, file, "->", renamed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func run(dryRun, force, printChanged, verbose bool, files []string) {
|
||||
if verbose && printChanged {
|
||||
printChanged = false
|
||||
}
|
||||
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
|
||||
for _, file := range files {
|
||||
_ = processOne(file, opt)
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
var dryRun, force, printChanged, verbose bool
|
||||
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
|
||||
@@ -102,34 +142,5 @@ func main() {
|
||||
flag.BoolVar(&verbose, "v", false, "list all processed files")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
if verbose && printChanged {
|
||||
printChanged = false
|
||||
}
|
||||
|
||||
for _, file := range flag.Args() {
|
||||
renamed, err := newName(file)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to get new file name")
|
||||
continue
|
||||
}
|
||||
|
||||
if verbose && !printChanged {
|
||||
fmt.Println(file)
|
||||
}
|
||||
|
||||
if renamed != file {
|
||||
if !dryRun {
|
||||
err = move(renamed, file, force)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if printChanged && !verbose {
|
||||
fmt.Println(file, "->", renamed)
|
||||
}
|
||||
}
|
||||
}
|
||||
run(dryRun, force, printChanged, verbose, flag.Args())
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -66,24 +67,25 @@ func main() {
|
||||
for _, remote := range flag.Args() {
|
||||
u, err := url.Parse(remote)
|
||||
if err != nil {
|
||||
lib.Warn(err, "parsing %s", remote)
|
||||
_, _ = lib.Warn(err, "parsing %s", remote)
|
||||
continue
|
||||
}
|
||||
|
||||
name := filepath.Base(u.Path)
|
||||
if name == "" {
|
||||
lib.Warnx("source URL doesn't appear to name a file")
|
||||
_, _ = lib.Warnx("source URL doesn't appear to name a file")
|
||||
continue
|
||||
}
|
||||
|
||||
resp, err := http.Get(remote)
|
||||
if err != nil {
|
||||
lib.Warn(err, "fetching %s", remote)
|
||||
req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
|
||||
if reqErr != nil {
|
||||
_, _ = lib.Warn(reqErr, "building request for %s", remote)
|
||||
continue
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
lib.Warn(err, "fetching %s", remote)
|
||||
_, _ = lib.Warn(err, "fetching %s", remote)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ package main
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"regexp"
|
||||
"strconv"
|
||||
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
|
||||
sum := 0
|
||||
var rolls []int
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
roll := rand.Intn(sides) + 1
|
||||
for range count {
|
||||
roll := rand.IntN(sides) + 1 // #nosec G404
|
||||
sum += roll
|
||||
rolls = append(rolls, roll)
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func init() {
|
||||
project = wd[len(gopath):]
|
||||
}
|
||||
|
||||
func walkFile(path string, info os.FileInfo, err error) error {
|
||||
func walkFile(path string, _ os.FileInfo, err error) error {
|
||||
if ignores[path] {
|
||||
return filepath.SkipDir
|
||||
}
|
||||
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
debug.Println(path)
|
||||
|
||||
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
debug.Println(path)
|
||||
|
||||
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
|
||||
for _, importSpec := range f.Imports {
|
||||
importPath := strings.Trim(importSpec.Path.Value, `"`)
|
||||
if stdLibRegexp.MatchString(importPath) {
|
||||
switch {
|
||||
case stdLibRegexp.MatchString(importPath):
|
||||
debug.Println("standard lib:", importPath)
|
||||
continue
|
||||
} else if strings.HasPrefix(importPath, project) {
|
||||
case strings.HasPrefix(importPath, project):
|
||||
debug.Println("internal import:", importPath)
|
||||
continue
|
||||
} else if strings.HasPrefix(importPath, "golang.org/") {
|
||||
case strings.HasPrefix(importPath, "golang.org/"):
|
||||
debug.Println("extended lib:", importPath)
|
||||
continue
|
||||
}
|
||||
@@ -102,7 +107,7 @@ func main() {
|
||||
ignores["vendor"] = true
|
||||
}
|
||||
|
||||
for _, word := range strings.Split(ignoreLine, ",") {
|
||||
for word := range strings.SplitSeq(ignoreLine, ",") {
|
||||
ignores[strings.TrimSpace(word)] = true
|
||||
}
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto"
|
||||
"crypto/ecdsa"
|
||||
"crypto/rsa"
|
||||
"crypto/sha1"
|
||||
"crypto/sha1" // #nosec G505
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
@@ -13,14 +12,19 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
)
|
||||
|
||||
const (
|
||||
keyTypeRSA = "RSA"
|
||||
keyTypeECDSA = "ECDSA"
|
||||
)
|
||||
|
||||
func usage(w io.Writer) {
|
||||
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
|
||||
|
||||
@@ -39,14 +43,14 @@ func init() {
|
||||
flag.Usage = func() { usage(os.Stderr) }
|
||||
}
|
||||
|
||||
func parse(path string) (public []byte, kt, ft string) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
func parse(path string) ([]byte, string, string) {
|
||||
data, err := os.ReadFile(path)
|
||||
die.If(err)
|
||||
|
||||
data = bytes.TrimSpace(data)
|
||||
p, rest := pem.Decode(data)
|
||||
if len(rest) > 0 {
|
||||
lib.Warnx("trailing data in PEM file")
|
||||
_, _ = lib.Warnx("trailing data in PEM file")
|
||||
}
|
||||
|
||||
if p == nil {
|
||||
@@ -55,6 +59,12 @@ func parse(path string) (public []byte, kt, ft string) {
|
||||
|
||||
data = p.Bytes
|
||||
|
||||
var (
|
||||
public []byte
|
||||
kt string
|
||||
ft string
|
||||
)
|
||||
|
||||
switch p.Type {
|
||||
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
|
||||
public, kt = parseKey(data)
|
||||
@@ -69,82 +79,79 @@ func parse(path string) (public []byte, kt, ft string) {
|
||||
die.With("unknown PEM type %s", p.Type)
|
||||
}
|
||||
|
||||
return
|
||||
return public, kt, ft
|
||||
}
|
||||
|
||||
func parseKey(data []byte) (public []byte, kt string) {
|
||||
privInterface, err := x509.ParsePKCS8PrivateKey(data)
|
||||
func parseKey(data []byte) ([]byte, string) {
|
||||
priv, err := certlib.ParsePrivateKeyDER(data)
|
||||
if err != nil {
|
||||
privInterface, err = x509.ParsePKCS1PrivateKey(data)
|
||||
if err != nil {
|
||||
privInterface, err = x509.ParseECPrivateKey(data)
|
||||
if err != nil {
|
||||
die.With("couldn't parse private key.")
|
||||
}
|
||||
}
|
||||
die.If(err)
|
||||
}
|
||||
|
||||
var priv crypto.Signer
|
||||
switch privInterface.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
priv = privInterface.(*rsa.PrivateKey)
|
||||
kt = "RSA"
|
||||
case *ecdsa.PrivateKey:
|
||||
priv = privInterface.(*ecdsa.PrivateKey)
|
||||
kt = "ECDSA"
|
||||
var kt string
|
||||
switch priv.Public().(type) {
|
||||
case *rsa.PublicKey:
|
||||
kt = keyTypeRSA
|
||||
case *ecdsa.PublicKey:
|
||||
kt = keyTypeECDSA
|
||||
default:
|
||||
die.With("unknown private key type %T", privInterface)
|
||||
die.With("unknown private key type %T", priv)
|
||||
}
|
||||
|
||||
public, err = x509.MarshalPKIXPublicKey(priv.Public())
|
||||
public, err := x509.MarshalPKIXPublicKey(priv.Public())
|
||||
die.If(err)
|
||||
|
||||
return
|
||||
return public, kt
|
||||
}
|
||||
|
||||
func parseCertificate(data []byte) (public []byte, kt string) {
|
||||
func parseCertificate(data []byte) ([]byte, string) {
|
||||
cert, err := x509.ParseCertificate(data)
|
||||
die.If(err)
|
||||
|
||||
pub := cert.PublicKey
|
||||
var kt string
|
||||
switch pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
kt = "RSA"
|
||||
kt = keyTypeRSA
|
||||
case *ecdsa.PublicKey:
|
||||
kt = "ECDSA"
|
||||
kt = keyTypeECDSA
|
||||
default:
|
||||
die.With("unknown public key type %T", pub)
|
||||
}
|
||||
|
||||
public, err = x509.MarshalPKIXPublicKey(pub)
|
||||
public, err := x509.MarshalPKIXPublicKey(pub)
|
||||
die.If(err)
|
||||
return
|
||||
return public, kt
|
||||
}
|
||||
|
||||
func parseCSR(data []byte) (public []byte, kt string) {
|
||||
csr, err := x509.ParseCertificateRequest(data)
|
||||
func parseCSR(data []byte) ([]byte, string) {
|
||||
// Use certlib to support both PEM and DER and to centralize validation.
|
||||
csr, _, err := certlib.ParseCSR(data)
|
||||
die.If(err)
|
||||
|
||||
pub := csr.PublicKey
|
||||
var kt string
|
||||
switch pub.(type) {
|
||||
case *rsa.PublicKey:
|
||||
kt = "RSA"
|
||||
kt = keyTypeRSA
|
||||
case *ecdsa.PublicKey:
|
||||
kt = "ECDSA"
|
||||
kt = keyTypeECDSA
|
||||
default:
|
||||
die.With("unknown public key type %T", pub)
|
||||
}
|
||||
|
||||
public, err = x509.MarshalPKIXPublicKey(pub)
|
||||
public, err := x509.MarshalPKIXPublicKey(pub)
|
||||
die.If(err)
|
||||
return
|
||||
return public, kt
|
||||
}
|
||||
|
||||
func dumpHex(in []byte) string {
|
||||
var s string
|
||||
var sSb153 strings.Builder
|
||||
for i := range in {
|
||||
s += fmt.Sprintf("%02X:", in[i])
|
||||
sSb153.WriteString(fmt.Sprintf("%02X:", in[i]))
|
||||
}
|
||||
s += sSb153.String()
|
||||
|
||||
return strings.Trim(s, ":")
|
||||
}
|
||||
@@ -172,18 +179,18 @@ func main() {
|
||||
var subPKI subjectPublicKeyInfo
|
||||
_, err := asn1.Unmarshal(public, &subPKI)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to get subject PKI")
|
||||
_, _ = lib.Warn(err, "failed to get subject PKI")
|
||||
continue
|
||||
}
|
||||
|
||||
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes)
|
||||
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
|
||||
pubHashString := dumpHex(pubHash[:])
|
||||
if ski == "" {
|
||||
ski = pubHashString
|
||||
}
|
||||
|
||||
if shouldMatch && ski != pubHashString {
|
||||
lib.Warnx("%s: SKI mismatch (%s != %s)",
|
||||
_, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
|
||||
path, ski, pubHashString)
|
||||
}
|
||||
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
)
|
||||
|
||||
func proxy(conn net.Conn, inside string) error {
|
||||
proxyConn, err := net.Dial("tcp", inside)
|
||||
proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
|
||||
defer conn.Close()
|
||||
|
||||
go func() {
|
||||
io.Copy(conn, proxyConn)
|
||||
_, _ = io.Copy(conn, proxyConn)
|
||||
}()
|
||||
_, err = io.Copy(proxyConn, conn)
|
||||
return err
|
||||
@@ -31,16 +32,22 @@ func main() {
|
||||
flag.StringVar(&inside, "p", "4000", "inside port")
|
||||
flag.Parse()
|
||||
|
||||
l, err := net.Listen("tcp", "0.0.0.0:"+outside)
|
||||
lc := &net.ListenConfig{}
|
||||
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
|
||||
die.If(err)
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
var conn net.Conn
|
||||
conn, err = l.Accept()
|
||||
if err != nil {
|
||||
log.Println(err)
|
||||
_, _ = lib.Warn(err, "accept failed")
|
||||
continue
|
||||
}
|
||||
|
||||
go proxy(conn, "127.0.0.1:"+inside)
|
||||
go func() {
|
||||
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
|
||||
_, _ = lib.Warn(err, "proxy error")
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
@@ -8,7 +9,6 @@ import (
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
cfg := &tls.Config{}
|
||||
cfg := &tls.Config{} // #nosec G402
|
||||
|
||||
var sysRoot, listenAddr, certFile, keyFile string
|
||||
var verify bool
|
||||
@@ -47,7 +47,8 @@ func main() {
|
||||
}
|
||||
cfg.Certificates = append(cfg.Certificates, cert)
|
||||
if sysRoot != "" {
|
||||
pemList, err := ioutil.ReadFile(sysRoot)
|
||||
var pemList []byte
|
||||
pemList, err = os.ReadFile(sysRoot)
|
||||
die.If(err)
|
||||
|
||||
roots := x509.NewCertPool()
|
||||
@@ -59,48 +60,54 @@ func main() {
|
||||
cfg.RootCAs = roots
|
||||
}
|
||||
|
||||
l, err := net.Listen("tcp", listenAddr)
|
||||
lc := &net.ListenConfig{}
|
||||
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := l.Accept()
|
||||
var conn net.Conn
|
||||
conn, err = l.Accept()
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
}
|
||||
|
||||
raddr := conn.RemoteAddr()
|
||||
tconn := tls.Server(conn, cfg)
|
||||
err = tconn.Handshake()
|
||||
if err != nil {
|
||||
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
|
||||
continue
|
||||
}
|
||||
cs := tconn.ConnectionState()
|
||||
if len(cs.PeerCertificates) == 0 {
|
||||
fmt.Printf("[+] %v: no chain presented\n", raddr)
|
||||
continue
|
||||
}
|
||||
|
||||
var chain []byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
p := &pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: cert.Raw,
|
||||
}
|
||||
chain = append(chain, pem.EncodeToMemory(p)...)
|
||||
}
|
||||
|
||||
var nonce [16]byte
|
||||
_, err = rand.Read(nonce[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
|
||||
err = ioutil.WriteFile(fname, chain, 0644)
|
||||
die.If(err)
|
||||
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
|
||||
handleConn(conn, cfg)
|
||||
}
|
||||
}
|
||||
|
||||
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
|
||||
func handleConn(conn net.Conn, cfg *tls.Config) {
|
||||
defer conn.Close()
|
||||
raddr := conn.RemoteAddr()
|
||||
tconn := tls.Server(conn, cfg)
|
||||
if err := tconn.HandshakeContext(context.Background()); err != nil {
|
||||
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
|
||||
return
|
||||
}
|
||||
cs := tconn.ConnectionState()
|
||||
if len(cs.PeerCertificates) == 0 {
|
||||
fmt.Printf("[+] %v: no chain presented\n", raddr)
|
||||
return
|
||||
}
|
||||
|
||||
var chain []byte
|
||||
for _, cert := range cs.PeerCertificates {
|
||||
p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
|
||||
chain = append(chain, pem.EncodeToMemory(p)...)
|
||||
}
|
||||
|
||||
var nonce [16]byte
|
||||
if _, err := rand.Read(nonce[:]); err != nil {
|
||||
fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
|
||||
return
|
||||
}
|
||||
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
|
||||
if err := os.WriteFile(fname, chain, 0o644); err != nil {
|
||||
fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
|
||||
return
|
||||
}
|
||||
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
)
|
||||
|
||||
func main() {
|
||||
var cfg = &tls.Config{}
|
||||
var cfg = &tls.Config{} // #nosec G402
|
||||
|
||||
var sysRoot, serverName string
|
||||
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
|
||||
@@ -23,7 +23,7 @@ func main() {
|
||||
flag.Parse()
|
||||
|
||||
if sysRoot != "" {
|
||||
pemList, err := ioutil.ReadFile(sysRoot)
|
||||
pemList, err := os.ReadFile(sysRoot)
|
||||
die.If(err)
|
||||
|
||||
roots := x509.NewCertPool()
|
||||
@@ -44,10 +44,13 @@ func main() {
|
||||
if err != nil {
|
||||
site += ":443"
|
||||
}
|
||||
conn, err := tls.Dial("tcp", site, cfg)
|
||||
if err != nil {
|
||||
fmt.Println(err.Error())
|
||||
os.Exit(1)
|
||||
d := &tls.Dialer{Config: cfg}
|
||||
nc, err := d.DialContext(context.Background(), "tcp", site)
|
||||
die.If(err)
|
||||
|
||||
conn, ok := nc.(*tls.Conn)
|
||||
if !ok {
|
||||
die.With("invalid TLS connection (not a *tls.Conn)")
|
||||
}
|
||||
|
||||
cs := conn.ConnectionState()
|
||||
@@ -61,8 +64,9 @@ func main() {
|
||||
chain = append(chain, pem.EncodeToMemory(p)...)
|
||||
}
|
||||
|
||||
err = ioutil.WriteFile(site+".pem", chain, 0644)
|
||||
err = os.WriteFile(site+".pem", chain, 0644)
|
||||
die.If(err)
|
||||
|
||||
fmt.Printf("[+] wrote %s.pem.\n", site)
|
||||
}
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
|
||||
for _, path := range paths {
|
||||
cert, err := certlib.LoadCertificate(path)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to load certificate from %s", path)
|
||||
_, _ = lib.Warn(err, "failed to load certificate from %s", path)
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
|
||||
}
|
||||
|
||||
var invalid int
|
||||
for {
|
||||
if len(paths) == 0 {
|
||||
break
|
||||
}
|
||||
for len(paths) > 0 {
|
||||
fst := paths[0]
|
||||
snd := paths[1]
|
||||
paths = paths[2:]
|
||||
|
||||
fstCert, err := certlib.LoadCertificate(fst)
|
||||
die.If(err)
|
||||
|
||||
sndCert, err := certlib.LoadCertificate(snd)
|
||||
die.If(err)
|
||||
|
||||
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
|
||||
lib.Warnx("certificates don't match: %s and %s", fst, snd)
|
||||
_, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
|
||||
invalid++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -13,16 +17,23 @@ func main() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
hostPort := os.Args[1]
|
||||
conn, err := tls.Dial("tcp", hostPort, &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
hostPort, err := hosts.ParseHost(os.Args[1])
|
||||
die.If(err)
|
||||
|
||||
if err != nil {
|
||||
fmt.Printf("Failed to connect to the TLS server: %v\n", err)
|
||||
os.Exit(1)
|
||||
d := &tls.Dialer{Config: &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
}} // #nosec G402
|
||||
|
||||
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
|
||||
die.If(err)
|
||||
|
||||
conn, ok := nc.(*tls.Conn)
|
||||
if !ok {
|
||||
die.With("invalid TLS connection (not a *tls.Conn)")
|
||||
}
|
||||
|
||||
defer conn.Close()
|
||||
|
||||
state := conn.ConnectionState()
|
||||
printConnectionDetails(state)
|
||||
}
|
||||
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
|
||||
|
||||
func tlsVersion(version uint16) string {
|
||||
switch version {
|
||||
|
||||
case tls.VersionTLS13:
|
||||
return "TLS 1.3"
|
||||
case tls.VersionTLS12:
|
||||
|
||||
@@ -11,10 +11,9 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/certlib"
|
||||
"git.wntrmute.dev/kyle/goutils/die"
|
||||
)
|
||||
|
||||
@@ -32,7 +31,7 @@ const (
|
||||
curveP521
|
||||
)
|
||||
|
||||
func getECCurve(pub interface{}) int {
|
||||
func getECCurve(pub any) int {
|
||||
switch pub := pub.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
switch pub.Curve {
|
||||
@@ -52,42 +51,88 @@ func getECCurve(pub interface{}) int {
|
||||
}
|
||||
}
|
||||
|
||||
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
|
||||
// It returns true on match.
|
||||
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
|
||||
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
|
||||
}
|
||||
|
||||
// matchECDSA compares ECDSA public keys for equality and compatible curve.
|
||||
// It returns match=true when they are on the same curve and have the same X/Y.
|
||||
// If curves mismatch, match is false.
|
||||
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
|
||||
if getECCurve(certPub) != getECCurve(keyPub) {
|
||||
return false
|
||||
}
|
||||
if keyPub.X.Cmp(certPub.X) != 0 {
|
||||
return false
|
||||
}
|
||||
if keyPub.Y.Cmp(certPub.Y) != 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// matchKeys determines whether the certificate's public key matches the given private key.
|
||||
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
|
||||
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
|
||||
switch keyPub := priv.Public().(type) {
|
||||
case *rsa.PublicKey:
|
||||
switch certPub := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if matchRSA(certPub, keyPub) {
|
||||
return true, ""
|
||||
}
|
||||
return false, "public keys don't match"
|
||||
case *ecdsa.PublicKey:
|
||||
return false, "RSA private key, EC public key"
|
||||
default:
|
||||
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
switch certPub := cert.PublicKey.(type) {
|
||||
case *ecdsa.PublicKey:
|
||||
if matchECDSA(certPub, keyPub) {
|
||||
return true, ""
|
||||
}
|
||||
// Determine a more precise reason
|
||||
kc := getECCurve(keyPub)
|
||||
cc := getECCurve(certPub)
|
||||
if kc == curveInvalid {
|
||||
return false, "invalid private key curve"
|
||||
}
|
||||
if cc == curveRSA {
|
||||
return false, "private key is EC, certificate is RSA"
|
||||
}
|
||||
if kc != cc {
|
||||
return false, "EC curves don't match"
|
||||
}
|
||||
return false, "public keys don't match"
|
||||
case *rsa.PublicKey:
|
||||
return false, "private key is EC, certificate is RSA"
|
||||
default:
|
||||
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
|
||||
}
|
||||
default:
|
||||
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
|
||||
}
|
||||
}
|
||||
|
||||
func loadKey(path string) (crypto.Signer, error) {
|
||||
in, err := ioutil.ReadFile(path)
|
||||
in, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
in = bytes.TrimSpace(in)
|
||||
p, _ := pem.Decode(in)
|
||||
if p != nil {
|
||||
if p, _ := pem.Decode(in); p != nil {
|
||||
if !validPEMs[p.Type] {
|
||||
return nil, errors.New("invalid private key file type " + p.Type)
|
||||
}
|
||||
in = p.Bytes
|
||||
return certlib.ParsePrivateKeyPEM(in)
|
||||
}
|
||||
|
||||
priv, err := x509.ParsePKCS8PrivateKey(in)
|
||||
if err != nil {
|
||||
priv, err = x509.ParsePKCS1PrivateKey(in)
|
||||
if err != nil {
|
||||
priv, err = x509.ParseECPrivateKey(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch priv.(type) {
|
||||
case *rsa.PrivateKey:
|
||||
return priv.(*rsa.PrivateKey), nil
|
||||
case *ecdsa.PrivateKey:
|
||||
return priv.(*ecdsa.PrivateKey), nil
|
||||
}
|
||||
|
||||
// should never reach here
|
||||
return nil, errors.New("invalid private key")
|
||||
|
||||
return certlib.ParsePrivateKeyDER(in)
|
||||
}
|
||||
|
||||
func main() {
|
||||
@@ -96,7 +141,7 @@ func main() {
|
||||
flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
|
||||
flag.Parse()
|
||||
|
||||
in, err := ioutil.ReadFile(certFile)
|
||||
in, err := os.ReadFile(certFile)
|
||||
die.If(err)
|
||||
|
||||
p, _ := pem.Decode(in)
|
||||
@@ -112,50 +157,11 @@ func main() {
|
||||
priv, err := loadKey(keyFile)
|
||||
die.If(err)
|
||||
|
||||
switch pub := priv.Public().(type) {
|
||||
case *rsa.PublicKey:
|
||||
switch certPub := cert.PublicKey.(type) {
|
||||
case *rsa.PublicKey:
|
||||
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
|
||||
fmt.Println("No match (public keys don't match).")
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("Match.")
|
||||
return
|
||||
case *ecdsa.PublicKey:
|
||||
fmt.Println("No match (RSA private key, EC public key).")
|
||||
os.Exit(1)
|
||||
}
|
||||
case *ecdsa.PublicKey:
|
||||
privCurve := getECCurve(pub)
|
||||
certCurve := getECCurve(cert.PublicKey)
|
||||
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
|
||||
|
||||
if certCurve == curveRSA {
|
||||
fmt.Println("No match (private key is EC, certificate is RSA).")
|
||||
os.Exit(1)
|
||||
} else if privCurve == curveInvalid {
|
||||
fmt.Println("No match (invalid private key curve).")
|
||||
os.Exit(1)
|
||||
} else if privCurve != certCurve {
|
||||
fmt.Println("No match (EC curves don't match).")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
certPub := cert.PublicKey.(*ecdsa.PublicKey)
|
||||
if pub.X.Cmp(certPub.X) != 0 {
|
||||
fmt.Println("No match (public keys don't match).")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if pub.Y.Cmp(certPub.Y) != 0 {
|
||||
fmt.Println("No match (public keys don't match).")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
matched, reason := matchKeys(cert, priv)
|
||||
if matched {
|
||||
fmt.Println("Match.")
|
||||
default:
|
||||
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
|
||||
os.Exit(1)
|
||||
return
|
||||
}
|
||||
fmt.Printf("No match (%s).\n", reason)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -201,10 +201,6 @@ func init() {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if fromLoc == time.UTC {
|
||||
|
||||
}
|
||||
|
||||
toLoc = time.UTC
|
||||
}
|
||||
|
||||
@@ -257,15 +253,16 @@ func main() {
|
||||
showTime(time.Now())
|
||||
os.Exit(0)
|
||||
case 1:
|
||||
if flag.Arg(0) == "-" {
|
||||
switch {
|
||||
case flag.Arg(0) == "-":
|
||||
s := bufio.NewScanner(os.Stdin)
|
||||
|
||||
for s.Scan() {
|
||||
times = append(times, s.Text())
|
||||
}
|
||||
} else if flag.Arg(0) == "help" {
|
||||
case flag.Arg(0) == "help":
|
||||
usageExamples()
|
||||
} else {
|
||||
default:
|
||||
times = flag.Args()
|
||||
}
|
||||
default:
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
"gopkg.in/yaml.v2"
|
||||
@@ -12,9 +11,8 @@ import (
|
||||
|
||||
type empty struct{}
|
||||
|
||||
func errorf(format string, args ...interface{}) {
|
||||
format += "\n"
|
||||
fmt.Fprintf(os.Stderr, format, args...)
|
||||
func errorf(path string, err error) {
|
||||
fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
|
||||
}
|
||||
|
||||
func usage(w io.Writer) {
|
||||
@@ -44,16 +42,16 @@ func main() {
|
||||
|
||||
if flag.NArg() == 1 && flag.Arg(0) == "-" {
|
||||
path := "stdin"
|
||||
in, err := ioutil.ReadAll(os.Stdin)
|
||||
in, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
errorf("%s FAILED: %s", path, err)
|
||||
errorf(path, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var e empty
|
||||
err = yaml.Unmarshal(in, &e)
|
||||
if err != nil {
|
||||
errorf("%s FAILED: %s", path, err)
|
||||
errorf(path, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
@@ -65,16 +63,16 @@ func main() {
|
||||
}
|
||||
|
||||
for _, path := range flag.Args() {
|
||||
in, err := ioutil.ReadFile(path)
|
||||
in, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
errorf("%s FAILED: %s", path, err)
|
||||
errorf(path, err)
|
||||
continue
|
||||
}
|
||||
|
||||
var e empty
|
||||
err = yaml.Unmarshal(in, &e)
|
||||
if err != nil {
|
||||
errorf("%s FAILED: %s", path, err)
|
||||
errorf(path, err)
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -14,16 +14,16 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
|
||||
"git.wntrmute.dev/kyle/goutils/lib"
|
||||
)
|
||||
|
||||
const defaultDirectory = ".git/objects"
|
||||
|
||||
func errorf(format string, a ...interface{}) {
|
||||
fmt.Fprintf(os.Stderr, format, a...)
|
||||
if format[len(format)-1] != '\n' {
|
||||
fmt.Fprintf(os.Stderr, "\n")
|
||||
}
|
||||
}
|
||||
// maxDecompressedSize limits how many bytes we will decompress from a zlib
|
||||
// stream to mitigate decompression bombs (gosec G110).
|
||||
// Increase this if you expect larger objects.
|
||||
const maxDecompressedSize int64 = 64 << 30 // 64 GiB
|
||||
|
||||
func isDir(path string) bool {
|
||||
fi, err := os.Stat(path)
|
||||
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
|
||||
}
|
||||
defer zread.Close()
|
||||
|
||||
_, err = io.Copy(buf, zread)
|
||||
if err != nil {
|
||||
// Protect against decompression bombs by limiting how much we read.
|
||||
lr := io.LimitReader(zread, maxDecompressedSize+1)
|
||||
if _, err = buf.ReadFrom(lr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if int64(buf.Len()) > maxDecompressedSize {
|
||||
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func showFile(path string) {
|
||||
fileData, err := loadFile(path)
|
||||
if err != nil {
|
||||
errorf("%v", err)
|
||||
lib.Warn(err, "failed to load %s", path)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -68,39 +72,71 @@ func showFile(path string) {
|
||||
func searchFile(path string, search *regexp.Regexp) error {
|
||||
file, err := os.Open(path)
|
||||
if err != nil {
|
||||
errorf("%v", err)
|
||||
lib.Warn(err, "failed to open %s", path)
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
zread, err := zlib.NewReader(file)
|
||||
if err != nil {
|
||||
errorf("%v", err)
|
||||
lib.Warn(err, "failed to decompress %s", path)
|
||||
return err
|
||||
}
|
||||
defer zread.Close()
|
||||
|
||||
zbuf := bufio.NewReader(zread)
|
||||
if search.MatchReader(zbuf) {
|
||||
fileData, err := loadFile(path)
|
||||
if err != nil {
|
||||
errorf("%v", err)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%s:\n%s\n", path, fileData)
|
||||
// Limit how much we scan to avoid DoS via huge decompression.
|
||||
lr := io.LimitReader(zread, maxDecompressedSize+1)
|
||||
zbuf := bufio.NewReader(lr)
|
||||
if !search.MatchReader(zbuf) {
|
||||
return nil
|
||||
}
|
||||
|
||||
fileData, err := loadFile(path)
|
||||
if err != nil {
|
||||
lib.Warn(err, "failed to load %s", path)
|
||||
return err
|
||||
}
|
||||
fmt.Printf("%s:\n%s\n", path, fileData)
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
|
||||
return func(path string, info os.FileInfo, err error) error {
|
||||
if info.Mode().IsRegular() {
|
||||
return searchFile(path, searchExpr)
|
||||
return func(path string, info os.FileInfo, _ error) error {
|
||||
if !info.Mode().IsRegular() {
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
return searchFile(path, searchExpr)
|
||||
}
|
||||
}
|
||||
|
||||
// runSearch compiles the search expression and processes the provided paths.
|
||||
// It returns an error for fatal conditions; per-file errors are logged.
|
||||
func runSearch(expr string) error {
|
||||
search, err := regexp.Compile(expr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid regexp: %w", err)
|
||||
}
|
||||
|
||||
pathList := flag.Args()
|
||||
if len(pathList) == 0 {
|
||||
pathList = []string{defaultDirectory}
|
||||
}
|
||||
|
||||
for _, path := range pathList {
|
||||
if isDir(path) {
|
||||
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
|
||||
return err2
|
||||
}
|
||||
continue
|
||||
}
|
||||
if err2 := searchFile(path, search); err2 != nil {
|
||||
// Non-fatal: keep going, but report it.
|
||||
lib.Warn(err2, "non-fatal error while searching files")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func main() {
|
||||
flSearch := flag.String("s", "", "search string (should be an RE2 regular expression)")
|
||||
flag.Parse()
|
||||
@@ -109,28 +145,10 @@ func main() {
|
||||
for _, path := range flag.Args() {
|
||||
showFile(path)
|
||||
}
|
||||
} else {
|
||||
search, err := regexp.Compile(*flSearch)
|
||||
if err != nil {
|
||||
errorf("Bad regexp: %v", err)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pathList := flag.Args()
|
||||
if len(pathList) == 0 {
|
||||
pathList = []string{defaultDirectory}
|
||||
}
|
||||
|
||||
for _, path := range pathList {
|
||||
if isDir(path) {
|
||||
err := filepath.Walk(path, buildWalker(search))
|
||||
if err != nil {
|
||||
errorf("%v", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
searchFile(path, search)
|
||||
}
|
||||
}
|
||||
if err := runSearch(*flSearch); err != nil {
|
||||
lib.Err(lib.ExitFailure, err, "failed to run search")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// +build freebsd darwin,386 netbsd
|
||||
//go:build bsd
|
||||
|
||||
package lib
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// +build unix linux openbsd darwin,amd64
|
||||
//go:build unix || linux || openbsd || (darwin && amd64)
|
||||
|
||||
package lib
|
||||
|
||||
@@ -18,7 +18,7 @@ type FileTime struct {
|
||||
|
||||
func timeSpecToTime(ts unix.Timespec) time.Time {
|
||||
// The casts to int64 are needed because on 386, these are int32s.
|
||||
return time.Unix(int64(ts.Sec), int64(ts.Nsec))
|
||||
return time.Unix(ts.Sec, ts.Nsec)
|
||||
}
|
||||
|
||||
// LoadFileTime returns a FileTime associated with the file.
|
||||
|
||||
32
lib/lib.go
32
lib/lib.go
@@ -10,6 +10,12 @@ import (
|
||||
|
||||
var progname = filepath.Base(os.Args[0])
|
||||
|
||||
const (
|
||||
daysInYear = 365
|
||||
digitWidth = 10
|
||||
hoursInQuarterDay = 6
|
||||
)
|
||||
|
||||
// ProgName returns what lib thinks the program name is, namely the
|
||||
// basename of argv0.
|
||||
//
|
||||
@@ -20,7 +26,7 @@ func ProgName() string {
|
||||
|
||||
// Warnx displays a formatted error message to standard error, à la
|
||||
// warnx(3).
|
||||
func Warnx(format string, a ...interface{}) (int, error) {
|
||||
func Warnx(format string, a ...any) (int, error) {
|
||||
format = fmt.Sprintf("[%s] %s", progname, format)
|
||||
format += "\n"
|
||||
return fmt.Fprintf(os.Stderr, format, a...)
|
||||
@@ -28,7 +34,7 @@ func Warnx(format string, a ...interface{}) (int, error) {
|
||||
|
||||
// Warn displays a formatted error message to standard output,
|
||||
// appending the error string, à la warn(3).
|
||||
func Warn(err error, format string, a ...interface{}) (int, error) {
|
||||
func Warn(err error, format string, a ...any) (int, error) {
|
||||
format = fmt.Sprintf("[%s] %s", progname, format)
|
||||
format += ": %v\n"
|
||||
a = append(a, err)
|
||||
@@ -37,7 +43,7 @@ func Warn(err error, format string, a ...interface{}) (int, error) {
|
||||
|
||||
// Errx displays a formatted error message to standard error and exits
|
||||
// with the status code from `exit`, à la errx(3).
|
||||
func Errx(exit int, format string, a ...interface{}) {
|
||||
func Errx(exit int, format string, a ...any) {
|
||||
format = fmt.Sprintf("[%s] %s", progname, format)
|
||||
format += "\n"
|
||||
fmt.Fprintf(os.Stderr, format, a...)
|
||||
@@ -47,7 +53,7 @@ func Errx(exit int, format string, a ...interface{}) {
|
||||
// Err displays a formatting error message to standard error,
|
||||
// appending the error string, and exits with the status code from
|
||||
// `exit`, à la err(3).
|
||||
func Err(exit int, err error, format string, a ...interface{}) {
|
||||
func Err(exit int, err error, format string, a ...any) {
|
||||
format = fmt.Sprintf("[%s] %s", progname, format)
|
||||
format += ": %v\n"
|
||||
a = append(a, err)
|
||||
@@ -62,30 +68,30 @@ func Itoa(i int, wid int) string {
|
||||
// Assemble decimal in reverse order.
|
||||
var b [20]byte
|
||||
bp := len(b) - 1
|
||||
for i >= 10 || wid > 1 {
|
||||
for i >= digitWidth || wid > 1 {
|
||||
wid--
|
||||
q := i / 10
|
||||
b[bp] = byte('0' + i - q*10)
|
||||
q := i / digitWidth
|
||||
b[bp] = byte('0' + i - q*digitWidth)
|
||||
bp--
|
||||
i = q
|
||||
}
|
||||
// i < 10
|
||||
|
||||
b[bp] = byte('0' + i)
|
||||
return string(b[bp:])
|
||||
}
|
||||
|
||||
var (
|
||||
dayDuration = 24 * time.Hour
|
||||
yearDuration = (365 * dayDuration) + (6 * time.Hour)
|
||||
yearDuration = (daysInYear * dayDuration) + (hoursInQuarterDay * time.Hour)
|
||||
)
|
||||
|
||||
// Duration returns a prettier string for time.Durations.
|
||||
func Duration(d time.Duration) string {
|
||||
var s string
|
||||
if d >= yearDuration {
|
||||
years := d / yearDuration
|
||||
years := int64(d / yearDuration)
|
||||
s += fmt.Sprintf("%dy", years)
|
||||
d -= years * yearDuration
|
||||
d -= time.Duration(years) * yearDuration
|
||||
}
|
||||
|
||||
if d >= dayDuration {
|
||||
@@ -98,8 +104,8 @@ func Duration(d time.Duration) string {
|
||||
}
|
||||
|
||||
d %= 1 * time.Second
|
||||
hours := d / time.Hour
|
||||
d -= hours * time.Hour
|
||||
hours := int64(d / time.Hour)
|
||||
d -= time.Duration(hours) * time.Hour
|
||||
s += fmt.Sprintf("%dh%s", hours, d)
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// File writes its logs to file.
|
||||
@@ -59,12 +60,12 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) {
|
||||
fl.fe, err = os.OpenFile(errpath, os.O_WRONLY|os.O_APPEND|os.O_CREATE, 0600)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if closeErr := fl.Close(); closeErr != nil {
|
||||
return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
if closeErr := fl.Close(); closeErr != nil {
|
||||
return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err))
|
||||
}
|
||||
return nil, fmt.Errorf("failed to open error log: %w", err)
|
||||
}
|
||||
|
||||
fl.LogWriter = NewLogWriter(fl.fo, fl.fe)
|
||||
return fl, nil
|
||||
@@ -94,13 +95,13 @@ func (fl *File) Flush() error {
|
||||
}
|
||||
|
||||
func (fl *File) Chmod(mode os.FileMode) error {
|
||||
if err := fl.fo.Chmod(mode); err != nil {
|
||||
return fmt.Errorf("failed to chmod output log: %w", err)
|
||||
}
|
||||
if err := fl.fo.Chmod(mode); err != nil {
|
||||
return fmt.Errorf("failed to chmod output log: %w", err)
|
||||
}
|
||||
|
||||
if err := fl.fe.Chmod(mode); err != nil {
|
||||
return fmt.Errorf("failed to chmod error log: %w", err)
|
||||
}
|
||||
if err := fl.fe.Chmod(mode); err != nil {
|
||||
return fmt.Errorf("failed to chmod error log: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil
|
||||
}
|
||||
|
||||
42
twofactor/.circleci/config.yml
Normal file
42
twofactor/.circleci/config.yml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Use the latest 2.1 version of CircleCI pipeline process engine.
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference
|
||||
version: 2.1
|
||||
|
||||
# Define a job to be invoked later in a workflow.
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#jobs
|
||||
jobs:
|
||||
testbuild:
|
||||
working_directory: ~/repo
|
||||
# Specify the execution environment. You can specify an image from Dockerhub or use one of our Convenience Images from CircleCI's Developer Hub.
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#docker-machine-macos-windows-executor
|
||||
docker:
|
||||
- image: cimg/go:1.22.2
|
||||
# Add steps to the job
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#steps
|
||||
steps:
|
||||
- checkout
|
||||
- restore_cache:
|
||||
keys:
|
||||
- go-mod-v4-{{ checksum "go.sum" }}
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: go mod download
|
||||
- save_cache:
|
||||
key: go-mod-v4-{{ checksum "go.sum" }}
|
||||
paths:
|
||||
- "/go/pkg/mod"
|
||||
- run:
|
||||
name: Run tests
|
||||
command: go test ./...
|
||||
- run:
|
||||
name: Run build
|
||||
command: go build ./...
|
||||
- store_test_results:
|
||||
path: /tmp/test-reports
|
||||
|
||||
# Invoke jobs via workflows
|
||||
# See: https://circleci.com/docs/2.0/configuration-reference/#workflows
|
||||
workflows:
|
||||
testbuild:
|
||||
jobs:
|
||||
- testbuild
|
||||
19
twofactor/LICENSE
Normal file
19
twofactor/LICENSE
Normal file
@@ -0,0 +1,19 @@
|
||||
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
33
twofactor/README.md
Normal file
33
twofactor/README.md
Normal file
@@ -0,0 +1,33 @@
|
||||
## `twofactor`
|
||||
|
||||
[](https://godoc.org/github.com/gokyle/twofactor)
|
||||
|
||||
|
||||
### Author
|
||||
|
||||
`twofactor` was written by Kyle Isom <kyle@tyrfingr.is>.
|
||||
|
||||
|
||||
### License
|
||||
|
||||
```
|
||||
Copyright (c) 2017 Kyle Isom <kyle@imap.cc>
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
```
|
||||
5
twofactor/doc.go
Normal file
5
twofactor/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// twofactor implements two-factor authentication.
|
||||
//
|
||||
// Currently supported are RFC 4226 HOTP one-time passwords and
|
||||
// RFC 6238 TOTP SHA-1 one-time passwords.
|
||||
package twofactor
|
||||
8
twofactor/go.mod
Normal file
8
twofactor/go.mod
Normal file
@@ -0,0 +1,8 @@
|
||||
module github.com/gokyle/twofactor
|
||||
|
||||
go 1.14
|
||||
|
||||
require (
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3
|
||||
rsc.io/qr v0.1.0
|
||||
)
|
||||
4
twofactor/go.sum
Normal file
4
twofactor/go.sum
Normal file
@@ -0,0 +1,4 @@
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3 h1:wOysYcIdqv3WnvwqFFzrYCFALPED7qkUGaLXu359GSc=
|
||||
github.com/benbjohnson/clock v0.0.0-20161215174838-7dc76406b6d3/go.mod h1:UMqtWQTnOe4byzwe7Zhwh8f8s+36uszN51sJrSIZlTE=
|
||||
rsc.io/qr v0.1.0 h1:M/sAxsU2J5mlQ4W84Bxga2EgdQqOaAliipcjPmMUM5Q=
|
||||
rsc.io/qr v0.1.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
||||
103
twofactor/hotp.go
Normal file
103
twofactor/hotp.go
Normal file
@@ -0,0 +1,103 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha1"
|
||||
"encoding/base32"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// HOTP represents an RFC-4226 Hash-based One Time Password instance.
|
||||
type HOTP struct {
|
||||
*OATH
|
||||
}
|
||||
|
||||
// Type returns OATH_HOTP.
|
||||
func (otp *HOTP) Type() Type {
|
||||
return OATH_HOTP
|
||||
}
|
||||
|
||||
// NewHOTP takes the key, the initial counter value, and the number
|
||||
// of digits (typically 6 or 8) and returns a new HOTP instance.
|
||||
func NewHOTP(key []byte, counter uint64, digits int) *HOTP {
|
||||
return &HOTP{
|
||||
OATH: &OATH{
|
||||
key: key,
|
||||
counter: counter,
|
||||
size: digits,
|
||||
hash: sha1.New,
|
||||
algo: crypto.SHA1,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// OTP returns the next OTP and increments the counter.
|
||||
func (otp *HOTP) OTP() string {
|
||||
code := otp.OATH.OTP(otp.counter)
|
||||
otp.counter++
|
||||
return code
|
||||
}
|
||||
|
||||
// URL returns an HOTP URL (i.e. for putting in a QR code).
|
||||
func (otp *HOTP) URL(label string) string {
|
||||
return otp.OATH.URL(otp.Type(), label)
|
||||
}
|
||||
|
||||
// SetProvider sets up the provider component of the OTP URL.
|
||||
func (otp *HOTP) SetProvider(provider string) {
|
||||
otp.provider = provider
|
||||
}
|
||||
|
||||
// GenerateGoogleHOTP generates a new HOTP instance as used by
|
||||
// Google Authenticator.
|
||||
func GenerateGoogleHOTP() *HOTP {
|
||||
key := make([]byte, sha1.Size)
|
||||
if _, err := io.ReadFull(PRNG, key); err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewHOTP(key, 0, 6)
|
||||
}
|
||||
|
||||
func hotpFromURL(u *url.URL) (*HOTP, string, error) {
|
||||
label := u.Path[1:]
|
||||
v := u.Query()
|
||||
|
||||
secret := strings.ToUpper(v.Get("secret"))
|
||||
if secret == "" {
|
||||
return nil, "", ErrInvalidURL
|
||||
}
|
||||
|
||||
var digits = 6
|
||||
if sdigit := v.Get("digits"); sdigit != "" {
|
||||
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
digits = int(tmpDigits)
|
||||
}
|
||||
|
||||
var counter uint64 = 0
|
||||
if scounter := v.Get("counter"); scounter != "" {
|
||||
var err error
|
||||
counter, err = strconv.ParseUint(scounter, 10, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
key, err := base32.StdEncoding.DecodeString(Pad(secret))
|
||||
if err != nil {
|
||||
// assume secret isn't base32 encoded
|
||||
key = []byte(secret)
|
||||
}
|
||||
otp := NewHOTP(key, counter, digits)
|
||||
return otp, label, nil
|
||||
}
|
||||
|
||||
// QR generates a new QR code for the HOTP.
|
||||
func (otp *HOTP) QR(label string) ([]byte, error) {
|
||||
return otp.OATH.QR(otp.Type(), label)
|
||||
}
|
||||
64
twofactor/hotp_test.go
Normal file
64
twofactor/hotp_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var testKey = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20}
|
||||
|
||||
var rfcHotpKey = []byte("12345678901234567890")
|
||||
var rfcHotpExpected = []string{
|
||||
"755224",
|
||||
"287082",
|
||||
"359152",
|
||||
"969429",
|
||||
"338314",
|
||||
"254676",
|
||||
"287922",
|
||||
"162583",
|
||||
"399871",
|
||||
"520489",
|
||||
}
|
||||
|
||||
// This test runs through the test cases presented in the RFC, and
|
||||
// ensures that this implementation is in compliance.
|
||||
func TestHotpRFC(t *testing.T) {
|
||||
otp := NewHOTP(rfcHotpKey, 0, 6)
|
||||
for i := 0; i < len(rfcHotpExpected); i++ {
|
||||
if otp.Counter() != uint64(i) {
|
||||
fmt.Printf("twofactor: invalid counter (should be %d, is %d",
|
||||
i, otp.Counter())
|
||||
t.FailNow()
|
||||
}
|
||||
code := otp.OTP()
|
||||
if code == "" {
|
||||
fmt.Printf("twofactor: failed to produce an OTP\n")
|
||||
t.FailNow()
|
||||
} else if code != rfcHotpExpected[i] {
|
||||
fmt.Printf("twofactor: invalid OTP\n")
|
||||
fmt.Printf("\tExpected: %s\n", rfcHotpExpected[i])
|
||||
fmt.Printf("\t Actual: %s\n", code)
|
||||
fmt.Printf("\t Counter: %d\n", otp.counter)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This test uses a different key than the test cases in the RFC,
|
||||
// but runs through the same test cases to ensure that they fail as
|
||||
// expected.
|
||||
func TestHotpBadRFC(t *testing.T) {
|
||||
otp := NewHOTP(testKey, 0, 6)
|
||||
for i := 0; i < len(rfcHotpExpected); i++ {
|
||||
code := otp.OTP()
|
||||
switch code {
|
||||
case "":
|
||||
fmt.Printf("twofactor: failed to produce an OTP\n")
|
||||
t.FailNow()
|
||||
case rfcHotpExpected[i]:
|
||||
fmt.Printf("twofactor: should not have received a valid OTP\n")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
150
twofactor/oath.go
Normal file
150
twofactor/oath.go
Normal file
@@ -0,0 +1,150 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/hmac"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/url"
|
||||
|
||||
"rsc.io/qr"
|
||||
)
|
||||
|
||||
const defaultSize = 6
|
||||
|
||||
// OATH provides a baseline structure for the two OATH algorithms.
|
||||
type OATH struct {
|
||||
key []byte
|
||||
counter uint64
|
||||
size int
|
||||
hash func() hash.Hash
|
||||
algo crypto.Hash
|
||||
provider string
|
||||
}
|
||||
|
||||
// Size returns the output size (in characters) of the password.
|
||||
func (o OATH) Size() int {
|
||||
return o.size
|
||||
}
|
||||
|
||||
// Counter returns the OATH token's counter.
|
||||
func (o OATH) Counter() uint64 {
|
||||
return o.counter
|
||||
}
|
||||
|
||||
// SetCounter updates the OATH token's counter to a new value.
|
||||
func (o *OATH) SetCounter(counter uint64) {
|
||||
o.counter = counter
|
||||
}
|
||||
|
||||
// Key returns the token's secret key.
|
||||
func (o OATH) Key() []byte {
|
||||
return o.key[:]
|
||||
}
|
||||
|
||||
// Hash returns the token's hash function.
|
||||
func (o OATH) Hash() func() hash.Hash {
|
||||
return o.hash
|
||||
}
|
||||
|
||||
// URL constructs a URL appropriate for the token (i.e. for use in a
|
||||
// QR code).
|
||||
func (o OATH) URL(t Type, label string) string {
|
||||
secret := base32.StdEncoding.EncodeToString(o.key)
|
||||
u := url.URL{}
|
||||
v := url.Values{}
|
||||
u.Scheme = "otpauth"
|
||||
switch t {
|
||||
case OATH_HOTP:
|
||||
u.Host = "hotp"
|
||||
case OATH_TOTP:
|
||||
u.Host = "totp"
|
||||
}
|
||||
u.Path = label
|
||||
v.Add("secret", secret)
|
||||
if o.Counter() != 0 && t == OATH_HOTP {
|
||||
v.Add("counter", fmt.Sprintf("%d", o.Counter()))
|
||||
}
|
||||
if o.Size() != defaultSize {
|
||||
v.Add("digits", fmt.Sprintf("%d", o.Size()))
|
||||
}
|
||||
|
||||
switch o.algo {
|
||||
case crypto.SHA256:
|
||||
v.Add("algorithm", "SHA256")
|
||||
case crypto.SHA512:
|
||||
v.Add("algorithm", "SHA512")
|
||||
}
|
||||
|
||||
if o.provider != "" {
|
||||
v.Add("provider", o.provider)
|
||||
}
|
||||
|
||||
u.RawQuery = v.Encode()
|
||||
return u.String()
|
||||
|
||||
}
|
||||
|
||||
var digits = []int64{
|
||||
0: 1,
|
||||
1: 10,
|
||||
2: 100,
|
||||
3: 1000,
|
||||
4: 10000,
|
||||
5: 100000,
|
||||
6: 1000000,
|
||||
7: 10000000,
|
||||
8: 100000000,
|
||||
9: 1000000000,
|
||||
10: 10000000000,
|
||||
}
|
||||
|
||||
// The top-level type should provide a counter; for example, HOTP
|
||||
// will provide the counter directly while TOTP will provide the
|
||||
// time-stepped counter.
|
||||
func (o OATH) OTP(counter uint64) string {
|
||||
var ctr [8]byte
|
||||
binary.BigEndian.PutUint64(ctr[:], counter)
|
||||
|
||||
var mod int64 = 1
|
||||
if len(digits) > o.size {
|
||||
for i := 1; i <= o.size; i++ {
|
||||
mod *= 10
|
||||
}
|
||||
} else {
|
||||
mod = digits[o.size]
|
||||
}
|
||||
|
||||
h := hmac.New(o.hash, o.key)
|
||||
h.Write(ctr[:])
|
||||
dt := truncate(h.Sum(nil)) % mod
|
||||
fmtStr := fmt.Sprintf("%%0%dd", o.size)
|
||||
return fmt.Sprintf(fmtStr, dt)
|
||||
}
|
||||
|
||||
// truncate contains the DT function from the RFC; this is used to
|
||||
// deterministically select a sequence of 4 bytes from the HMAC
|
||||
// counter hash.
|
||||
func truncate(in []byte) int64 {
|
||||
offset := int(in[len(in)-1] & 0xF)
|
||||
p := in[offset : offset+4]
|
||||
var binCode int32
|
||||
binCode = int32((p[0] & 0x7f)) << 24
|
||||
binCode += int32((p[1] & 0xff)) << 16
|
||||
binCode += int32((p[2] & 0xff)) << 8
|
||||
binCode += int32((p[3] & 0xff))
|
||||
return int64(binCode) & 0x7FFFFFFF
|
||||
}
|
||||
|
||||
// QR generates a byte slice containing the a QR code encoded as a
|
||||
// PNG with level Q error correction.
|
||||
func (o OATH) QR(t Type, label string) ([]byte, error) {
|
||||
u := o.URL(t, label)
|
||||
code, err := qr.Encode(u, qr.Q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return code.PNG(), nil
|
||||
}
|
||||
30
twofactor/oath_test.go
Normal file
30
twofactor/oath_test.go
Normal file
@@ -0,0 +1,30 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var sha1Hmac = []byte{
|
||||
0x1f, 0x86, 0x98, 0x69, 0x0e,
|
||||
0x02, 0xca, 0x16, 0x61, 0x85,
|
||||
0x50, 0xef, 0x7f, 0x19, 0xda,
|
||||
0x8e, 0x94, 0x5b, 0x55, 0x5a,
|
||||
}
|
||||
|
||||
var truncExpect int64 = 0x50ef7f19
|
||||
|
||||
// This test runs through the truncation example given in the RFC.
|
||||
func TestTruncate(t *testing.T) {
|
||||
if result := truncate(sha1Hmac); result != truncExpect {
|
||||
fmt.Printf("hotp: expected truncate -> %d, saw %d\n",
|
||||
truncExpect, result)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
sha1Hmac[19]++
|
||||
if result := truncate(sha1Hmac); result == truncExpect {
|
||||
fmt.Println("hotp: expected truncation to fail")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
86
twofactor/otp.go
Normal file
86
twofactor/otp.go
Normal file
@@ -0,0 +1,86 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type Type uint
|
||||
|
||||
const (
|
||||
OATH_HOTP = iota
|
||||
OATH_TOTP
|
||||
)
|
||||
|
||||
// PRNG is an io.Reader that provides a cryptographically secure
|
||||
// random byte stream.
|
||||
var PRNG = rand.Reader
|
||||
|
||||
var (
|
||||
ErrInvalidURL = errors.New("twofactor: invalid URL")
|
||||
ErrInvalidAlgo = errors.New("twofactor: invalid algorithm")
|
||||
)
|
||||
|
||||
// Type OTP represents a one-time password token -- whether a
|
||||
// software taken (as in the case of Google Authenticator) or a
|
||||
// hardware token (as in the case of a YubiKey).
|
||||
type OTP interface {
|
||||
// Returns the current counter value; the meaning of the
|
||||
// returned value is algorithm-specific.
|
||||
Counter() uint64
|
||||
|
||||
// Set the counter to a specific value.
|
||||
SetCounter(uint64)
|
||||
|
||||
// the secret key contained in the OTP
|
||||
Key() []byte
|
||||
|
||||
// generate a new OTP
|
||||
OTP() string
|
||||
|
||||
// the output size of the OTP
|
||||
Size() int
|
||||
|
||||
// the hash function used by the OTP
|
||||
Hash() func() hash.Hash
|
||||
|
||||
// Returns the type of this OTP.
|
||||
Type() Type
|
||||
}
|
||||
|
||||
func otpString(otp OTP) string {
|
||||
var typeName string
|
||||
switch otp.Type() {
|
||||
case OATH_HOTP:
|
||||
typeName = "OATH-HOTP"
|
||||
case OATH_TOTP:
|
||||
typeName = "OATH-TOTP"
|
||||
default:
|
||||
typeName = "UNKNOWN"
|
||||
}
|
||||
return fmt.Sprintf("%s, %d", typeName, otp.Size())
|
||||
}
|
||||
|
||||
// FromURL constructs a new OTP token from a URL string.
|
||||
func FromURL(URL string) (OTP, string, error) {
|
||||
u, err := url.Parse(URL)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
if u.Scheme != "otpauth" {
|
||||
return nil, "", ErrInvalidURL
|
||||
}
|
||||
|
||||
switch u.Host {
|
||||
case "totp":
|
||||
return totpFromURL(u)
|
||||
case "hotp":
|
||||
return hotpFromURL(u)
|
||||
default:
|
||||
return nil, "", ErrInvalidURL
|
||||
}
|
||||
}
|
||||
136
twofactor/otp_test.go
Normal file
136
twofactor/otp_test.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestHOTPString(t *testing.T) {
|
||||
hotp := NewHOTP(nil, 0, 6)
|
||||
hotpString := otpString(hotp)
|
||||
if hotpString != "OATH-HOTP, 6" {
|
||||
fmt.Println("twofactor: invalid OTP string")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// This test generates a new OTP, outputs the URL for that OTP,
|
||||
// and attempts to parse that URL. It verifies that the two OTPs
|
||||
// are the same, and that they produce the same output.
|
||||
func TestURL(t *testing.T) {
|
||||
var ident = "testuser@foo"
|
||||
otp := NewHOTP(testKey, 0, 6)
|
||||
url := otp.URL("testuser@foo")
|
||||
otp2, id, err := FromURL(url)
|
||||
if err != nil {
|
||||
fmt.Printf("hotp: failed to parse HOTP URL\n")
|
||||
t.FailNow()
|
||||
} else if id != ident {
|
||||
fmt.Printf("hotp: bad label\n")
|
||||
fmt.Printf("\texpected: %s\n", ident)
|
||||
fmt.Printf("\t actual: %s\n", id)
|
||||
t.FailNow()
|
||||
} else if otp2.Counter() != otp.Counter() {
|
||||
fmt.Printf("hotp: OTP counters aren't synced\n")
|
||||
fmt.Printf("\toriginal: %d\n", otp.Counter())
|
||||
fmt.Printf("\t second: %d\n", otp2.Counter())
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
code1 := otp.OTP()
|
||||
code2 := otp2.OTP()
|
||||
if code1 != code2 {
|
||||
fmt.Printf("hotp: mismatched OTPs\n")
|
||||
fmt.Printf("\texpected: %s\n", code1)
|
||||
fmt.Printf("\t actual: %s\n", code2)
|
||||
}
|
||||
|
||||
// There's not much we can do test the QR code, except to
|
||||
// ensure it doesn't fail.
|
||||
_, err = otp.QR(ident)
|
||||
if err != nil {
|
||||
fmt.Printf("hotp: failed to generate QR code PNG (%v)\n", err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// This should fail because the maximum size of an alphanumeric
|
||||
// QR code with the lowest-level of error correction should
|
||||
// max out at 4296 bytes. 8k may be a bit overkill... but it
|
||||
// gets the job done. The value is read from the PRNG to
|
||||
// increase the likelihood that the returned data is
|
||||
// uncompressible.
|
||||
var tooBigIdent = make([]byte, 8192)
|
||||
_, err = io.ReadFull(PRNG, tooBigIdent)
|
||||
if err != nil {
|
||||
fmt.Printf("hotp: failed to read identity (%v)\n", err)
|
||||
t.FailNow()
|
||||
} else if _, err = otp.QR(string(tooBigIdent)); err == nil {
|
||||
fmt.Println("hotp: QR code should fail to encode oversized URL")
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
||||
// This test makes sure we can generate codes for padded and non-padded
|
||||
// entries
|
||||
func TestPaddedURL(t *testing.T) {
|
||||
var urlList = []string{
|
||||
"otpauth://hotp/?secret=ME",
|
||||
"otpauth://hotp/?secret=MEFR",
|
||||
"otpauth://hotp/?secret=MFRGG",
|
||||
"otpauth://hotp/?secret=MFRGGZA",
|
||||
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by=======",
|
||||
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by",
|
||||
"otpauth://hotp/?secret=a6mryljlbufszudtjdt42nh5by%3D%3D%3D%3D%3D%3D%3D",
|
||||
}
|
||||
var codeList = []string{
|
||||
"413198",
|
||||
"770938",
|
||||
"670717",
|
||||
"402378",
|
||||
"069864",
|
||||
"069864",
|
||||
"069864",
|
||||
}
|
||||
|
||||
for i := range urlList {
|
||||
if o, id, err := FromURL(urlList[i]); err != nil {
|
||||
fmt.Println("hotp: URL should have parsed successfully (id=", id, ")")
|
||||
fmt.Printf("\turl was: %s\n", urlList[i])
|
||||
t.FailNow()
|
||||
fmt.Printf("\t%s, %s\n", o.OTP(), id)
|
||||
} else {
|
||||
code2 := o.OTP()
|
||||
if code2 != codeList[i] {
|
||||
fmt.Printf("hotp: mismatched OTPs\n")
|
||||
fmt.Printf("\texpected: %s\n", codeList[i])
|
||||
fmt.Printf("\t actual: %s\n", code2)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// This test attempts a variety of invalid urls against the parser
|
||||
// to ensure they fail.
|
||||
func TestBadURL(t *testing.T) {
|
||||
var urlList = []string{
|
||||
"http://google.com",
|
||||
"",
|
||||
"-",
|
||||
"foo",
|
||||
"otpauth:/foo/bar/baz",
|
||||
"://",
|
||||
"otpauth://hotp/?digits=",
|
||||
"otpauth://hotp/?secret=MFRGGZDF&digits=ABCD",
|
||||
"otpauth://hotp/?secret=MFRGGZDF&counter=ABCD",
|
||||
}
|
||||
|
||||
for i := range urlList {
|
||||
if _, _, err := FromURL(urlList[i]); err == nil {
|
||||
fmt.Println("hotp: URL should not have parsed successfully")
|
||||
fmt.Printf("\turl was: %s\n", urlList[i])
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
172
twofactor/totp.go
Normal file
172
twofactor/totp.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"crypto/sha1"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base32"
|
||||
"hash"
|
||||
"io"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
var timeSource = clock.New()
|
||||
|
||||
// TOTP represents an RFC 6238 Time-based One-Time Password instance.
|
||||
type TOTP struct {
|
||||
*OATH
|
||||
step uint64
|
||||
}
|
||||
|
||||
// Type returns OATH_TOTP.
|
||||
func (otp *TOTP) Type() Type {
|
||||
return OATH_TOTP
|
||||
}
|
||||
|
||||
func (otp *TOTP) otp(counter uint64) string {
|
||||
return otp.OATH.OTP(counter)
|
||||
}
|
||||
|
||||
// OTP returns the OTP for the current timestep.
|
||||
func (otp *TOTP) OTP() string {
|
||||
return otp.otp(otp.OTPCounter())
|
||||
}
|
||||
|
||||
// URL returns a TOTP URL (i.e. for putting in a QR code).
|
||||
func (otp *TOTP) URL(label string) string {
|
||||
return otp.OATH.URL(otp.Type(), label)
|
||||
}
|
||||
|
||||
// SetProvider sets up the provider component of the OTP URL.
|
||||
func (otp *TOTP) SetProvider(provider string) {
|
||||
otp.provider = provider
|
||||
}
|
||||
|
||||
func (otp *TOTP) otpCounter(t uint64) uint64 {
|
||||
return (t - otp.counter) / otp.step
|
||||
}
|
||||
|
||||
// OTPCounter returns the current time value for the OTP.
|
||||
func (otp *TOTP) OTPCounter() uint64 {
|
||||
return otp.otpCounter(uint64(timeSource.Now().Unix()))
|
||||
}
|
||||
|
||||
// NewTOTP takes a new key, a starting time, a step, the number of
|
||||
// digits of output (typically 6 or 8) and the hash algorithm to
|
||||
// use, and builds a new OTP.
|
||||
func NewTOTP(key []byte, start uint64, step uint64, digits int, algo crypto.Hash) *TOTP {
|
||||
h := hashFromAlgo(algo)
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &TOTP{
|
||||
OATH: &OATH{
|
||||
key: key,
|
||||
counter: start,
|
||||
size: digits,
|
||||
hash: h,
|
||||
algo: algo,
|
||||
},
|
||||
step: step,
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// NewTOTPSHA1 will build a new TOTP using SHA-1.
|
||||
func NewTOTPSHA1(key []byte, start uint64, step uint64, digits int) *TOTP {
|
||||
return NewTOTP(key, start, step, digits, crypto.SHA1)
|
||||
}
|
||||
|
||||
func hashFromAlgo(algo crypto.Hash) func() hash.Hash {
|
||||
switch algo {
|
||||
case crypto.SHA1:
|
||||
return sha1.New
|
||||
case crypto.SHA256:
|
||||
return sha256.New
|
||||
case crypto.SHA512:
|
||||
return sha512.New
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateGoogleTOTP produces a new TOTP token with the defaults expected by
|
||||
// Google Authenticator.
|
||||
func GenerateGoogleTOTP() *TOTP {
|
||||
key := make([]byte, sha1.Size)
|
||||
if _, err := io.ReadFull(PRNG, key); err != nil {
|
||||
return nil
|
||||
}
|
||||
return NewTOTP(key, 0, 30, 6, crypto.SHA1)
|
||||
}
|
||||
|
||||
// NewGoogleTOTP takes a secret as a base32-encoded string and
|
||||
// returns an appropriate Google Authenticator TOTP instance.
|
||||
func NewGoogleTOTP(secret string) (*TOTP, error) {
|
||||
key, err := base32.StdEncoding.DecodeString(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewTOTP(key, 0, 30, 6, crypto.SHA1), nil
|
||||
}
|
||||
|
||||
func totpFromURL(u *url.URL) (*TOTP, string, error) {
|
||||
label := u.Path[1:]
|
||||
v := u.Query()
|
||||
|
||||
secret := strings.ToUpper(v.Get("secret"))
|
||||
if secret == "" {
|
||||
return nil, "", ErrInvalidURL
|
||||
}
|
||||
|
||||
var algo = crypto.SHA1
|
||||
if algorithm := v.Get("algorithm"); algorithm != "" {
|
||||
if strings.EqualFold(algorithm, "SHA256") {
|
||||
algo = crypto.SHA256
|
||||
} else if strings.EqualFold(algorithm, "SHA512") {
|
||||
algo = crypto.SHA512
|
||||
} else if !strings.EqualFold(algorithm, "SHA1") {
|
||||
return nil, "", ErrInvalidAlgo
|
||||
}
|
||||
}
|
||||
|
||||
var digits = 6
|
||||
if sdigit := v.Get("digits"); sdigit != "" {
|
||||
tmpDigits, err := strconv.ParseInt(sdigit, 10, 8)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
digits = int(tmpDigits)
|
||||
}
|
||||
|
||||
var period uint64 = 30
|
||||
if speriod := v.Get("period"); speriod != "" {
|
||||
var err error
|
||||
period, err = strconv.ParseUint(speriod, 10, 64)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
|
||||
key, err := base32.StdEncoding.DecodeString(Pad(secret))
|
||||
if err != nil {
|
||||
// assume secret isn't base32 encoded
|
||||
key = []byte(secret)
|
||||
}
|
||||
otp := NewTOTP(key, 0, period, digits, algo)
|
||||
return otp, label, nil
|
||||
}
|
||||
|
||||
// QR generates a new TOTP QR code.
|
||||
func (otp *TOTP) QR(label string) ([]byte, error) {
|
||||
return otp.OATH.QR(otp.Type(), label)
|
||||
}
|
||||
|
||||
func SetClock(c clock.Clock) {
|
||||
timeSource = c
|
||||
}
|
||||
87
twofactor/totp_test.go
Normal file
87
twofactor/totp_test.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"crypto"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/benbjohnson/clock"
|
||||
)
|
||||
|
||||
var rfcTotpKey = map[crypto.Hash][]byte{
|
||||
crypto.SHA1: []byte("12345678901234567890"),
|
||||
crypto.SHA256: []byte("12345678901234567890123456789012"),
|
||||
crypto.SHA512: []byte("1234567890123456789012345678901234567890123456789012345678901234"),
|
||||
}
|
||||
var rfcTotpStep uint64 = 30
|
||||
|
||||
var rfcTotpTests = []struct {
|
||||
Time uint64
|
||||
Code string
|
||||
T uint64
|
||||
Algo crypto.Hash
|
||||
}{
|
||||
{59, "94287082", 1, crypto.SHA1},
|
||||
{59, "46119246", 1, crypto.SHA256},
|
||||
{59, "90693936", 1, crypto.SHA512},
|
||||
{1111111109, "07081804", 37037036, crypto.SHA1},
|
||||
{1111111109, "68084774", 37037036, crypto.SHA256},
|
||||
{1111111109, "25091201", 37037036, crypto.SHA512},
|
||||
{1111111111, "14050471", 37037037, crypto.SHA1},
|
||||
{1111111111, "67062674", 37037037, crypto.SHA256},
|
||||
{1111111111, "99943326", 37037037, crypto.SHA512},
|
||||
{1234567890, "89005924", 41152263, crypto.SHA1},
|
||||
{1234567890, "91819424", 41152263, crypto.SHA256},
|
||||
{1234567890, "93441116", 41152263, crypto.SHA512},
|
||||
{2000000000, "69279037", 66666666, crypto.SHA1},
|
||||
{2000000000, "90698825", 66666666, crypto.SHA256},
|
||||
{2000000000, "38618901", 66666666, crypto.SHA512},
|
||||
{20000000000, "65353130", 666666666, crypto.SHA1},
|
||||
{20000000000, "77737706", 666666666, crypto.SHA256},
|
||||
{20000000000, "47863826", 666666666, crypto.SHA512},
|
||||
}
|
||||
|
||||
func TestTotpRFC(t *testing.T) {
|
||||
for _, tc := range rfcTotpTests {
|
||||
otp := NewTOTP(rfcTotpKey[tc.Algo], 0, rfcTotpStep, 8, tc.Algo)
|
||||
if otp.otpCounter(tc.Time) != tc.T {
|
||||
fmt.Printf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
|
||||
fmt.Printf("\texpected: %d\n", tc.T)
|
||||
fmt.Printf("\t actual: %d\n", otp.otpCounter(tc.Time))
|
||||
t.Fail()
|
||||
}
|
||||
|
||||
if code := otp.otp(otp.otpCounter(tc.Time)); code != tc.Code {
|
||||
fmt.Printf("twofactor: invalid TOTP (t=%d, h=%d)\n", tc.Time, tc.Algo)
|
||||
fmt.Printf("\texpected: %s\n", tc.Code)
|
||||
fmt.Printf("\t actual: %s\n", code)
|
||||
t.Fail()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTOTPTime(t *testing.T) {
|
||||
otp := GenerateGoogleTOTP()
|
||||
|
||||
testClock := clock.NewMock()
|
||||
testClock.Add(2 * time.Minute)
|
||||
SetClock(testClock)
|
||||
|
||||
code := otp.OTP()
|
||||
|
||||
testClock.Add(-1 * time.Minute)
|
||||
if newCode := otp.OTP(); newCode == code {
|
||||
t.Errorf("twofactor: TOTP: previous code %s shouldn't match code %s", newCode, code)
|
||||
}
|
||||
|
||||
testClock.Add(2 * time.Minute)
|
||||
if newCode := otp.OTP(); newCode == code {
|
||||
t.Errorf("twofactor: TOTP: future code %s shouldn't match code %s", newCode, code)
|
||||
}
|
||||
|
||||
testClock.Add(-1 * time.Minute)
|
||||
if newCode := otp.OTP(); newCode != code {
|
||||
t.Errorf("twofactor: TOTP: current code %s shouldn't match code %s", newCode, code)
|
||||
}
|
||||
}
|
||||
16
twofactor/util.go
Normal file
16
twofactor/util.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Pad calculates the number of '='s to add to our encoded string
|
||||
// to make base32.StdEncoding.DecodeString happy
|
||||
func Pad(s string) string {
|
||||
if !strings.HasSuffix(s, "=") && len(s)%8 != 0 {
|
||||
for len(s)%8 != 0 {
|
||||
s += "="
|
||||
}
|
||||
}
|
||||
return s
|
||||
}
|
||||
53
twofactor/util_test.go
Normal file
53
twofactor/util_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package twofactor
|
||||
|
||||
import (
|
||||
"encoding/base32"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const letters = "1234567890!@#$%^&*()abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
|
||||
func randString() string {
|
||||
b := make([]byte, rand.Intn(len(letters)))
|
||||
for i := range b {
|
||||
b[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return base32.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
func TestPadding(t *testing.T) {
|
||||
for i := 0; i < 300; i++ {
|
||||
b := randString()
|
||||
origEncoding := string(b)
|
||||
modEncoding := strings.ReplaceAll(string(b), "=", "")
|
||||
str, err := base32.StdEncoding.DecodeString(origEncoding)
|
||||
if err != nil {
|
||||
fmt.Println("Can't decode: ", string(b))
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
paddedEncoding := Pad(modEncoding)
|
||||
if origEncoding != paddedEncoding {
|
||||
fmt.Println("Padding failed:")
|
||||
fmt.Printf("Expected: '%s'", origEncoding)
|
||||
fmt.Printf("Got: '%s'", paddedEncoding)
|
||||
t.FailNow()
|
||||
} else {
|
||||
mstr, err := base32.StdEncoding.DecodeString(paddedEncoding)
|
||||
if err != nil {
|
||||
fmt.Println("Can't decode: ", paddedEncoding)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
if string(mstr) != string(str) {
|
||||
fmt.Println("Re-padding failed:")
|
||||
fmt.Printf("Expected: '%s'", str)
|
||||
fmt.Printf("Got: '%s'", mstr)
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user