Compare commits

...

14 Commits

Author SHA1 Message Date
8348c5fd65 Update CHANGELOG. 2025-11-16 11:09:02 -08:00
1eafb638a8 cmd: finish linting fixes 2025-11-16 11:03:12 -08:00
3ad562b6fa cmd: continuing linter fixes 2025-11-16 02:54:02 -08:00
0f77bd49dc cmd: continue lint fixes. 2025-11-16 01:32:19 -08:00
f31d74243f cmd: start linting fixes. 2025-11-16 00:36:19 -08:00
a573f1cd20 Update CHANGELOG. 2025-11-15 23:48:54 -08:00
f93cf5fa9c adding lru/mru cache. 2025-11-15 23:48:00 -08:00
b879d62384 cert-bundler: lint fixes 2025-11-15 23:27:50 -08:00
c99ffd4394 cmd: cleaning up programs 2025-11-15 23:17:40 -08:00
ed8c07c1c5 Add 'mru/' from commit '2899885c4220560df4f60e4c052a6ab9773a0386'
git-subtree-dir: mru
git-subtree-mainline: cf2b016433
git-subtree-split: 2899885c42
2025-11-15 22:54:26 -08:00
cf2b016433 certlib: complete overhaul. 2025-11-15 22:54:12 -08:00
2899885c42 linter fixes 2025-11-15 22:46:42 -08:00
b92e16fa4d Handle evictions properly when cache is empty. 2023-08-27 18:01:16 -07:00
6fbdece4be Initial import. 2022-02-24 21:39:10 -08:00
70 changed files with 2510 additions and 1196 deletions

View File

@@ -64,4 +64,4 @@ workflows:
testbuild: testbuild:
jobs: jobs:
- testbuild - testbuild
# - lint - lint

View File

@@ -18,6 +18,19 @@ issues:
# Default: 3 # Default: 3
max-same-issues: 50 max-same-issues: 50
# Exclude some lints for CLI programs under cmd/ (package main).
# The project allows fmt.Print* in command-line tools; keep forbidigo for libraries.
exclude-rules:
- path: ^cmd/
linters:
- forbidigo
- path: cmd/.*
linters:
- forbidigo
- path: .*/cmd/.*
linters:
- forbidigo
formatters: formatters:
enable: enable:
- goimports # checks if the code and import statements are formatted according to the 'goimports' command - goimports # checks if the code and import statements are formatted according to the 'goimports' command
@@ -73,7 +86,6 @@ linters:
- godoclint # checks Golang's documentation practice - godoclint # checks Golang's documentation practice
- godot # checks if comments end in a period - godot # checks if comments end in a period
- gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod - gomoddirectives # manages the use of 'replace', 'retract', and 'excludes' directives in go.mod
- goprintffuncname # checks that printf-like functions are named with f at the end
- gosec # inspects source code for security problems - gosec # inspects source code for security problems
- govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - govet # reports suspicious constructs, such as Printf calls whose arguments do not align with the format string
- iface # checks the incorrect use of interfaces, helping developers avoid interface pollution - iface # checks the incorrect use of interfaces, helping developers avoid interface pollution
@@ -230,6 +242,10 @@ linters:
check-type-assertions: true check-type-assertions: true
exclude-functions: exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write - (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
exhaustive: exhaustive:
# Program elements to check for exhaustiveness. # Program elements to check for exhaustiveness.
@@ -321,6 +337,12 @@ linters:
# https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link # https://github.com/godoc-lint/godoc-lint?tab=readme-ov-file#no-unused-link
- no-unused-link - no-unused-link
gosec:
excludes:
- G104 # handled by errcheck
- G301
- G306
govet: govet:
# Enable all analyzers. # Enable all analyzers.
# Default: false # Default: false
@@ -356,6 +378,12 @@ linters:
- os.WriteFile - os.WriteFile
- prometheus.ExponentialBuckets.* - prometheus.ExponentialBuckets.*
- prometheus.LinearBuckets - prometheus.LinearBuckets
ignored-numbers:
- 1
- 2
- 3
- 4
- 8
nakedret: nakedret:
# Make an issue if func has more lines of code than this setting, and it has naked returns. # Make an issue if func has more lines of code than this setting, and it has naked returns.
@@ -424,6 +452,8 @@ linters:
# Omit embedded fields from selector expression. # Omit embedded fields from selector expression.
# https://staticcheck.dev/docs/checks/#QF1008 # https://staticcheck.dev/docs/checks/#QF1008
- -QF1008 - -QF1008
# We often explicitly enable old/deprecated ciphers for research.
- -SA1019
usetesting: usetesting:
# Enable/disable `os.TempDir()` detections. # Enable/disable `os.TempDir()` detections.
@@ -450,6 +480,10 @@ linters:
linters: [ forbidigo ] linters: [ forbidigo ]
- path: 'logging/example_test.go' - path: 'logging/example_test.go'
linters: [ testableexamples ] linters: [ testableexamples ]
- path: 'main.go'
linters: [ forbidigo, mnd, reassign ]
- path: 'cmd/cruntar/main.go'
linters: [ unparam ]
- source: 'TODO' - source: 'TODO'
linters: [ godot ] linters: [ godot ]
- text: 'should have a package comment' - text: 'should have a package comment'

View File

@@ -1,45 +1,64 @@
Unreleased - 2025-11-15 CHANGELOG
"Error handling modernization" (in progress) v1.11.1 - 2025-11-16
- Introduced typed, wrapped errors via certlib/certerr.Error (Source, Kind, Op, Err) with Unwrap. Changed
- Standardized helper constructors: DecodeError, ParsingError, VerifyError, LoadingError. - cmd: complete linting fixes across programs; no functional changes.
- Preserved sentinel errors (e.g., ErrEncryptedPrivateKey, ErrInvalidPEMType, ErrEmptyCertificate) for errors.Is.
- Refactored certlib to use certerr in key paths (CSR parsing/verification, PEM cert pool, certificate read/load).
- Migrated logging/file.go and cmd/kgz away from github.com/pkg/errors to stdlib wrapping.
- Removed dependency on github.com/pkg/errors; ran go mod tidy.
- Added package docs for certerr and a README section on error handling and matching.
- Added unit tests for certerr (Is/As and message formatting).
Planned next steps: v1.11.0 - 2025-11-15
- Continue refactoring remaining error paths for consistent wrapping.
- Add focused tests for key flows (encrypted private key, CSR invalid PEM types, etc.).
- Run golangci-lint (errorlint, errcheck) and address findings.
Release 1.2.1 - 2018-09-15 Added
- cache/mru: introduce MRU cache implementation with timestamp utilities.
+ Add missing format argument to Errorf call in kgz. Changed
- certlib: complete overhaul to simplify APIs and internals.
- repo: widespread linting cleanups across many packages (config, dbg, die,
fileutil, log/logging, mwc, sbuf, seekbuf, tee, testio, etc.).
- cmd: general program cleanups; `cert-bundler` lint fixes.
Release 1.2.0 - 2018-09-15 Removed
- rand: remove unused package.
- testutil: remove unused code.
+ Adds the kgz command line utility.
Release 1.1.0 - 2017-11-16 v1.10.1 — 2025-11-15
+ A number of new command line utilities were added Changed
- certlib: major overhaul and refactor.
- repo: linter autofixes ahead of release.
+ atping
+ cruntar
+ renfnv
+
+ ski
+ subjhash
+ yamll
+ new package: ahash v1.10.0 — 2025-11-14
+ package for loading hashes from an algorithm string
+ new certificate loading functions in the lib package Added
- cmd: add `cert-revcheck` command.
+ new package: tee Changed
+ emulates tee(1) - ci/lint: add golangci-lint stage and initial cleanup.
v1.9.1 — 2025-11-15
Fixed
- die: correct calls to `die.With`.
v1.9.0 — 2025-11-14
Added
- cmd: add `cert-bundler` tool.
Changed
- misc: minor updates and maintenance.
v1.8.1 — 2025-11-14
Added
- cmd: add `tlsinfo` tool.
v1.8.0 — 2025-11-14
Baseline
- Initial baseline for this changelog series.

View File

@@ -91,7 +91,7 @@ func TestReset(t *testing.T) {
} }
} }
const decay = 5 * time.Millisecond const decay = 25 * time.Millisecond
const maxDuration = 10 * time.Millisecond const maxDuration = 10 * time.Millisecond
const interval = time.Millisecond const interval = time.Millisecond

179
cache/lru/lru.go vendored Normal file
View File

@@ -0,0 +1,179 @@
// Package lru implements a Least Recently Used cache.
package lru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the least recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new LRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("LRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("lru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("lru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("lru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

87
cache/lru/lru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,87 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests mirror the MRU-style behavior present in this LRU package
// implementation (eviction removes the most-recently-used entry).
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
// Since this implementation evicts the most-recently-used item, inserting
// "goat" when full evicts "owl" (the most recent at that time).
mock.Add(time.Second)
if _, ok := c.Get("owl"); ok {
t.Fatal("store should not have an entry for owl (MRU-evicted)")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
// Before storing elk, keys were: raven (older), goat (newer). Evict MRU -> goat.
if !c.Has("raven") {
t.Fatal("store should contain an entry for 'raven'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/lru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package lru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t > ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("lru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("lru: duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("lru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

50
cache/lru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,50 @@
package lru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
// These tests validate timestamps ordering semantics for the LRU package.
// Note: The LRU timestamps are sorted with most-recent-first (descending by t).
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// make owl the most recent
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// For LRU timestamps: most recent first. Expected order: owl, goat, raven.
if ts.K(0) != "owl" {
t.Fatalf("first key should be owl, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "raven" {
t.Fatalf("third key should be raven, have %s", ts.K(2))
}
}

178
cache/mru/mru.go vendored Normal file
View File

@@ -0,0 +1,178 @@
package mru
import (
"errors"
"fmt"
"sort"
"sync"
"github.com/benbjohnson/clock"
)
type item[V any] struct {
V V
access int64
}
// A Cache is a map that retains a limited number of items. It must be
// initialized with New, providing a maximum capacity for the cache.
// Only the most recently used items are retained.
type Cache[K comparable, V any] struct {
store map[K]*item[V]
access *timestamps[K]
cap int
clock clock.Clock
// All public methods that have the possibility of modifying the
// cache should lock it.
mtx *sync.Mutex
}
// New must be used to create a new Cache.
func New[K comparable, V any](icap int) *Cache[K, V] {
return &Cache[K, V]{
store: map[K]*item[V]{},
access: newTimestamps[K](icap),
cap: icap,
clock: clock.New(),
mtx: &sync.Mutex{},
}
}
// StringKeyCache is a convenience wrapper for cache keyed by string.
type StringKeyCache[V any] struct {
*Cache[string, V]
}
// NewStringKeyCache creates a new MRU cache keyed by string.
func NewStringKeyCache[V any](icap int) *StringKeyCache[V] {
return &StringKeyCache[V]{Cache: New[string, V](icap)}
}
func (c *Cache[K, V]) lock() {
c.mtx.Lock()
}
func (c *Cache[K, V]) unlock() {
c.mtx.Unlock()
}
// Len returns the number of items currently in the cache.
func (c *Cache[K, V]) Len() int {
return len(c.store)
}
// evict should remove the least-recently-used cache item.
func (c *Cache[K, V]) evict() {
if c.access.Len() == 0 {
return
}
k := c.access.K(0)
c.evictKey(k)
}
// evictKey should remove the entry given by the key item.
func (c *Cache[K, V]) evictKey(k K) {
delete(c.store, k)
i, ok := c.access.Find(k)
if !ok {
return
}
c.access.Delete(i)
}
func (c *Cache[K, V]) sanityCheck() {
if len(c.store) != c.access.Len() {
panic(fmt.Sprintf("MRU cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len()))
}
}
// ConsistencyCheck runs a series of checks to ensure that the cache's
// data structures are consistent. It is not normally required, and it
// is primarily used in testing.
func (c *Cache[K, V]) ConsistencyCheck() error {
c.lock()
defer c.unlock()
if err := c.access.ConsistencyCheck(); err != nil {
return err
}
if len(c.store) != c.access.Len() {
return fmt.Errorf("mru: cache is out of sync; store len = %d, access len = %d",
len(c.store), c.access.Len())
}
for i := range c.access.ts {
itm, ok := c.store[c.access.K(i)]
if !ok {
return errors.New("mru: key in access is not in store")
}
if c.access.T(i) != itm.access {
return fmt.Errorf("timestamps are out of sync (%d != %d)",
itm.access, c.access.T(i))
}
}
if !sort.IsSorted(c.access) {
return errors.New("mru: timestamps aren't sorted")
}
return nil
}
// Store adds the value v to the cache under the k.
func (c *Cache[K, V]) Store(k K, v V) {
c.lock()
defer c.unlock()
c.sanityCheck()
if len(c.store) == c.cap {
c.evict()
}
if _, ok := c.store[k]; ok {
c.evictKey(k)
}
itm := &item[V]{
V: v,
access: c.clock.Now().UnixNano(),
}
c.store[k] = itm
c.access.Update(k, itm.access)
}
// Get returns the value stored in the cache. If the item isn't present,
// it will return false.
func (c *Cache[K, V]) Get(k K) (V, bool) {
c.lock()
defer c.unlock()
c.sanityCheck()
itm, ok := c.store[k]
if !ok {
var zero V
return zero, false
}
c.store[k].access = c.clock.Now().UnixNano()
c.access.Update(k, itm.access)
return itm.V, true
}
// Has returns true if the cache has an entry for k. It will not update
// the timestamp on the item.
func (c *Cache[K, V]) Has(k K) bool {
// Don't need to lock as we don't modify anything.
c.sanityCheck()
_, ok := c.store[k]
return ok
}

92
cache/mru/mru_internal_test.go vendored Normal file
View File

@@ -0,0 +1,92 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestBasicCacheEviction(t *testing.T) {
mock := clock.NewMock()
c := NewStringKeyCache[int](2)
c.clock = mock
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if c.Len() != 0 {
t.Fatal("cache should have size 0")
}
c.evict()
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
c.Store("raven", 1)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 1 {
t.Fatalf("store should have length=1, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("owl", 2)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
c.Store("goat", 3)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if len(c.store) != 2 {
t.Fatalf("store should have length=2, have length=%d", len(c.store))
}
mock.Add(time.Second)
v, ok := c.Get("owl")
if !ok {
t.Fatal("store should have an entry for owl")
}
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
itm := v
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if itm != 2 {
t.Fatalf("stored item should be 2, have %d", itm)
}
mock.Add(time.Second)
c.Store("elk", 4)
if err := c.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
if !c.Has("elk") {
t.Fatal("store should contain an entry for 'elk'")
}
if !c.Has("owl") {
t.Fatal("store should contain an entry for 'owl'")
}
if c.Has("goat") {
t.Fatal("store should not contain an entry for 'goat'")
}
}

101
cache/mru/timestamps.go vendored Normal file
View File

@@ -0,0 +1,101 @@
package mru
import (
"errors"
"fmt"
"io"
"sort"
)
// timestamps contains datastructures for maintaining a list of keys sortable
// by timestamp.
type timestamp[K comparable] struct {
t int64
k K
}
type timestamps[K comparable] struct {
ts []timestamp[K]
cap int
}
func newTimestamps[K comparable](icap int) *timestamps[K] {
return &timestamps[K]{
ts: make([]timestamp[K], 0, icap),
cap: icap,
}
}
func (ts *timestamps[K]) K(i int) K {
return ts.ts[i].k
}
func (ts *timestamps[K]) T(i int) int64 {
return ts.ts[i].t
}
func (ts *timestamps[K]) Len() int {
return len(ts.ts)
}
func (ts *timestamps[K]) Less(i, j int) bool {
return ts.ts[i].t < ts.ts[j].t
}
func (ts *timestamps[K]) Swap(i, j int) {
ts.ts[i], ts.ts[j] = ts.ts[j], ts.ts[i]
}
func (ts *timestamps[K]) Find(k K) (int, bool) {
for i := range ts.ts {
if ts.ts[i].k == k {
return i, true
}
}
return -1, false
}
func (ts *timestamps[K]) Update(k K, t int64) bool {
i, ok := ts.Find(k)
if !ok {
ts.ts = append(ts.ts, timestamp[K]{t, k})
sort.Sort(ts)
return false
}
ts.ts[i].t = t
sort.Sort(ts)
return true
}
func (ts *timestamps[K]) ConsistencyCheck() error {
if !sort.IsSorted(ts) {
return errors.New("mru: timestamps are not sorted")
}
keys := map[K]bool{}
for i := range ts.ts {
if keys[ts.ts[i].k] {
return fmt.Errorf("duplicate key %v detected", ts.ts[i].k)
}
keys[ts.ts[i].k] = true
}
if len(keys) != len(ts.ts) {
return fmt.Errorf("mru: timestamp contains %d duplicate keys",
len(ts.ts)-len(keys))
}
return nil
}
func (ts *timestamps[K]) Delete(i int) {
ts.ts = append(ts.ts[:i], ts.ts[i+1:]...)
}
func (ts *timestamps[K]) Dump(w io.Writer) {
for i := range ts.ts {
fmt.Fprintf(w, "%d: %v, %d\n", i, ts.K(i), ts.T(i))
}
}

49
cache/mru/timestamps_internal_test.go vendored Normal file
View File

@@ -0,0 +1,49 @@
package mru
import (
"testing"
"time"
"github.com/benbjohnson/clock"
)
func TestTimestamps(t *testing.T) {
ts := newTimestamps[string](3)
mock := clock.NewMock()
// raven
ts.Update("raven", mock.Now().UnixNano())
// raven, owl
mock.Add(time.Millisecond)
ts.Update("owl", mock.Now().UnixNano())
// raven, owl, goat
mock.Add(time.Second)
ts.Update("goat", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
mock.Add(time.Millisecond)
// raven, goat, owl
ts.Update("owl", mock.Now().UnixNano())
if err := ts.ConsistencyCheck(); err != nil {
t.Fatal(err)
}
// at this point, the keys should be raven, goat, owl.
if ts.K(0) != "raven" {
t.Fatalf("first key should be raven, have %s", ts.K(0))
}
if ts.K(1) != "goat" {
t.Fatalf("second key should be goat, have %s", ts.K(1))
}
if ts.K(2) != "owl" {
t.Fatalf("third key should be owl, have %s", ts.K(2))
}
}

View File

@@ -79,24 +79,23 @@ func (e *Error) Error() string {
func (e *Error) Unwrap() error { return e.Err } func (e *Error) Unwrap() error { return e.Err }
// InvalidPEMType is used to indicate that we were expecting one type of PEM // InvalidPEMTypeError is used to indicate that we were expecting one type of PEM
// file, but saw another. // file, but saw another.
type InvalidPEMType struct { type InvalidPEMTypeError struct {
have string have string
want []string want []string
} }
func (err *InvalidPEMType) Error() string { func (err *InvalidPEMTypeError) Error() string {
if len(err.want) == 1 { if len(err.want) == 1 {
return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0]) return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0])
} else { }
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", ")) return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
}
// ErrInvalidPEMType returns a new InvalidPEMType error. // ErrInvalidPEMType returns a new InvalidPEMTypeError error.
func ErrInvalidPEMType(have string, want ...string) error { func ErrInvalidPEMType(have string, want ...string) error {
return &InvalidPEMType{ return &InvalidPEMTypeError{
have: have, have: have,
want: want, want: want,
} }

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certerr package certerr
import ( import (

View File

@@ -11,7 +11,7 @@ import (
// ReadCertificate reads a DER or PEM-encoded certificate from the // ReadCertificate reads a DER or PEM-encoded certificate from the
// byte slice. // byte slice.
func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error) { func ReadCertificate(in []byte) (*x509.Certificate, []byte, error) {
if len(in) == 0 { if len(in) == 0 {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate)
} }
@@ -22,7 +22,7 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file")) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file"))
} }
rest = remaining rest := remaining
if p.Type != "CERTIFICATE" { if p.Type != "CERTIFICATE" {
return nil, rest, certerr.ParsingError( return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCertificate, certerr.ErrorSourceCertificate,
@@ -31,19 +31,26 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
} }
in = p.Bytes in = p.Bytes
} cert, err := x509.ParseCertificate(in)
cert, err = x509.ParseCertificate(in)
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return cert, rest, nil return cert, rest, nil
} }
cert, err := x509.ParseCertificate(in)
if err != nil {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return cert, nil, nil
}
// ReadCertificates tries to read all the certificates in a // ReadCertificates tries to read all the certificates in a
// PEM-encoded collection. // PEM-encoded collection.
func ReadCertificates(in []byte) (certs []*x509.Certificate, err error) { func ReadCertificates(in []byte) ([]*x509.Certificate, error) {
var cert *x509.Certificate var cert *x509.Certificate
var certs []*x509.Certificate
var err error
for { for {
cert, in, err = ReadCertificate(in) cert, in, err = ReadCertificate(in)
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certlib package certlib
import ( import (

View File

@@ -38,6 +38,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/certlib/certerr" "git.wntrmute.dev/kyle/goutils/certlib/certerr"
@@ -47,29 +48,36 @@ import (
// private key. The key must not be in PEM format. If an error is returned, it // private key. The key must not be in PEM format. If an error is returned, it
// may contain information about the private key, so care should be taken when // may contain information about the private key, so care should be taken when
// displaying it directly. // displaying it directly.
func ParsePrivateKeyDER(keyDER []byte) (key crypto.Signer, err error) { func ParsePrivateKeyDER(keyDER []byte) (crypto.Signer, error) {
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER) // Try common encodings in order without deep nesting.
if err != nil { if k, err := x509.ParsePKCS8PrivateKey(keyDER); err == nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER) switch kk := k.(type) {
if err != nil { case *rsa.PrivateKey:
generalKey, err = x509.ParseECPrivateKey(keyDER) return kk, nil
if err != nil { case *ecdsa.PrivateKey:
generalKey, err = ParseEd25519PrivateKey(keyDER) return kk, nil
if err != nil { case ed25519.PrivateKey:
return kk, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
}
if k, err := x509.ParsePKCS1PrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := x509.ParseECPrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := ParseEd25519PrivateKey(keyDER); err == nil {
if kk, ok := k.(ed25519.PrivateKey); ok {
return kk, nil
}
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
// If all parsers failed, return the last error from Ed25519 attempt (approximate cause).
if _, err := ParseEd25519PrivateKey(keyDER); err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err) return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err)
} }
} // Fallback (should be unreachable)
} return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, errors.New("unknown key encoding"))
}
switch generalKey := generalKey.(type) {
case *rsa.PrivateKey:
return generalKey, nil
case *ecdsa.PrivateKey:
return generalKey, nil
case ed25519.PrivateKey:
return generalKey, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %t", generalKey))
}
} }

View File

@@ -65,12 +65,14 @@ func MarshalEd25519PublicKey(pk crypto.PublicKey) ([]byte, error) {
return nil, errEd25519WrongKeyType return nil, errEd25519WrongKeyType
} }
const bitsPerByte = 8
spki := subjectPublicKeyInfo{ spki := subjectPublicKeyInfo{
Algorithm: pkix.AlgorithmIdentifier{ Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID, Algorithm: ed25519OID,
}, },
PublicKey: asn1.BitString{ PublicKey: asn1.BitString{
BitLength: len(pub) * 8, BitLength: len(pub) * bitsPerByte,
Bytes: pub, Bytes: pub,
}, },
} }
@@ -91,7 +93,8 @@ func ParseEd25519PublicKey(der []byte) (crypto.PublicKey, error) {
return nil, errEd25519WrongID return nil, errEd25519WrongID
} }
if spki.PublicKey.BitLength != ed25519.PublicKeySize*8 { const bitsPerByte = 8
if spki.PublicKey.BitLength != ed25519.PublicKeySize*bitsPerByte {
return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch") return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch")
} }

View File

@@ -49,14 +49,14 @@ import (
"strings" "strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
ct "github.com/google/certificate-transparency-go" ct "github.com/google/certificate-transparency-go"
cttls "github.com/google/certificate-transparency-go/tls" cttls "github.com/google/certificate-transparency-go/tls"
ctx509 "github.com/google/certificate-transparency-go/x509" ctx509 "github.com/google/certificate-transparency-go/x509"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"golang.org/x/crypto/pkcs12" "golang.org/x/crypto/pkcs12"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
) )
// OneYear is a time.Duration representing a year's worth of seconds. // OneYear is a time.Duration representing a year's worth of seconds.
@@ -68,7 +68,7 @@ const OneDay = 24 * time.Hour
// DelegationUsage is the OID for the DelegationUseage extensions. // DelegationUsage is the OID for the DelegationUseage extensions.
var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44} var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44}
// DelegationExtension. // DelegationExtension is a non-critical extension marking delegation usage.
var DelegationExtension = pkix.Extension{ var DelegationExtension = pkix.Extension{
Id: DelegationUsage, Id: DelegationUsage,
Critical: false, Critical: false,
@@ -81,13 +81,19 @@ func InclusiveDate(year int, month time.Month, day int) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond) return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond)
} }
const (
year2012 = 2012
year2015 = 2015
day1 = 1
)
// Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop // Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 5 years. // issuing certificates valid for more than 5 years.
var Jul2012 = InclusiveDate(2012, time.July, 01) var Jul2012 = InclusiveDate(year2012, time.July, day1)
// Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop // Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 39 months. // issuing certificates valid for more than 39 months.
var Apr2015 = InclusiveDate(2015, time.April, 01) var Apr2015 = InclusiveDate(year2015, time.April, day1)
// KeyLength returns the bit size of ECDSA or RSA PublicKey. // KeyLength returns the bit size of ECDSA or RSA PublicKey.
func KeyLength(key any) int { func KeyLength(key any) int {
@@ -108,11 +114,11 @@ func KeyLength(key any) int {
} }
// ExpiryTime returns the time when the certificate chain is expired. // ExpiryTime returns the time when the certificate chain is expired.
func ExpiryTime(chain []*x509.Certificate) (notAfter time.Time) { func ExpiryTime(chain []*x509.Certificate) time.Time {
var notAfter time.Time
if len(chain) == 0 { if len(chain) == 0 {
return notAfter return notAfter
} }
notAfter = chain[0].NotAfter notAfter = chain[0].NotAfter
for _, cert := range chain { for _, cert := range chain {
if notAfter.After(cert.NotAfter) { if notAfter.After(cert.NotAfter) {
@@ -158,18 +164,23 @@ func ValidExpiry(c *x509.Certificate) bool {
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm. // an X509 signature algorithm.
var signatureString = map[x509.SignatureAlgorithm]string{ var signatureString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Signature",
x509.MD2WithRSA: "MD2WithRSA", x509.MD2WithRSA: "MD2WithRSA",
x509.MD5WithRSA: "MD5WithRSA", x509.MD5WithRSA: "MD5WithRSA",
x509.SHA1WithRSA: "SHA1WithRSA", x509.SHA1WithRSA: "SHA1WithRSA",
x509.SHA256WithRSA: "SHA256WithRSA", x509.SHA256WithRSA: "SHA256WithRSA",
x509.SHA384WithRSA: "SHA384WithRSA", x509.SHA384WithRSA: "SHA384WithRSA",
x509.SHA512WithRSA: "SHA512WithRSA", x509.SHA512WithRSA: "SHA512WithRSA",
x509.SHA256WithRSAPSS: "SHA256WithRSAPSS",
x509.SHA384WithRSAPSS: "SHA384WithRSAPSS",
x509.SHA512WithRSAPSS: "SHA512WithRSAPSS",
x509.DSAWithSHA1: "DSAWithSHA1", x509.DSAWithSHA1: "DSAWithSHA1",
x509.DSAWithSHA256: "DSAWithSHA256", x509.DSAWithSHA256: "DSAWithSHA256",
x509.ECDSAWithSHA1: "ECDSAWithSHA1", x509.ECDSAWithSHA1: "ECDSAWithSHA1",
x509.ECDSAWithSHA256: "ECDSAWithSHA256", x509.ECDSAWithSHA256: "ECDSAWithSHA256",
x509.ECDSAWithSHA384: "ECDSAWithSHA384", x509.ECDSAWithSHA384: "ECDSAWithSHA384",
x509.ECDSAWithSHA512: "ECDSAWithSHA512", x509.ECDSAWithSHA512: "ECDSAWithSHA512",
x509.PureEd25519: "PureEd25519",
} }
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
@@ -184,18 +195,23 @@ func SignatureString(alg x509.SignatureAlgorithm) string {
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
// method. // method.
var hashAlgoString = map[x509.SignatureAlgorithm]string{ var hashAlgoString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Hash Algorithm",
x509.MD2WithRSA: "MD2", x509.MD2WithRSA: "MD2",
x509.MD5WithRSA: "MD5", x509.MD5WithRSA: "MD5",
x509.SHA1WithRSA: "SHA1", x509.SHA1WithRSA: "SHA1",
x509.SHA256WithRSA: "SHA256", x509.SHA256WithRSA: "SHA256",
x509.SHA384WithRSA: "SHA384", x509.SHA384WithRSA: "SHA384",
x509.SHA512WithRSA: "SHA512", x509.SHA512WithRSA: "SHA512",
x509.SHA256WithRSAPSS: "SHA256",
x509.SHA384WithRSAPSS: "SHA384",
x509.SHA512WithRSAPSS: "SHA512",
x509.DSAWithSHA1: "SHA1", x509.DSAWithSHA1: "SHA1",
x509.DSAWithSHA256: "SHA256", x509.DSAWithSHA256: "SHA256",
x509.ECDSAWithSHA1: "SHA1", x509.ECDSAWithSHA1: "SHA1",
x509.ECDSAWithSHA256: "SHA256", x509.ECDSAWithSHA256: "SHA256",
x509.ECDSAWithSHA384: "SHA384", x509.ECDSAWithSHA384: "SHA384",
x509.ECDSAWithSHA512: "SHA512", x509.ECDSAWithSHA512: "SHA512",
x509.PureEd25519: "SHA512", // per x509 docs Ed25519 uses SHA-512 internally
} }
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
@@ -273,7 +289,7 @@ func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
// ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key, // ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key,
// either PKCS #7, PKCS #12, or raw x509. // either PKCS #7, PKCS #12, or raw x509.
func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certificate, key crypto.Signer, err error) { func ParseCertificatesDER(certsDER []byte, password string) ([]*x509.Certificate, crypto.Signer, error) {
certsDER = bytes.TrimSpace(certsDER) certsDER = bytes.TrimSpace(certsDER)
// First, try PKCS #7 // First, try PKCS #7
@@ -284,7 +300,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
errors.New("can only extract certificates from signed data content info"), errors.New("can only extract certificates from signed data content info"),
) )
} }
certs = pkcs7data.Content.SignedData.Certificates certs := pkcs7data.Content.SignedData.Certificates
if certs == nil { if certs == nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded")) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded"))
} }
@@ -304,7 +320,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
} }
// Finally, attempt to parse raw X.509 certificates // Finally, attempt to parse raw X.509 certificates
certs, err = x509.ParseCertificates(certsDER) certs, err := x509.ParseCertificates(certsDER)
if err != nil { if err != nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err)
} }
@@ -318,7 +334,8 @@ func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
return nil, err return nil, err
} }
if err := cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil { err = cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature)
if err != nil {
return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err) return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err)
} }
return cert, nil return cert, nil
@@ -362,8 +379,8 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
pkcs7data, err := pkcs7.ParsePKCS7(block.Bytes) pkcs7data, err2 := pkcs7.ParsePKCS7(block.Bytes)
if err != nil { if err2 != nil {
return nil, rest, err return nil, rest, err
} }
if pkcs7data.ContentInfo != "SignedData" { if pkcs7data.ContentInfo != "SignedData" {
@@ -382,7 +399,7 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
// LoadPEMCertPool loads a pool of PEM certificates from file. // LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" { if certsFile == "" {
return nil, nil return nil, nil //nolint:nilnil // no CA file provided -> treat as no pool and no error
} }
pemCerts, err := os.ReadFile(certsFile) pemCerts, err := os.ReadFile(certsFile)
if err != nil { if err != nil {
@@ -395,7 +412,7 @@ func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
// PEMToCertPool concerts PEM certificates to a CertPool. // PEMToCertPool concerts PEM certificates to a CertPool.
func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) { func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
if len(pemCerts) == 0 { if len(pemCerts) == 0 {
return nil, nil return nil, nil //nolint:nilnil // empty input means no pool needed
} }
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
@@ -409,14 +426,14 @@ func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
// ParsePrivateKeyPEM parses and returns a PEM-encoded private // ParsePrivateKeyPEM parses and returns a PEM-encoded private
// key. The private key may be either an unencrypted PKCS#8, PKCS#1, // key. The private key may be either an unencrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEM(keyPEM []byte) (crypto.Signer, error) {
return ParsePrivateKeyPEMWithPassword(keyPEM, nil) return ParsePrivateKeyPEMWithPassword(keyPEM, nil)
} }
// ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private // ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private
// key. The private key may be a potentially encrypted PKCS#8, PKCS#1, // key. The private key may be a potentially encrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (crypto.Signer, error) {
keyDER, err := GetKeyDERFromPEM(keyPEM, password) keyDER, err := GetKeyDERFromPEM(keyPEM, password)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -436,26 +453,33 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
break break
} }
} }
if keyDER != nil { if keyDER == nil {
if procType, ok := keyDER.Headers["Proc-Type"]; ok { return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
if strings.Contains(procType, "ENCRYPTED") { }
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
if password != nil { if password != nil {
return x509.DecryptPEMBlock(keyDER, password) return x509.DecryptPEMBlock(keyDER, password)
} }
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey) return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
} }
}
return keyDER.Bytes, nil return keyDER.Bytes, nil
} }
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
}
// ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request. // ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request.
func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error) { func ParseCSR(in []byte) (*x509.CertificateRequest, []byte, error) {
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, rest := pem.Decode(in) p, rest := pem.Decode(in)
if p != nil { if p == nil {
csr, err := x509.ParseCertificateRequest(in)
if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
}
if sigErr := csr.CheckSignature(); sigErr != nil {
return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
}
return csr, rest, nil
}
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" { if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" {
return nil, rest, certerr.ParsingError( return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCSR, certerr.ErrorSourceCSR,
@@ -463,20 +487,13 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
) )
} }
csr, err = x509.ParseCertificateRequest(p.Bytes) csr, err := x509.ParseCertificateRequest(p.Bytes)
} else {
csr, err = x509.ParseCertificateRequest(in)
}
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err) return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
err = csr.CheckSignature() return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
if err != nil {
return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, err)
} }
return csr, rest, nil return csr, rest, nil
} }
@@ -484,7 +501,7 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
// It does not check the signature. This is useful for dumping data from a CSR // It does not check the signature. This is useful for dumping data from a CSR
// locally. // locally.
func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) { func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode(csrPEM)
if block == nil { if block == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty")) return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty"))
} }
@@ -499,15 +516,20 @@ func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
// SignerAlgo returns an X.509 signature algorithm from a crypto.Signer. // SignerAlgo returns an X.509 signature algorithm from a crypto.Signer.
func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm { func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
const (
rsaBits2048 = 2048
rsaBits3072 = 3072
rsaBits4096 = 4096
)
switch pub := priv.Public().(type) { switch pub := priv.Public().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
bitLength := pub.N.BitLen() bitLength := pub.N.BitLen()
switch { switch {
case bitLength >= 4096: case bitLength >= rsaBits4096:
return x509.SHA512WithRSA return x509.SHA512WithRSA
case bitLength >= 3072: case bitLength >= rsaBits3072:
return x509.SHA384WithRSA return x509.SHA384WithRSA
case bitLength >= 2048: case bitLength >= rsaBits2048:
return x509.SHA256WithRSA return x509.SHA256WithRSA
default: default:
return x509.SHA1WithRSA return x509.SHA1WithRSA
@@ -537,7 +559,7 @@ func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, e
} }
return &cert, nil return &cert, nil
} }
return nil, nil return nil, nil //nolint:nilnil // absence of client cert is not an error
} }
// CreateTLSConfig creates a tls.Config object from certs and roots. // CreateTLSConfig creates a tls.Config object from certs and roots.
@@ -549,6 +571,7 @@ func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Confi
return &tls.Config{ return &tls.Config{
Certificates: certs, Certificates: certs,
RootCAs: remoteCAs, RootCAs: remoteCAs,
MinVersion: tls.VersionTLS12, // secure default
} }
} }
@@ -582,11 +605,11 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList)) list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList))
for i, serializedSCT := range sctList.SCTList { for i, serializedSCT := range sctList.SCTList {
var sct ct.SignedCertificateTimestamp var sct ct.SignedCertificateTimestamp
rest, err := cttls.Unmarshal(serializedSCT.Val, &sct) rest2, err2 := cttls.Unmarshal(serializedSCT.Val, &sct)
if err != nil { if err2 != nil {
return nil, err return nil, err2
} }
if len(rest) != 0 { if len(rest2) != 0 {
return nil, certerr.ParsingError( return nil, certerr.ParsingError(
certerr.ErrorSourceSCTList, certerr.ErrorSourceSCTList,
errors.New("serialized SCT list contained trailing garbage"), errors.New("serialized SCT list contained trailing garbage"),
@@ -602,12 +625,12 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
// unmarshalled. // unmarshalled.
func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) { func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) {
// This loop finds the SCTListExtension in the OCSP response. // This loop finds the SCTListExtension in the OCSP response.
var SCTListExtension, ext pkix.Extension var sctListExtension, ext pkix.Extension
for _, ext = range response.Extensions { for _, ext = range response.Extensions {
// sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp. // sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp.
sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5} sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
if ext.Id.Equal(sctExtOid) { if ext.Id.Equal(sctExtOid) {
SCTListExtension = ext sctListExtension = ext
break break
} }
} }
@@ -615,10 +638,10 @@ func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTim
// This code block extracts the sctList from the SCT extension. // This code block extracts the sctList from the SCT extension.
var sctList []ct.SignedCertificateTimestamp var sctList []ct.SignedCertificateTimestamp
var err error var err error
if numBytes := len(SCTListExtension.Value); numBytes != 0 { if numBytes := len(sctListExtension.Value); numBytes != 0 {
var serializedSCTList []byte var serializedSCTList []byte
rest := make([]byte, numBytes) rest := make([]byte, numBytes)
copy(rest, SCTListExtension.Value) copy(rest, sctListExtension.Value)
for len(rest) != 0 { for len(rest) != 0 {
rest, err = asn1.Unmarshal(rest, &serializedSCTList) rest, err = asn1.Unmarshal(rest, &serializedSCTList)
if err != nil { if err != nil {

View File

@@ -9,6 +9,8 @@ import (
"strings" "strings"
) )
const defaultHTTPSPort = 443
type Target struct { type Target struct {
Host string Host string
Port int Port int
@@ -29,29 +31,29 @@ func parseURL(host string) (string, int, error) {
} }
if url.Port() == "" { if url.Port() == "" {
return url.Hostname(), 443, nil return url.Hostname(), defaultHTTPSPort, nil
} }
port, err := strconv.ParseInt(url.Port(), 10, 16) portInt, err2 := strconv.ParseInt(url.Port(), 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port()) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port())
} }
return url.Hostname(), int(port), nil return url.Hostname(), int(portInt), nil
} }
func parseHostPort(host string) (string, int, error) { func parseHostPort(host string) (string, int, error) {
host, sport, err := net.SplitHostPort(host) host, sport, err := net.SplitHostPort(host)
if err == nil { if err == nil {
port, err := strconv.ParseInt(sport, 10, 16) portInt, err2 := strconv.ParseInt(sport, 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport)
} }
return host, int(port), nil return host, int(portInt), nil
} }
return host, 443, nil return host, defaultHTTPSPort, nil
} }
func ParseHost(host string) (*Target, error) { func ParseHost(host string) (*Target, error) {

View File

@@ -158,9 +158,9 @@ type EncryptedContentInfo struct {
EncryptedContent []byte `asn1:"tag:0,optional"` EncryptedContent []byte `asn1:"tag:0,optional"`
} }
func unmarshalInit(raw []byte) (init initPKCS7, err error) { func unmarshalInit(raw []byte) (initPKCS7, error) {
_, err = asn1.Unmarshal(raw, &init) var init initPKCS7
if err != nil { if _, err := asn1.Unmarshal(raw, &init); err != nil {
return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return init, nil return init, nil
@@ -218,28 +218,28 @@ func populateEncryptedData(msg *PKCS7, contentBytes []byte) error {
// ParsePKCS7 attempts to parse the DER encoded bytes of a // ParsePKCS7 attempts to parse the DER encoded bytes of a
// PKCS7 structure. // PKCS7 structure.
func ParsePKCS7(raw []byte) (msg *PKCS7, err error) { func ParsePKCS7(raw []byte) (*PKCS7, error) {
pkcs7, err := unmarshalInit(raw) pkcs7, err := unmarshalInit(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg = new(PKCS7) msg := new(PKCS7)
msg.Raw = pkcs7.Raw msg.Raw = pkcs7.Raw
msg.ContentInfo = pkcs7.ContentType.String() msg.ContentInfo = pkcs7.ContentType.String()
switch msg.ContentInfo { switch msg.ContentInfo {
case ObjIDData: case ObjIDData:
if err := populateData(msg, pkcs7.Content); err != nil { if e := populateData(msg, pkcs7.Content); e != nil {
return nil, err return nil, e
} }
case ObjIDSignedData: case ObjIDSignedData:
if err := populateSignedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateSignedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
case ObjIDEncryptedData: case ObjIDEncryptedData:
if err := populateEncryptedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateEncryptedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
default: default:
return nil, certerr.ParsingError( return nil, certerr.ParsingError(

View File

@@ -5,6 +5,7 @@ package revoke
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -90,34 +91,34 @@ func ldapURL(url string) bool {
// - false, true: the certificate was checked successfully, and it is not revoked. // - false, true: the certificate was checked successfully, and it is not revoked.
// - true, true: the certificate was checked successfully, and it is revoked. // - true, true: the certificate was checked successfully, and it is revoked.
// - true, false: failure to check revocation status causes verification to fail. // - true, false: failure to check revocation status causes verification to fail.
func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) { func revCheck(cert *x509.Certificate) (bool, bool, error) {
for _, url := range cert.CRLDistributionPoints { for _, url := range cert.CRLDistributionPoints {
if ldapURL(url) { if ldapURL(url) {
log.Infof("skipping LDAP CRL: %s", url) log.Infof("skipping LDAP CRL: %s", url)
continue continue
} }
if revoked, ok, err := certIsRevokedCRL(cert, url); !ok { if rvk, ok2, err2 := certIsRevokedCRL(cert, url); !ok2 {
log.Warning("error checking revocation via CRL") log.Warning("error checking revocation via CRL")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via CRL") log.Info("certificate is revoked via CRL")
return true, true, err return true, true, err2
} }
} }
if revoked, ok, err := certIsRevokedOCSP(cert, HardFail); !ok { if rvk, ok2, err2 := certIsRevokedOCSP(cert, HardFail); !ok2 {
log.Warning("error checking revocation via OCSP") log.Warning("error checking revocation via OCSP")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via OCSP") log.Info("certificate is revoked via OCSP")
return true, true, err return true, true, err2
} }
return false, true, nil return false, true, nil
@@ -125,13 +126,17 @@ func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) {
// fetchCRL fetches and parses a CRL. // fetchCRL fetches and parses a CRL.
func fetchCRL(url string) (*x509.RevocationList, error) { func fetchCRL(url string) (*x509.RevocationList, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 300 { if resp.StatusCode >= http.StatusMultipleChoices {
return nil, errors.New("failed to retrieve CRL") return nil, errors.New("failed to retrieve CRL")
} }
@@ -158,7 +163,7 @@ func getIssuer(cert *x509.Certificate) *x509.Certificate {
// check a cert against a specific CRL. Returns the same bool pair // check a cert against a specific CRL. Returns the same bool pair
// as revCheck, plus an error if one occurred. // as revCheck, plus an error if one occurred.
func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err error) { func certIsRevokedCRL(cert *x509.Certificate, url string) (bool, bool, error) {
crlLock.Lock() crlLock.Lock()
crl, ok := CRLSet[url] crl, ok := CRLSet[url]
if ok && crl == nil { if ok && crl == nil {
@@ -186,10 +191,9 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
// check CRL signature // check CRL signature
if issuer != nil { if issuer != nil {
err = crl.CheckSignatureFrom(issuer) if sigErr := crl.CheckSignatureFrom(issuer); sigErr != nil {
if err != nil { log.Warningf("failed to verify CRL: %v", sigErr)
log.Warningf("failed to verify CRL: %v", err) return false, false, sigErr
return false, false, err
} }
} }
@@ -198,26 +202,26 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
crlLock.Unlock() crlLock.Unlock()
} }
for _, revoked := range crl.RevokedCertificates { for _, entry := range crl.RevokedCertificateEntries {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 { if cert.SerialNumber.Cmp(entry.SerialNumber) == 0 {
log.Info("Serial number match: intermediate is revoked.") log.Info("Serial number match: intermediate is revoked.")
return true, true, err return true, true, nil
} }
} }
return false, true, err return false, true, nil
} }
// VerifyCertificate ensures that the certificate passed in hasn't // VerifyCertificate ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificate(cert *x509.Certificate) (revoked, ok bool) { func VerifyCertificate(cert *x509.Certificate) (bool, bool) {
revoked, ok, _ = VerifyCertificateError(cert) revoked, ok, _ := VerifyCertificateError(cert)
return revoked, ok return revoked, ok
} }
// VerifyCertificateError ensures that the certificate passed in hasn't // VerifyCertificateError ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error) { func VerifyCertificateError(cert *x509.Certificate) (bool, bool, error) {
if !time.Now().Before(cert.NotAfter) { if !time.Now().Before(cert.NotAfter) {
msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter) msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter)
log.Info(msg) log.Info(msg)
@@ -231,7 +235,11 @@ func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error
} }
func fetchRemote(url string) (*x509.Certificate, error) { func fetchRemote(url string) (*x509.Certificate, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,8 +262,12 @@ var ocspOpts = ocsp.RequestOptions{
Hash: crypto.SHA1, Hash: crypto.SHA1,
} }
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e error) { const ocspGetURLMaxLen = 256
var err error
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (bool, bool, error) {
var revoked bool
var ok bool
var lastErr error
ocspURLs := leaf.OCSPServer ocspURLs := leaf.OCSPServer
if len(ocspURLs) == 0 { if len(ocspURLs) == 0 {
@@ -271,15 +283,16 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts) ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts)
if err != nil { if err != nil {
return revoked, ok, err return false, false, err
} }
for _, server := range ocspURLs { for _, server := range ocspURLs {
resp, err := sendOCSPRequest(server, ocspRequest, leaf, issuer) resp, e := sendOCSPRequest(server, ocspRequest, leaf, issuer)
if err != nil { if e != nil {
if strict { if strict {
return revoked, ok, err return false, false, e
} }
lastErr = e
continue continue
} }
@@ -291,9 +304,9 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
revoked = true revoked = true
} }
return revoked, ok, err return revoked, ok, nil
} }
return revoked, ok, err return revoked, ok, lastErr
} }
// sendOCSPRequest attempts to request an OCSP response from the // sendOCSPRequest attempts to request an OCSP response from the
@@ -302,12 +315,21 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) { func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) {
var resp *http.Response var resp *http.Response
var err error var err error
if len(req) > 256 { if len(req) > ocspGetURLMaxLen {
buf := bytes.NewBuffer(req) buf := bytes.NewBuffer(req)
resp, err = HTTPClient.Post(server, "application/ocsp-request", buf) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodPost, server, buf)
if e != nil {
return nil, e
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
resp, err = HTTPClient.Do(httpReq)
} else { } else {
reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req)) reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req))
resp, err = HTTPClient.Get(reqURL) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
if e != nil {
return nil, e
}
resp, err = HTTPClient.Do(httpReq)
} }
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package revoke package revoke
import ( import (
@@ -153,7 +154,7 @@ func mustParse(pemData string) *x509.Certificate {
panic("Invalid PEM type.") panic("Invalid PEM type.")
} }
cert, err := x509.ParseCertificate([]byte(block.Bytes)) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"net" "net"
@@ -28,10 +29,16 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Printf("connecting to %s/%s... ", addr, proto) fmt.Printf("connecting to %s/%s... ", addr, proto)
os.Stdout.Sync() if err = os.Stdout.Sync(); err != nil {
return err
}
} }
conn, err := net.DialTimeout(proto, addr, timeout) dialer := &net.Dialer{
Timeout: timeout,
}
conn, err := dialer.DialContext(context.Background(), proto, addr)
if err != nil { if err != nil {
if verbose { if verbose {
fmt.Println("failed.") fmt.Println("failed.")
@@ -42,8 +49,8 @@ func connect(addr string, dport string, six bool, timeout time.Duration) error {
if verbose { if verbose {
fmt.Println("OK") fmt.Println("OK")
} }
conn.Close()
return nil return conn.Close()
} }
func main() { func main() {

View File

@@ -3,6 +3,7 @@ package main
import ( import (
"crypto/x509" "crypto/x509"
"embed" "embed"
"errors"
"fmt" "fmt"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,22 +15,22 @@ import (
// loadCertsFromFile attempts to parse certificates from a file that may be in // loadCertsFromFile attempts to parse certificates from a file that may be in
// PEM or DER/PKCS#7 format. Returns the parsed certificates or an error. // PEM or DER/PKCS#7 format. Returns the parsed certificates or an error.
func loadCertsFromFile(path string) ([]*x509.Certificate, error) { func loadCertsFromFile(path string) ([]*x509.Certificate, error) {
var certs []*x509.Certificate
data, err := os.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Try PEM first if certs, err = certlib.ParseCertificatesPEM(data); err == nil {
if certs, err := certlib.ParseCertificatesPEM(data); err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password) if certs, _, err = certlib.ParseCertificatesDER(data, ""); err == nil {
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil {
return certs, nil return certs, nil
} else {
return nil, err
} }
return nil, err
} }
func makePoolFromFile(path string) (*x509.CertPool, error) { func makePoolFromFile(path string) (*x509.CertPool, error) {
@@ -56,22 +57,23 @@ var embeddedTestdata embed.FS
// loadCertsFromBytes attempts to parse certificates from bytes that may be in // loadCertsFromBytes attempts to parse certificates from bytes that may be in
// PEM or DER/PKCS#7 format. // PEM or DER/PKCS#7 format.
func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) { func loadCertsFromBytes(data []byte) ([]*x509.Certificate, error) {
// Try PEM first certs, err := certlib.ParseCertificatesPEM(data)
if certs, err := certlib.ParseCertificatesPEM(data); err == nil { if err == nil {
return certs, nil return certs, nil
} }
// Try DER/PKCS7/PKCS12 (with no password)
if certs, _, err := certlib.ParseCertificatesDER(data, ""); err == nil { certs, _, err = certlib.ParseCertificatesDER(data, "")
if err == nil {
return certs, nil return certs, nil
} else { }
return nil, err return nil, err
} }
}
func makePoolFromBytes(data []byte) (*x509.CertPool, error) { func makePoolFromBytes(data []byte) (*x509.CertPool, error) {
certs, err := loadCertsFromBytes(data) certs, err := loadCertsFromBytes(data)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return nil, fmt.Errorf("failed to load CA certificates from embedded bytes") return nil, errors.New("failed to load CA certificates from embedded bytes")
} }
pool := x509.NewCertPool() pool := x509.NewCertPool()
for _, c := range certs { for _, c := range certs {
@@ -98,7 +100,7 @@ func isSelfSigned(cert *x509.Certificate) bool {
return true return true
} }
func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string) { func verifyAgainstCA(caPool *x509.CertPool, path string) (bool, string) {
certs, err := loadCertsFromFile(path) certs, err := loadCertsFromFile(path)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -117,14 +119,14 @@ func verifyAgainstCA(caPool *x509.CertPool, path string) (ok bool, expiry string
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expiry string) { func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (bool, string) {
certs, err := loadCertsFromBytes(certData) certs, err := loadCertsFromBytes(certData)
if err != nil || len(certs) == 0 { if err != nil || len(certs) == 0 {
return false, "" return false, ""
@@ -143,15 +145,13 @@ func verifyAgainstCABytes(caPool *x509.CertPool, certData []byte) (ok bool, expi
Intermediates: ints, Intermediates: ints,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
if _, err := leaf.Verify(opts); err != nil { if _, err = leaf.Verify(opts); err != nil {
return false, "" return false, ""
} }
return true, leaf.NotAfter.Format("2006-01-02") return true, leaf.NotAfter.Format("2006-01-02")
} }
// selftest runs built-in validation using embedded certificates.
func selftest() int {
type testCase struct { type testCase struct {
name string name string
caFile string caFile string
@@ -159,43 +159,72 @@ func selftest() int {
expectOK bool expectOK bool
} }
cases := []testCase{ func (tc testCase) Run() error {
{name: "ISRG Root X1 validates LE E7", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/le-e7.pem", expectOK: true},
{name: "ISRG Root X1 does NOT validate Google WR2", caFile: "testdata/isrg-root-x1.pem", certFile: "testdata/goog-wr2.pem", expectOK: false},
{name: "GTS R1 validates Google WR2", caFile: "testdata/gts-r1.pem", certFile: "testdata/goog-wr2.pem", expectOK: true},
{name: "GTS R1 does NOT validate LE E7", caFile: "testdata/gts-r1.pem", certFile: "testdata/le-e7.pem", expectOK: false},
}
failures := 0
for _, tc := range cases {
caBytes, err := embeddedTestdata.ReadFile(tc.caFile) caBytes, err := embeddedTestdata.ReadFile(tc.caFile)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.caFile, err) return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.caFile, err)
failures++
continue
} }
certBytes, err := embeddedTestdata.ReadFile(tc.certFile) certBytes, err := embeddedTestdata.ReadFile(tc.certFile)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "selftest: failed to read embedded %s: %v\n", tc.certFile, err) return fmt.Errorf("selftest: failed to read embedded %s: %w", tc.certFile, err)
failures++
continue
} }
pool, err := makePoolFromBytes(caBytes) pool, err := makePoolFromBytes(caBytes)
if err != nil || pool == nil { if err != nil || pool == nil {
fmt.Fprintf(os.Stderr, "selftest: failed to build CA pool for %s: %v\n", tc.caFile, err) return fmt.Errorf("selftest: failed to build CA pool for %s: %w", tc.caFile, err)
failures++
continue
} }
ok, exp := verifyAgainstCABytes(pool, certBytes) ok, exp := verifyAgainstCABytes(pool, certBytes)
if ok != tc.expectOK { if ok != tc.expectOK {
fmt.Printf("%s: unexpected result: got %v, want %v\n", tc.name, ok, tc.expectOK) return fmt.Errorf("%s: unexpected result: got %v, want %v", tc.name, ok, tc.expectOK)
failures++ }
} else {
if ok { if ok {
fmt.Printf("%s: OK (expires %s)\n", tc.name, exp) fmt.Printf("%s: OK (expires %s)\n", tc.name, exp)
} else {
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
} }
fmt.Printf("%s: INVALID (as expected)\n", tc.name)
return nil
}
var cases = []testCase{
{
name: "ISRG Root X1 validates LE E7",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/le-e7.pem",
expectOK: true,
},
{
name: "ISRG Root X1 does NOT validate Google WR2",
caFile: "testdata/isrg-root-x1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: false,
},
{
name: "GTS R1 validates Google WR2",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/goog-wr2.pem",
expectOK: true,
},
{
name: "GTS R1 does NOT validate LE E7",
caFile: "testdata/gts-r1.pem",
certFile: "testdata/le-e7.pem",
expectOK: false,
},
}
// selftest runs built-in validation using embedded certificates.
func selftest() int {
failures := 0
for _, tc := range cases {
err := tc.Run()
if err != nil {
fmt.Fprintln(os.Stderr, err)
failures++
continue
} }
} }
@@ -231,6 +260,46 @@ func selftest() int {
return 1 return 1
} }
// expiryString returns a YYYY-MM-DD date string to display for certificate
// expiry. If an explicit exp string is provided, it is used. Otherwise, if a
// leaf certificate is available, its NotAfter is formatted. As a last resort,
// it falls back to today's date (should not normally happen).
func expiryString(leaf *x509.Certificate, exp string) string {
if exp != "" {
return exp
}
if leaf != nil {
return leaf.NotAfter.Format("2006-01-02")
}
return time.Now().Format("2006-01-02")
}
// processCert verifies a single certificate file against the provided CA pool
// and prints the result in the required format, handling self-signed
// certificates specially.
func processCert(caPool *x509.CertPool, certPath string) {
ok, exp := verifyAgainstCA(caPool, certPath)
name := filepath.Base(certPath)
// Try to load the leaf cert for self-signed detection and expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// Prefer the SELF-SIGNED label if applicable
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
return
}
if ok {
fmt.Printf("%s: OK (expires %s)\n", name, expiryString(leaf, exp))
return
}
fmt.Printf("%s: INVALID\n", name)
}
func main() { func main() {
// Special selftest mode: single argument "selftest" // Special selftest mode: single argument "selftest"
if len(os.Args) == 2 && os.Args[1] == "selftest" { if len(os.Args) == 2 && os.Args[1] == "selftest" {
@@ -251,37 +320,6 @@ func main() {
} }
for _, certPath := range os.Args[2:] { for _, certPath := range os.Args[2:] {
ok, exp := verifyAgainstCA(caPool, certPath) processCert(caPool, certPath)
name := filepath.Base(certPath)
// Load the leaf once for self-signed detection and potential expiry fallback
var leaf *x509.Certificate
if certs, err := loadCertsFromFile(certPath); err == nil && len(certs) > 0 {
leaf = certs[0]
}
// If the certificate is self-signed, prefer the SELF-SIGNED label
if isSelfSigned(leaf) {
fmt.Printf("%s: SELF-SIGNED\n", name)
continue
}
if ok {
// Display with the requested format
// Example: file: OK (expires 2031-01-01)
// Ensure deterministic date formatting
// Note: no timezone displayed; date only as per example
// If exp ended up empty for some reason, recompute safely
if exp == "" {
if leaf != nil {
exp = leaf.NotAfter.Format("2006-01-02")
} else {
// fallback to the current date to avoid empty; though shouldn't happen
exp = time.Now().Format("2006-01-02")
}
}
fmt.Printf("%s: OK (expires %s)\n", name, exp)
} else {
fmt.Printf("%s: INVALID\n", name)
}
} }
} }

View File

@@ -8,8 +8,10 @@ import (
"crypto/x509" "crypto/x509"
_ "embed" _ "embed"
"encoding/pem" "encoding/pem"
"errors"
"flag" "flag"
"fmt" "fmt"
"io"
"os" "os"
"path/filepath" "path/filepath"
"strings" "strings"
@@ -19,7 +21,7 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
// Config represents the top-level YAML configuration // Config represents the top-level YAML configuration.
type Config struct { type Config struct {
Config struct { Config struct {
Hashes string `yaml:"hashes"` Hashes string `yaml:"hashes"`
@@ -28,19 +30,19 @@ type Config struct {
Chains map[string]ChainGroup `yaml:"chains"` Chains map[string]ChainGroup `yaml:"chains"`
} }
// ChainGroup represents a named group of certificate chains // ChainGroup represents a named group of certificate chains.
type ChainGroup struct { type ChainGroup struct {
Certs []CertChain `yaml:"certs"` Certs []CertChain `yaml:"certs"`
Outputs Outputs `yaml:"outputs"` Outputs Outputs `yaml:"outputs"`
} }
// CertChain represents a root certificate and its intermediates // CertChain represents a root certificate and its intermediates.
type CertChain struct { type CertChain struct {
Root string `yaml:"root"` Root string `yaml:"root"`
Intermediates []string `yaml:"intermediates"` Intermediates []string `yaml:"intermediates"`
} }
// Outputs defines output format options // Outputs defines output format options.
type Outputs struct { type Outputs struct {
IncludeSingle bool `yaml:"include_single"` IncludeSingle bool `yaml:"include_single"`
IncludeIndividual bool `yaml:"include_individual"` IncludeIndividual bool `yaml:"include_individual"`
@@ -95,7 +97,8 @@ func main() {
} }
// Create output directory if it doesn't exist // Create output directory if it doesn't exist
if err := os.MkdirAll(outputDir, 0755); err != nil { err = os.MkdirAll(outputDir, 0750)
if err != nil {
fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err) fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err)
os.Exit(1) os.Exit(1)
} }
@@ -108,9 +111,9 @@ func main() {
} }
createdFiles := make([]string, 0, totalFormats) createdFiles := make([]string, 0, totalFormats)
for groupName, group := range cfg.Chains { for groupName, group := range cfg.Chains {
files, err := processChainGroup(groupName, group, expiryDuration) files, perr := processChainGroup(groupName, group, expiryDuration)
if err != nil { if perr != nil {
fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, err) fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, perr)
os.Exit(1) os.Exit(1)
} }
createdFiles = append(createdFiles, files...) createdFiles = append(createdFiles, files...)
@@ -119,8 +122,8 @@ func main() {
// Generate hash file for all created archives // Generate hash file for all created archives
if cfg.Config.Hashes != "" { if cfg.Config.Hashes != "" {
hashFile := filepath.Join(outputDir, cfg.Config.Hashes) hashFile := filepath.Join(outputDir, cfg.Config.Hashes)
if err := generateHashFile(hashFile, createdFiles); err != nil { if gerr := generateHashFile(hashFile, createdFiles); gerr != nil {
fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", err) fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", gerr)
os.Exit(1) os.Exit(1)
} }
} }
@@ -135,8 +138,8 @@ func loadConfig(path string) (*Config, error) {
} }
var cfg Config var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil { if uerr := yaml.Unmarshal(data, &cfg); uerr != nil {
return nil, err return nil, uerr
} }
return &cfg, nil return &cfg, nil
@@ -200,16 +203,48 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D
return createdFiles, nil return createdFiles, nil
} }
// loadAndCollectCerts loads all certificates from chains and collects them for processing // loadAndCollectCerts loads all certificates from chains and collects them for processing.
func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration time.Duration) ([]*x509.Certificate, []certWithPath, error) { func loadAndCollectCerts(
chains []CertChain,
outputs Outputs,
expiryDuration time.Duration,
) ([]*x509.Certificate, []certWithPath, error) {
var singleFileCerts []*x509.Certificate var singleFileCerts []*x509.Certificate
var individualCerts []certWithPath var individualCerts []certWithPath
for _, chain := range chains { for _, chain := range chains {
s, i, cerr := collectFromChain(chain, outputs, expiryDuration)
if cerr != nil {
return nil, nil, cerr
}
if len(s) > 0 {
singleFileCerts = append(singleFileCerts, s...)
}
if len(i) > 0 {
individualCerts = append(individualCerts, i...)
}
}
return singleFileCerts, individualCerts, nil
}
// collectFromChain loads a single chain, performs checks, and returns the certs to include.
func collectFromChain(
chain CertChain,
outputs Outputs,
expiryDuration time.Duration,
) (
[]*x509.Certificate,
[]certWithPath,
error,
) {
var single []*x509.Certificate
var indiv []certWithPath
// Load root certificate // Load root certificate
rootCert, err := certlib.LoadCertificate(chain.Root) rootCert, rerr := certlib.LoadCertificate(chain.Root)
if err != nil { if rerr != nil {
return nil, nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err) return nil, nil, fmt.Errorf("failed to load root certificate %s: %w", chain.Root, rerr)
} }
// Check expiry for root // Check expiry for root
@@ -217,25 +252,27 @@ func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration tim
// Add root to collections if needed // Add root to collections if needed
if outputs.IncludeSingle { if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, rootCert) single = append(single, rootCert)
} }
if outputs.IncludeIndividual { if outputs.IncludeIndividual {
individualCerts = append(individualCerts, certWithPath{ indiv = append(indiv, certWithPath{cert: rootCert, path: chain.Root})
cert: rootCert,
path: chain.Root,
})
} }
// Load and validate intermediates // Load and validate intermediates
for _, intPath := range chain.Intermediates { for _, intPath := range chain.Intermediates {
intCert, err := certlib.LoadCertificate(intPath) intCert, lerr := certlib.LoadCertificate(intPath)
if err != nil { if lerr != nil {
return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err) return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %w", intPath, lerr)
} }
// Validate that intermediate is signed by root // Validate that intermediate is signed by root
if err := intCert.CheckSignatureFrom(rootCert); err != nil { if sigErr := intCert.CheckSignatureFrom(rootCert); sigErr != nil {
return nil, nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err) return nil, nil, fmt.Errorf(
"intermediate %s is not properly signed by root %s: %w",
intPath,
chain.Root,
sigErr,
)
} }
// Check expiry for intermediate // Check expiry for intermediate
@@ -243,29 +280,30 @@ func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration tim
// Add intermediate to collections if needed // Add intermediate to collections if needed
if outputs.IncludeSingle { if outputs.IncludeSingle {
singleFileCerts = append(singleFileCerts, intCert) single = append(single, intCert)
} }
if outputs.IncludeIndividual { if outputs.IncludeIndividual {
individualCerts = append(individualCerts, certWithPath{ indiv = append(indiv, certWithPath{cert: intCert, path: intPath})
cert: intCert,
path: intPath,
})
}
} }
} }
return singleFileCerts, individualCerts, nil return single, indiv, nil
} }
// prepareArchiveFiles prepares all files to be included in archives // prepareArchiveFiles prepares all files to be included in archives.
func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []certWithPath, outputs Outputs, encoding string) ([]fileEntry, error) { func prepareArchiveFiles(
singleFileCerts []*x509.Certificate,
individualCerts []certWithPath,
outputs Outputs,
encoding string,
) ([]fileEntry, error) {
var archiveFiles []fileEntry var archiveFiles []fileEntry
// Handle a single bundle file // Handle a single bundle file
if outputs.IncludeSingle && len(singleFileCerts) > 0 { if outputs.IncludeSingle && len(singleFileCerts) > 0 {
files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true) files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode single bundle: %v", err) return nil, fmt.Errorf("failed to encode single bundle: %w", err)
} }
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
@@ -276,7 +314,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []
baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path)) baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path))
files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false) files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to encode individual cert %s: %v", cp.path, err) return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err)
} }
archiveFiles = append(archiveFiles, files...) archiveFiles = append(archiveFiles, files...)
} }
@@ -294,7 +332,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []
return archiveFiles, nil return archiveFiles, nil
} }
// createArchiveFiles creates archive files in the specified formats // createArchiveFiles creates archive files in the specified formats.
func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) { func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) {
createdFiles := make([]string, 0, len(formats)) createdFiles := make([]string, 0, len(formats))
@@ -307,11 +345,11 @@ func createArchiveFiles(groupName string, formats []string, archiveFiles []fileE
switch format { switch format {
case "zip": case "zip":
if err := createZipArchive(archivePath, archiveFiles); err != nil { if err := createZipArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create zip archive: %v", err) return nil, fmt.Errorf("failed to create zip archive: %w", err)
} }
case "tgz": case "tgz":
if err := createTarGzArchive(archivePath, archiveFiles); err != nil { if err := createTarGzArchive(archivePath, archiveFiles); err != nil {
return nil, fmt.Errorf("failed to create tar.gz archive: %v", err) return nil, fmt.Errorf("failed to create tar.gz archive: %w", err)
} }
default: default:
return nil, fmt.Errorf("unsupported format: %s", format) return nil, fmt.Errorf("unsupported format: %s", format)
@@ -329,7 +367,12 @@ func checkExpiry(path string, cert *x509.Certificate, expiryDuration time.Durati
if cert.NotAfter.Before(expiryThreshold) { if cert.NotAfter.Before(expiryThreshold) {
daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24) daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24)
if daysUntilExpiry < 0 { if daysUntilExpiry < 0 {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s has EXPIRED (expired %d days ago)\n", path, -daysUntilExpiry) fmt.Fprintf(
os.Stderr,
"WARNING: Certificate %s has EXPIRED (expired %d days ago)\n",
path,
-daysUntilExpiry,
)
} else { } else {
fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02")) fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02"))
} }
@@ -347,8 +390,13 @@ type certWithPath struct {
} }
// encodeCertsToFiles converts certificates to file entries based on encoding type // encodeCertsToFiles converts certificates to file entries based on encoding type
// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file // If isSingle is true, certs are concatenated into a single file; otherwise one cert per file.
func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding string, isSingle bool) ([]fileEntry, error) { func encodeCertsToFiles(
certs []*x509.Certificate,
baseName string,
encoding string,
isSingle bool,
) ([]fileEntry, error) {
var files []fileEntry var files []fileEntry
switch encoding { switch encoding {
@@ -369,15 +417,13 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
name: baseName + ".crt", name: baseName + ".crt",
content: derContent, content: derContent,
}) })
} else { } else if len(certs) > 0 {
// Individual DER file (should only have one cert) // Individual DER file (should only have one cert)
if len(certs) > 0 {
files = append(files, fileEntry{ files = append(files, fileEntry{
name: baseName + ".crt", name: baseName + ".crt",
content: certs[0].Raw, content: certs[0].Raw,
}) })
} }
}
case "both": case "both":
// Add PEM version // Add PEM version
pemContent := encodeCertsToPEM(certs) pemContent := encodeCertsToPEM(certs)
@@ -395,14 +441,12 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
name: baseName + ".crt", name: baseName + ".crt",
content: derContent, content: derContent,
}) })
} else { } else if len(certs) > 0 {
if len(certs) > 0 {
files = append(files, fileEntry{ files = append(files, fileEntry{
name: baseName + ".crt", name: baseName + ".crt",
content: certs[0].Raw, content: certs[0].Raw,
}) })
} }
}
default: default:
return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding) return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding)
} }
@@ -410,7 +454,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str
return files, nil return files, nil
} }
// encodeCertsToPEM encodes certificates to PEM format // encodeCertsToPEM encodes certificates to PEM format.
func encodeCertsToPEM(certs []*x509.Certificate) []byte { func encodeCertsToPEM(certs []*x509.Certificate) []byte {
var pemContent []byte var pemContent []byte
for _, cert := range certs { for _, cert := range certs {
@@ -435,40 +479,49 @@ func generateManifest(files []fileEntry) []byte {
return []byte(manifest.String()) return []byte(manifest.String())
} }
// closeWithErr attempts to close all provided closers, joining any close errors with baseErr.
func closeWithErr(baseErr error, closers ...io.Closer) error {
for _, c := range closers {
if c == nil {
continue
}
if cerr := c.Close(); cerr != nil {
baseErr = errors.Join(baseErr, cerr)
}
}
return baseErr
}
func createZipArchive(path string, files []fileEntry) error { func createZipArchive(path string, files []fileEntry) error {
f, err := os.Create(path) f, zerr := os.Create(path)
if err != nil { if zerr != nil {
return err return zerr
} }
w := zip.NewWriter(f) w := zip.NewWriter(f)
for _, file := range files { for _, file := range files {
fw, err := w.Create(file.name) fw, werr := w.Create(file.name)
if err != nil { if werr != nil {
w.Close() return closeWithErr(werr, w, f)
f.Close()
return err
} }
if _, err := fw.Write(file.content); err != nil { if _, werr = fw.Write(file.content); werr != nil {
w.Close() return closeWithErr(werr, w, f)
f.Close()
return err
} }
} }
// Check errors on close operations // Check errors on close operations
if err := w.Close(); err != nil { if cerr := w.Close(); cerr != nil {
f.Close() _ = f.Close()
return err return cerr
} }
return f.Close() return f.Close()
} }
func createTarGzArchive(path string, files []fileEntry) error { func createTarGzArchive(path string, files []fileEntry) error {
f, err := os.Create(path) f, terr := os.Create(path)
if err != nil { if terr != nil {
return err return terr
} }
gw := gzip.NewWriter(f) gw := gzip.NewWriter(f)
@@ -480,29 +533,23 @@ func createTarGzArchive(path string, files []fileEntry) error {
Mode: 0644, Mode: 0644,
Size: int64(len(file.content)), Size: int64(len(file.content)),
} }
if err := tw.WriteHeader(hdr); err != nil { if herr := tw.WriteHeader(hdr); herr != nil {
tw.Close() return closeWithErr(herr, tw, gw, f)
gw.Close()
f.Close()
return err
} }
if _, err := tw.Write(file.content); err != nil { if _, werr := tw.Write(file.content); werr != nil {
tw.Close() return closeWithErr(werr, tw, gw, f)
gw.Close()
f.Close()
return err
} }
} }
// Check errors on close operations in the correct order // Check errors on close operations in the correct order
if err := tw.Close(); err != nil { if cerr := tw.Close(); cerr != nil {
gw.Close() _ = gw.Close()
f.Close() _ = f.Close()
return err return cerr
} }
if err := gw.Close(); err != nil { if cerr := gw.Close(); cerr != nil {
f.Close() _ = f.Close()
return err return cerr
} }
return f.Close() return f.Close()
} }
@@ -515,9 +562,9 @@ func generateHashFile(path string, files []string) error {
defer f.Close() defer f.Close()
for _, file := range files { for _, file := range files {
data, err := os.ReadFile(file) data, rerr := os.ReadFile(file)
if err != nil { if rerr != nil {
return err return rerr
} }
hash := sha256.Sum256(data) hash := sha256.Sum256(data)

View File

@@ -1,14 +1,15 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"flag"
"errors" "errors"
"flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
"strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
@@ -23,6 +24,13 @@ var (
verbose bool verbose bool
) )
var (
strOK = "OK"
strExpired = "EXPIRED"
strRevoked = "REVOKED"
strUnknown = "UNKNOWN"
)
func main() { func main() {
flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal") flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal")
flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects") flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects")
@@ -42,16 +50,16 @@ func main() {
for _, target := range flag.Args() { for _, target := range flag.Args() {
status, err := processTarget(target) status, err := processTarget(target)
switch status { switch status {
case "OK": case strOK:
fmt.Printf("%s: OK\n", target) fmt.Printf("%s: %s\n", target, strOK)
case "EXPIRED": case strExpired:
fmt.Printf("%s: EXPIRED: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strExpired, err)
exitCode = 1 exitCode = 1
case "REVOKED": case strRevoked:
fmt.Printf("%s: REVOKED\n", target) fmt.Printf("%s: %s\n", target, strRevoked)
exitCode = 1 exitCode = 1
case "UNKNOWN": case strUnknown:
fmt.Printf("%s: UNKNOWN: %v\n", target, err) fmt.Printf("%s: %s: %v\n", target, strUnknown, err)
if hardfail { if hardfail {
// In hardfail, treat unknown as failure // In hardfail, treat unknown as failure
exitCode = 1 exitCode = 1
@@ -67,74 +75,77 @@ func processTarget(target string) (string, error) {
return checkFile(target) return checkFile(target)
} }
// Not a file; treat as site
return checkSite(target) return checkSite(target)
} }
func checkFile(path string) (string, error) { func checkFile(path string) (string, error) {
in, err := ioutil.ReadFile(path) // Prefer high-level helpers from certlib to load certificates from disk
if err != nil { if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 {
return "UNKNOWN", err
}
// Try PEM first; if that fails, try single DER cert
certs, err := certlib.ReadCertificates(in)
if err != nil || len(certs) == 0 {
cert, _, derr := certlib.ReadCertificate(in)
if derr != nil || cert == nil {
if err == nil {
err = derr
}
return "UNKNOWN", err
}
return evaluateCert(cert)
}
// Evaluate the first certificate (leaf) by default // Evaluate the first certificate (leaf) by default
return evaluateCert(certs[0]) return evaluateCert(certs[0])
} }
cert, err := certlib.LoadCertificate(path)
if err != nil || cert == nil {
return strUnknown, err
}
return evaluateCert(cert)
}
func checkSite(hostport string) (string, error) { func checkSite(hostport string) (string, error) {
// Use certlib/hosts to parse host/port (supports https URLs and host:port) // Use certlib/hosts to parse host/port (supports https URLs and host:port)
target, err := hosts.ParseHost(hostport) target, err := hosts.ParseHost(hostport)
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
d := &net.Dialer{Timeout: timeout} d := &net.Dialer{Timeout: timeout}
conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}) tcfg := &tls.Config{
InsecureSkipVerify: true,
ServerName: target.Host,
} // #nosec G402 -- CLI tool only verifies revocation
td := &tls.Dialer{NetDialer: d, Config: tcfg}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
conn, err := td.DialContext(ctx, "tcp", target.String())
if err != nil { if err != nil {
return "UNKNOWN", err return strUnknown, err
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() tconn, ok := conn.(*tls.Conn)
if !ok {
return strUnknown, errors.New("connection is not TLS")
}
state := tconn.ConnectionState()
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
return "UNKNOWN", errors.New("no peer certificates presented") return strUnknown, errors.New("no peer certificates presented")
} }
return evaluateCert(state.PeerCertificates[0]) return evaluateCert(state.PeerCertificates[0])
} }
func evaluateCert(cert *x509.Certificate) (string, error) { func evaluateCert(cert *x509.Certificate) (string, error) {
// Expiry check // Delegate validity and revocation checks to certlib/revoke helper.
now := time.Now() // It returns revoked=true for both revoked and expired/not-yet-valid.
if !now.Before(cert.NotAfter) { // Map those cases back to our statuses using the returned error text.
return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter)
}
if !now.After(cert.NotBefore) {
return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore)
}
// Revocation check using certlib/revoke
revoked, ok, err := revoke.VerifyCertificateError(cert) revoked, ok, err := revoke.VerifyCertificateError(cert)
if revoked { if revoked {
// If revoked is true, ok will be true per implementation, err may describe why if err != nil {
return "REVOKED", err msg := err.Error()
if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") ||
strings.Contains(msg, "not valid until") {
return strExpired, err
}
}
return strRevoked, err
} }
if !ok { if !ok {
// Revocation status could not be determined // Revocation status could not be determined
return "UNKNOWN", err return strUnknown, err
} }
return "OK", nil return strOK, nil
} }

View File

@@ -1,11 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"os"
"regexp" "regexp"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -20,20 +23,26 @@ func main() {
server += ":443" server += ":443"
} }
var chain string d := &tls.Dialer{Config: &tls.Config{}} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", server)
conn, err := tls.Dial("tcp", server, nil)
die.If(err) die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
}
defer conn.Close()
details := conn.ConnectionState() details := conn.ConnectionState()
var chain strings.Builder
for _, cert := range details.PeerCertificates { for _, cert := range details.PeerCertificates {
p := pem.Block{ p := pem.Block{
Type: "CERTIFICATE", Type: "CERTIFICATE",
Bytes: cert.Raw, Bytes: cert.Raw,
} }
chain += string(pem.EncodeToMemory(&p)) chain.Write(pem.EncodeToMemory(&p))
} }
fmt.Println(chain) fmt.Fprintln(os.Stdout, chain.String())
} }
} }

View File

@@ -1,7 +1,9 @@
//lint:file-ignore SA1019 allow strict compatibility for old certs
package main package main
import ( import (
"bytes" "bytes"
"context"
"crypto/dsa" "crypto/dsa"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/elliptic" "crypto/elliptic"
@@ -101,30 +103,30 @@ func extUsage(ext []x509.ExtKeyUsage) string {
} }
func showBasicConstraints(cert *x509.Certificate) { func showBasicConstraints(cert *x509.Certificate) {
fmt.Printf("\tBasic constraints: ") fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid { if cert.BasicConstraintsValid {
fmt.Printf("valid") fmt.Fprint(os.Stdout, "valid")
} else { } else {
fmt.Printf("invalid") fmt.Fprint(os.Stdout, "invalid")
} }
if cert.IsCA { if cert.IsCA {
fmt.Printf(", is a CA certificate") fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid { if !cert.BasicConstraintsValid {
fmt.Printf(" (basic constraint failure)") fmt.Fprint(os.Stdout, " (basic constraint failure)")
} }
} else { } else {
fmt.Printf("is not a CA certificate") fmt.Fprint(os.Stdout, "is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Printf(" (key encipherment usage enabled!)") fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
} }
} }
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) { if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Printf(", max path length %d", cert.MaxPathLen) fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
} }
fmt.Printf("\n") fmt.Fprintln(os.Stdout)
} }
const oneTrueDateFormat = "2006-01-02T15:04:05-0700" const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
@@ -136,39 +138,41 @@ var (
func wrapPrint(text string, indent int) { func wrapPrint(text string, indent int) {
tabs := "" tabs := ""
for i := 0; i < indent; i++ { var tabsSb140 strings.Builder
tabs += "\t" for range indent {
tabsSb140.WriteString("\t")
} }
tabs += tabsSb140.String()
fmt.Printf(tabs+"%s\n", wrap(text, indent)) fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
} }
func displayCert(cert *x509.Certificate) { func displayCert(cert *x509.Certificate) {
fmt.Println("CERTIFICATE") fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash { if showHash {
fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0)) fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
} }
fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0)) fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0)) fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm), fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm)) sigAlgoHash(cert.SignatureAlgorithm))
fmt.Println("Details:") fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1) wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Printf("\tSerial number: %s\n", cert.SerialNumber) fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 { if len(cert.AuthorityKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
} }
if len(cert.SubjectKeyId) > 0 { if len(cert.SubjectKeyId) > 0 {
fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1)) fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
} }
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1) wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat)) fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage)) fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 { if len(cert.ExtKeyUsage) > 0 {
fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage)) fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
} }
showBasicConstraints(cert) showBasicConstraints(cert)
@@ -221,13 +225,13 @@ func displayAllCerts(in []byte, leafOnly bool) {
if err != nil { if err != nil {
certs, _, err = certlib.ParseCertificatesDER(in, "") certs, _, err = certlib.ParseCertificatesDER(in, "")
if err != nil { if err != nil {
lib.Warn(err, "failed to parse certificates") _, _ = lib.Warn(err, "failed to parse certificates")
return return
} }
} }
if len(certs) == 0 { if len(certs) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -243,29 +247,45 @@ func displayAllCerts(in []byte, leafOnly bool) {
func displayAllCertsWeb(uri string, leafOnly bool) { func displayAllCertsWeb(uri string, leafOnly bool) {
ci := getConnInfo(uri) ci := getConnInfo(uri)
conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig()) d := &tls.Dialer{Config: permissiveConfig()}
nc, err := d.DialContext(context.Background(), "tcp", ci.Addr)
if err != nil { if err != nil {
lib.Warn(err, "couldn't connect to %s", ci.Addr) _, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr)
return
}
conn, ok := nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return return
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
conn.Close() if err = conn.Close(); err != nil {
_, _ = lib.Warn(err, "couldn't close TLS connection")
}
conn, err = tls.Dial("tcp", ci.Addr, verifyConfig(ci.Host)) d = &tls.Dialer{Config: verifyConfig(ci.Host)}
nc, err = d.DialContext(context.Background(), "tcp", ci.Addr)
if err == nil { if err == nil {
conn, ok = nc.(*tls.Conn)
if !ok {
_, _ = lib.Warnx("invalid TLS connection (not a *tls.Conn)")
return
}
err = conn.VerifyHostname(ci.Host) err = conn.VerifyHostname(ci.Host)
if err == nil { if err == nil {
state = conn.ConnectionState() state = conn.ConnectionState()
} }
conn.Close() conn.Close()
} else { } else {
lib.Warn(err, "TLS verification error with server name %s", ci.Host) _, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host)
} }
if len(state.PeerCertificates) == 0 { if len(state.PeerCertificates) == 0 {
lib.Warnx("no certificates found") _, _ = lib.Warnx("no certificates found")
return return
} }
@@ -275,14 +295,14 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
if len(state.VerifiedChains) == 0 { if len(state.VerifiedChains) == 0 {
lib.Warnx("no verified chains found; using peer chain") _, _ = lib.Warnx("no verified chains found; using peer chain")
for i := range state.PeerCertificates { for i := range state.PeerCertificates {
displayCert(state.PeerCertificates[i]) displayCert(state.PeerCertificates[i])
} }
} else { } else {
fmt.Println("TLS chain verified successfully.") fmt.Fprintln(os.Stdout, "TLS chain verified successfully.")
for i := range state.VerifiedChains { for i := range state.VerifiedChains {
fmt.Printf("--- Verified certificate chain %d ---\n", i+1) fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n")
for j := range state.VerifiedChains[i] { for j := range state.VerifiedChains[i] {
displayCert(state.VerifiedChains[i][j]) displayCert(state.VerifiedChains[i][j])
} }
@@ -290,6 +310,32 @@ func displayAllCertsWeb(uri string, leafOnly bool) {
} }
} }
func shouldReadStdin(argc int, argv []string) bool {
if argc == 0 {
return true
}
if argc == 1 && argv[0] == "-" {
return true
}
return false
}
func readStdin(leafOnly bool) {
certs, err := io.ReadAll(os.Stdin)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
}
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.ReplaceAll(certs, []byte(`\n`), []byte{0xa})
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
}
func main() { func main() {
var leafOnly bool var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents") flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
@@ -297,27 +343,19 @@ func main() {
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate") flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse() flag.Parse()
if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") { if shouldReadStdin(flag.NArg(), flag.Args()) {
certs, err := io.ReadAll(os.Stdin) readStdin(leafOnly)
if err != nil { return
lib.Warn(err, "couldn't read certificates from standard input")
os.Exit(1)
} }
// This is needed for getting certs from JSON/jq.
certs = bytes.TrimSpace(certs)
certs = bytes.Replace(certs, []byte(`\n`), []byte{0xa}, -1)
certs = bytes.Trim(certs, `"`)
displayAllCerts(certs, leafOnly)
} else {
for _, filename := range flag.Args() { for _, filename := range flag.Args() {
fmt.Printf("--%s ---\n", filename) fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
if strings.HasPrefix(filename, "https://") { if strings.HasPrefix(filename, "https://") {
displayAllCertsWeb(filename, leafOnly) displayAllCertsWeb(filename, leafOnly)
} else { } else {
in, err := os.ReadFile(filename) in, err := os.ReadFile(filename)
if err != nil { if err != nil {
lib.Warn(err, "couldn't read certificate") _, _ = lib.Warn(err, "couldn't read certificate")
continue continue
} }
@@ -325,4 +363,3 @@ func main() {
} }
} }
} }
}

View File

@@ -13,6 +13,11 @@ import (
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\)," // following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,") // "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{ var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature", x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment", x509.KeyUsageContentCommitment: "content committment",
@@ -38,30 +43,24 @@ var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageOCSPSigning: "ocsp signing", x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc", x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc", x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
} x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
func pubKeyAlgo(a x509.PublicKeyAlgorithm) string {
switch a {
case x509.RSA:
return "RSA"
case x509.ECDSA:
return "ECDSA"
case x509.DSA:
return "DSA"
default:
return "unknown public key algorithm"
}
} }
func sigAlgoPK(a x509.SignatureAlgorithm) string { func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a { switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA" return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA" return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256: case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA" return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default: default:
return "unknown public key algorithm" return "unknown public key algorithm"
} }
@@ -76,11 +75,21 @@ func sigAlgoHash(a x509.SignatureAlgorithm) string {
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1: case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1" return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256: case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return "SHA256" return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384: case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384" return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512: case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return "SHA512" return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default: default:
return "unknown hash algorithm" return "unknown hash algorithm"
} }
@@ -90,9 +99,11 @@ const maxLine = 78
func makeIndent(n int) string { func makeIndent(n int) string {
s := " " s := " "
for i := 0; i < n; i++ { var sSb97 strings.Builder
s += " " for range n {
sSb97.WriteString(" ")
} }
s += sSb97.String()
return s return s
} }
@@ -100,7 +111,7 @@ func indentLen(n int) int {
return 4 + (8 * n) return 4 + (8 * n)
} }
// this isn't real efficient, but that's not a problem here // this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string { func wrap(s string, indent int) string {
if indent > 3 { if indent > 3 {
indent = 3 indent = 3
@@ -123,9 +134,11 @@ func wrap(s string, indent int) string {
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb130 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb130.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb130.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -136,14 +149,14 @@ func dumpHex(in []byte) string {
func permissiveConfig() *tls.Config { func permissiveConfig() *tls.Config {
return &tls.Config{ return &tls.Config{
InsecureSkipVerify: true, InsecureSkipVerify: true,
} } // #nosec G402
} }
// verifyConfig returns a config that will verify the connection. // verifyConfig returns a config that will verify the connection.
func verifyConfig(hostname string) *tls.Config { func verifyConfig(hostname string) *tls.Config {
return &tls.Config{ return &tls.Config{
ServerName: hostname, ServerName: hostname,
} } // #nosec G402
} }
type connInfo struct { type connInfo struct {

View File

@@ -5,7 +5,6 @@ import (
"crypto/x509/pkix" "crypto/x509/pkix"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"strings" "strings"
"time" "time"
@@ -54,7 +53,7 @@ func displayName(name pkix.Name) string {
} }
func expires(cert *x509.Certificate) time.Duration { func expires(cert *x509.Certificate) time.Duration {
return cert.NotAfter.Sub(time.Now()) return time.Until(cert.NotAfter)
} }
func inDanger(cert *x509.Certificate) bool { func inDanger(cert *x509.Certificate) bool {
@@ -81,15 +80,15 @@ func main() {
flag.Parse() flag.Parse()
for _, file := range flag.Args() { for _, file := range flag.Args() {
in, err := ioutil.ReadFile(file) in, err := os.ReadFile(file)
if err != nil { if err != nil {
lib.Warn(err, "failed to read file") _, _ = lib.Warn(err, "failed to read file")
continue continue
} }
certs, err := certlib.ParseCertificatesPEM(in) certs, err := certlib.ParseCertificatesPEM(in)
if err != nil { if err != nil {
lib.Warn(err, "while parsing certificates") _, _ = lib.Warn(err, "while parsing certificates")
continue continue
} }

View File

@@ -4,13 +4,11 @@ import (
"crypto/x509" "crypto/x509"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke" "git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
@@ -30,83 +28,116 @@ func printRevocation(cert *x509.Certificate) {
} }
} }
func main() { type appConfig struct {
var caFile, intFile string caFile, intFile string
var forceIntermediateBundle, revexp, verbose bool forceIntermediateBundle bool
flag.StringVar(&caFile, "ca", "", "CA certificate `bundle`") revexp, verbose bool
flag.StringVar(&intFile, "i", "", "intermediate `bundle`") }
flag.BoolVar(&forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") func parseFlags() appConfig {
flag.BoolVar(&revexp, "r", false, "print revocation and expiry information") var cfg appConfig
flag.BoolVar(&verbose, "v", false, "verbose") flag.StringVar(&cfg.caFile, "ca", "", "CA certificate `bundle`")
flag.Parse() flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse()
return cfg
}
func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile == "" {
return x509.SystemCertPool()
}
var roots *x509.CertPool
if caFile != "" {
var err error
if verbose { if verbose {
fmt.Println("[+] loading root certificates from", caFile) fmt.Println("[+] loading root certificates from", caFile)
} }
roots, err = certlib.LoadPEMCertPool(caFile) return certlib.LoadPEMCertPool(caFile)
die.If(err)
} }
var ints *x509.CertPool func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if intFile != "" { if intFile == "" {
var err error return x509.NewCertPool(), nil
}
if verbose { if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile) fmt.Println("[+] loading intermediate certificates from", intFile)
} }
ints, err = certlib.LoadPEMCertPool(caFile) // Note: use intFile here (previously used caFile mistakenly)
die.If(err) return certlib.LoadPEMCertPool(intFile)
} else {
ints = x509.NewCertPool()
} }
if flag.NArg() != 1 { func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert",
lib.ProgName())
}
fileData, err := ioutil.ReadFile(flag.Arg(0))
die.If(err)
chain, err := certlib.ParseCertificatesPEM(fileData)
die.If(err)
if verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 {
if !forceIntermediateBundle {
for _, intermediate := range chain[1:] { for _, intermediate := range chain[1:] {
if verbose { if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId) fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
} }
pool.AddCert(intermediate)
ints.AddCert(intermediate)
}
} }
} }
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
opts := x509.VerifyOptions{ opts := x509.VerifyOptions{
Intermediates: ints, Intermediates: ints,
Roots: roots, Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
} }
_, err := cert.Verify(opts)
_, err = cert.Verify(opts) return err
if err != nil {
fmt.Fprintf(os.Stderr, "Verification failed: %v\n", err)
os.Exit(1)
} }
if verbose { func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil {
return err
}
ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
fileData, err := os.ReadFile(flag.Arg(0))
if err != nil {
return err
}
chain, err := certlib.ParseCertificatesPEM(fileData)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK") fmt.Println("OK")
} }
if revexp { if cfg.revexp {
printRevocation(cert) printRevocation(cert)
} }
return nil
}
func main() {
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
} }

View File

@@ -2,6 +2,8 @@ package main
import ( import (
"bufio" "bufio"
"context"
"errors"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -56,7 +58,7 @@ var modes = ssh.TerminalModes{
} }
func sshAgent() ssh.AuthMethod { func sshAgent() ssh.AuthMethod {
a, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")) a, err := (&net.Dialer{}).DialContext(context.Background(), "unix", os.Getenv("SSH_AUTH_SOCK"))
if err == nil { if err == nil {
return ssh.PublicKeysCallback(agent.NewClient(a).Signers) return ssh.PublicKeysCallback(agent.NewClient(a).Signers)
} }
@@ -82,7 +84,7 @@ func scanner(host string, in io.Reader, out io.Writer) {
} }
} }
func logError(host string, err error, format string, args ...interface{}) { func logError(host string, err error, format string, args ...any) {
msg := fmt.Sprintf(format, args...) msg := fmt.Sprintf(format, args...)
log.Printf("[%s] FAILED: %s: %v\n", host, msg, err) log.Printf("[%s] FAILED: %s: %v\n", host, msg, err)
} }
@@ -93,7 +95,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -115,7 +117,7 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) {
} }
shutdown = append(shutdown, session.Close) shutdown = append(shutdown, session.Close)
if err := session.RequestPty("xterm", 80, 40, modes); err != nil { if err = session.RequestPty("xterm", 80, 40, modes); err != nil {
session.Close() session.Close()
logError(host, err, "request for pty failed") logError(host, err, "request for pty failed")
return return
@@ -150,7 +152,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -199,7 +201,7 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")
@@ -215,7 +217,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
defer func() { defer func() {
for i := len(shutdown) - 1; i >= 0; i-- { for i := len(shutdown) - 1; i >= 0; i-- {
err := shutdown[i]() err := shutdown[i]()
if err != nil && err != io.EOF { if err != nil && !errors.Is(err, io.EOF) {
logError(host, err, "shutting down") logError(host, err, "shutting down")
} }
} }
@@ -265,7 +267,7 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) {
fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) fmt.Printf("[%s] wrote %d-byte chunk\n", host, n)
} }
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} else if err != nil { } else if err != nil {
logError(host, err, "reading chunk") logError(host, err, "reading chunk")

View File

@@ -10,6 +10,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/fileutil" "git.wntrmute.dev/kyle/goutils/fileutil"
@@ -26,7 +27,7 @@ func setupFile(hdr *tar.Header, file *os.File) error {
if verbose { if verbose {
fmt.Printf("\tchmod %0#o\n", hdr.Mode) fmt.Printf("\tchmod %0#o\n", hdr.Mode)
} }
err := file.Chmod(os.FileMode(hdr.Mode)) err := file.Chmod(os.FileMode(hdr.Mode & 0xFFFFFFFF)) // #nosec G115
if err != nil { if err != nil {
return err return err
} }
@@ -48,54 +49,71 @@ func linkTarget(target, top string) string {
return target return target
} }
return filepath.Clean(filepath.Join(target, top)) return filepath.Clean(filepath.Join(top, target))
} }
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error { // safeJoin joins base and elem and ensures the resulting path does not escape base.
if verbose { func safeJoin(base, elem string) (string, error) {
fmt.Println(hdr.Name) cleanBase := filepath.Clean(base)
joined := filepath.Clean(filepath.Join(cleanBase, elem))
absBase, err := filepath.Abs(cleanBase)
if err != nil {
return "", err
} }
filePath := filepath.Clean(filepath.Join(top, hdr.Name)) absJoined, err := filepath.Abs(joined)
switch hdr.Typeflag { if err != nil {
case tar.TypeReg: return "", err
}
rel, err := filepath.Rel(absBase, absJoined)
if err != nil {
return "", err
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("path traversal detected: %s escapes %s", elem, base)
}
return joined, nil
}
func handleTypeReg(tfr *tar.Reader, hdr *tar.Header, filePath string) error {
file, err := os.Create(filePath) file, err := os.Create(filePath)
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
_, err = io.Copy(file, tfr) if _, err = io.Copy(file, tfr); err != nil {
if err != nil {
return err return err
} }
return setupFile(hdr, file)
err = setupFile(hdr, file)
if err != nil {
return err
} }
case tar.TypeLink:
func handleTypeLink(hdr *tar.Header, top, filePath string) error {
file, err := os.Create(filePath) file, err := os.Create(filePath)
if err != nil { if err != nil {
return err return err
} }
defer file.Close()
source, err := os.Open(hdr.Linkname) srcPath, err := safeJoin(top, hdr.Linkname)
if err != nil { if err != nil {
return err return err
} }
source, err := os.Open(srcPath)
_, err = io.Copy(file, source)
if err != nil { if err != nil {
return err return err
} }
defer source.Close()
err = setupFile(hdr, file) if _, err = io.Copy(file, source); err != nil {
if err != nil {
return err return err
} }
case tar.TypeSymlink: return setupFile(hdr, file)
}
func handleTypeSymlink(hdr *tar.Header, top, filePath string) error {
if !fileutil.ValidateSymlink(hdr.Linkname, top) { if !fileutil.ValidateSymlink(hdr.Linkname, top) {
return fmt.Errorf("symlink %s is outside the top-level %s", return fmt.Errorf("symlink %s is outside the top-level %s", hdr.Linkname, top)
hdr.Linkname, top)
} }
path := linkTarget(hdr.Linkname, top) path := linkTarget(hdr.Linkname, top)
if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok { if ok, err := filepath.Match(top+"/*", filepath.Clean(path)); !ok {
@@ -103,18 +121,33 @@ func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
} else if err != nil { } else if err != nil {
return err return err
} }
return os.Symlink(linkTarget(hdr.Linkname, top), filePath)
}
err := os.Symlink(linkTarget(hdr.Linkname, top), filePath) func handleTypeDir(hdr *tar.Header, filePath string) error {
return os.MkdirAll(filePath, os.FileMode(hdr.Mode&0xFFFFFFFF)) // #nosec G115
}
func processFile(tfr *tar.Reader, hdr *tar.Header, top string) error {
if verbose {
fmt.Println(hdr.Name)
}
filePath, err := safeJoin(top, hdr.Name)
if err != nil { if err != nil {
return err return err
} }
switch hdr.Typeflag {
case tar.TypeReg:
return handleTypeReg(tfr, hdr, filePath)
case tar.TypeLink:
return handleTypeLink(hdr, top, filePath)
case tar.TypeSymlink:
return handleTypeSymlink(hdr, top, filePath)
case tar.TypeDir: case tar.TypeDir:
err := os.MkdirAll(filePath, os.FileMode(hdr.Mode)) return handleTypeDir(hdr, filePath)
if err != nil {
return err
} }
}
return nil return nil
} }
@@ -261,16 +294,16 @@ func main() {
die.If(err) die.If(err)
tfr := tar.NewReader(r) tfr := tar.NewReader(r)
var hdr *tar.Header
for { for {
hdr, err := tfr.Next() hdr, err = tfr.Next()
if err == io.EOF { if errors.Is(err, io.EOF) {
break break
} }
die.If(err) die.If(err)
err = processFile(tfr, hdr, top) err = processFile(tfr, hdr, top)
die.If(err) die.If(err)
} }
r.Close() r.Close()

View File

@@ -7,8 +7,7 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "os"
"log"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
@@ -17,12 +16,12 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
in, err := ioutil.ReadFile(fileName) in, err := os.ReadFile(fileName)
die.If(err) die.If(err)
if p, _ := pem.Decode(in); p != nil { if p, _ := pem.Decode(in); p != nil {
if p.Type != "CERTIFICATE REQUEST" { if p.Type != "CERTIFICATE REQUEST" {
log.Fatal("INVALID FILE TYPE") die.With("INVALID FILE TYPE")
} }
in = p.Bytes in = p.Bytes
} }
@@ -48,8 +47,8 @@ func main() {
Bytes: out, Bytes: out,
} }
err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644) err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) // #nosec G306
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.\n", fileName+".pub") fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub")
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -152,7 +153,7 @@ func rsync(syncDir, target, excludeFile string, verboseRsync bool) error {
return err return err
} }
cmd := exec.Command(path, args...) cmd := exec.CommandContext(context.Background(), path, args...)
cmd.Stdout = os.Stdout cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr cmd.Stderr = os.Stderr
return cmd.Run() return cmd.Run()
@@ -163,7 +164,6 @@ func init() {
} }
func main() { func main() {
var logLevel, mountDir, syncDir, target string var logLevel, mountDir, syncDir, target string
var dryRun, quietMode, noSyslog, verboseRsync bool var dryRun, quietMode, noSyslog, verboseRsync bool
@@ -219,7 +219,7 @@ func main() {
if excludeFile != "" { if excludeFile != "" {
defer func() { defer func() {
log.Infof("removing exclude file %s", excludeFile) log.Infof("removing exclude file %s", excludeFile)
if err := os.Remove(excludeFile); err != nil { if rmErr := os.Remove(excludeFile); rmErr != nil {
log.Warningf("failed to remove temp file %s", excludeFile) log.Warningf("failed to remove temp file %s", excludeFile)
} }
}() }()

View File

@@ -19,39 +19,37 @@ var (
debug = dbg.New() debug = dbg.New()
) )
func openImage(imageFile string) (*os.File, []byte, error) {
func openImage(imageFile string) (image *os.File, hash []byte, err error) { f, err := os.Open(imageFile)
image, err = os.Open(imageFile)
if err != nil { if err != nil {
return return nil, nil, err
} }
hash, err = ahash.SumReader(hAlgo, image) h, err := ahash.SumReader(hAlgo, f)
if err != nil { if err != nil {
return return nil, nil, err
} }
_, err = image.Seek(0, 0) if _, err = f.Seek(0, 0); err != nil {
if err != nil { return nil, nil, err
return
} }
debug.Printf("%s %x\n", imageFile, hash) debug.Printf("%s %x\n", imageFile, h)
return return f, h, nil
} }
func openDevice(devicePath string) (device *os.File, err error) { func openDevice(devicePath string) (*os.File, error) {
fi, err := os.Stat(devicePath) fi, err := os.Stat(devicePath)
if err != nil { if err != nil {
return return nil, err
} }
device, err = os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode()) device, err := os.OpenFile(devicePath, os.O_RDWR|os.O_SYNC, fi.Mode())
if err != nil { if err != nil {
return return nil, err
} }
return return device, nil
} }
func main() { func main() {
@@ -105,12 +103,12 @@ func main() {
die.If(err) die.If(err)
if !bytes.Equal(deviceHash, hash) { if !bytes.Equal(deviceHash, hash) {
fmt.Fprintln(os.Stderr, "Hash mismatch:") buf := &bytes.Buffer{}
fmt.Fprintf(os.Stderr, "\t%s: %s\n", imageFile, hash) fmt.Fprintln(buf, "Hash mismatch:")
fmt.Fprintf(os.Stderr, "\t%s: %s\n", devicePath, deviceHash) fmt.Fprintf(buf, "\t%s: %s\n", imageFile, hash)
os.Exit(1) fmt.Fprintf(buf, "\t%s: %s\n", devicePath, deviceHash)
die.With(buf.String())
} }
debug.Println("OK") debug.Println("OK")
os.Exit(0)
} }

View File

@@ -1,30 +1,33 @@
package main package main
import ( import (
"errors"
"flag" "flag"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/die"
"io" "io"
"os" "os"
"strings"
"git.wntrmute.dev/kyle/goutils/die"
) )
func usage(w io.Writer, exc int) { func usage(w io.Writer, exc int) {
fmt.Fprintln(w, `usage: dumpbytes <file>`) fmt.Fprintln(w, `usage: dumpbytes -n tabs <file>`)
os.Exit(exc) os.Exit(exc)
} }
func printBytes(buf []byte) { func printBytes(buf []byte) {
fmt.Printf("\t") fmt.Printf("\t")
for i := 0; i < len(buf); i++ { for i := range buf {
fmt.Printf("0x%02x, ", buf[i]) fmt.Printf("0x%02x, ", buf[i])
} }
fmt.Println() fmt.Println()
} }
func dumpFile(path string, indentLevel int) error { func dumpFile(path string, indentLevel int) error {
indent := "" var indent strings.Builder
for i := 0; i < indentLevel; i++ { for range indentLevel {
indent += "\t" indent.WriteByte('\t')
} }
file, err := os.Open(path) file, err := os.Open(path)
@@ -34,13 +37,14 @@ func dumpFile(path string, indentLevel int) error {
defer file.Close() defer file.Close()
fmt.Printf("%svar buffer = []byte{\n", indent) fmt.Printf("%svar buffer = []byte{\n", indent.String())
var n int
for { for {
buf := make([]byte, 8) buf := make([]byte, 8)
n, err := file.Read(buf) n, err = file.Read(buf)
if err == io.EOF { if errors.Is(err, io.EOF) {
if n > 0 { if n > 0 {
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
break break
@@ -50,11 +54,11 @@ func dumpFile(path string, indentLevel int) error {
return err return err
} }
fmt.Printf("%s", indent) fmt.Printf("%s", indent.String())
printBytes(buf[:n]) printBytes(buf[:n])
} }
fmt.Printf("%s}\n", indent) fmt.Printf("%s}\n", indent.String())
return nil return nil
} }

View File

@@ -7,7 +7,7 @@ import (
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
) )
// size of a kilobit in bytes // size of a kilobit in bytes.
const kilobit = 128 const kilobit = 128
const pageSize = 4096 const pageSize = 4096
@@ -26,10 +26,10 @@ func main() {
path = flag.Arg(0) path = flag.Arg(0)
} }
fillByte := uint8(*fill) fillByte := uint8(*fill & 0xff) // #nosec G115 clearing out of bounds bits
buf := make([]byte, pageSize) buf := make([]byte, pageSize)
for i := 0; i < pageSize; i++ { for i := range pageSize {
buf[i] = fillByte buf[i] = fillByte
} }
@@ -40,7 +40,7 @@ func main() {
die.If(err) die.If(err)
defer file.Close() defer file.Close()
for i := 0; i < pages; i++ { for range pages {
_, err = file.Write(buf) _, err = file.Write(buf)
die.If(err) die.If(err)
} }

View File

@@ -72,15 +72,13 @@ func main() {
if end < start { if end < start {
fmt.Fprintln(os.Stderr, "[!] end < start, swapping values") fmt.Fprintln(os.Stderr, "[!] end < start, swapping values")
tmp := end start, end = end, start
end = start
start = tmp
} }
var fmtStr string var fmtStr string
if !*quiet { if !*quiet {
maxLine := fmt.Sprintf("%d", len(lines)) maxLine := strconv.Itoa(len(lines))
fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine)) fmtStr = fmt.Sprintf("%%0%dd: %%s", len(maxLine))
} }
@@ -98,9 +96,9 @@ func main() {
fmtStr += "\n" fmtStr += "\n"
for i := start; !endFunc(i); i++ { for i := start; !endFunc(i); i++ {
if *quiet { if *quiet {
fmt.Println(lines[i]) fmt.Fprintln(os.Stdout, lines[i])
} else { } else {
fmt.Printf(fmtStr, i, lines[i]) fmt.Fprintf(os.Stdout, fmtStr, i, lines[i])
} }
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"log" "log"
@@ -8,7 +9,8 @@ import (
) )
func lookupHost(host string) error { func lookupHost(host string) error {
cname, err := net.LookupCNAME(host) r := &net.Resolver{}
cname, err := r.LookupCNAME(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }
@@ -18,7 +20,7 @@ func lookupHost(host string) error {
host = cname host = cname
} }
addrs, err := net.LookupHost(host) addrs, err := r.LookupHost(context.Background(), host)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -5,7 +5,7 @@ import (
"encoding/json" "encoding/json"
"flag" "flag"
"fmt" "fmt"
"io/ioutil" "io"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -16,20 +16,20 @@ func prettify(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Indent(buf, in, "", " ") err = json.Indent(buf, in, "", " ")
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -40,11 +40,11 @@ func prettify(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -55,20 +55,20 @@ func compact(file string, validateOnly bool) error {
var err error var err error
if file == "-" { if file == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(file) in, err = os.ReadFile(file)
} }
if err != nil { if err != nil {
lib.Warn(err, "ReadFile") _, _ = lib.Warn(err, "ReadFile")
return err return err
} }
var buf = &bytes.Buffer{} var buf = &bytes.Buffer{}
err = json.Compact(buf, in) err = json.Compact(buf, in)
if err != nil { if err != nil {
lib.Warn(err, "%s", file) _, _ = lib.Warn(err, "%s", file)
return err return err
} }
@@ -79,11 +79,11 @@ func compact(file string, validateOnly bool) error {
if file == "-" { if file == "-" {
_, err = os.Stdout.Write(buf.Bytes()) _, err = os.Stdout.Write(buf.Bytes())
} else { } else {
err = ioutil.WriteFile(file, buf.Bytes(), 0644) err = os.WriteFile(file, buf.Bytes(), 0o644)
} }
if err != nil { if err != nil {
lib.Warn(err, "WriteFile") _, _ = lib.Warn(err, "WriteFile")
} }
return err return err
@@ -91,7 +91,7 @@ func compact(file string, validateOnly bool) error {
func usage() { func usage() {
progname := lib.ProgName() progname := lib.ProgName()
fmt.Printf(`Usage: %s [-h] files... fmt.Fprintf(os.Stdout, `Usage: %s [-h] files...
%s is used to lint and prettify (or compact) JSON files. The %s is used to lint and prettify (or compact) JSON files. The
files will be updated in-place. files will be updated in-place.
@@ -100,7 +100,6 @@ func usage() {
-h Print this help message. -h Print this help message.
-n Don't prettify; only perform validation. -n Don't prettify; only perform validation.
`, progname, progname) `, progname, progname)
} }
func init() { func init() {

View File

@@ -12,6 +12,9 @@ based on whether the source filename ends in ".gz".
Flags: Flags:
-l level Compression level (0-9). Only meaninful when -l level Compression level (0-9). Only meaninful when
compressing a file. compressing a file.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.

View File

@@ -40,26 +40,42 @@ func compress(path, target string, level int) error {
return nil return nil
} }
func uncompress(path, target string) error { func uncompress(path, target string, unrestrict bool) error {
sourceFile, err := os.Open(path) sourceFile, err := os.Open(path)
if err != nil { if err != nil {
return fmt.Errorf("opening file for read: %w", err) return fmt.Errorf("opening file for read: %w", err)
} }
defer sourceFile.Close() defer sourceFile.Close()
fi, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("reading file stats: %w", err)
}
maxDecompressionSize := fi.Size() * 32
gzipUncompressor, err := gzip.NewReader(sourceFile) gzipUncompressor, err := gzip.NewReader(sourceFile)
if err != nil { if err != nil {
return fmt.Errorf("reading gzip headers: %w", err) return fmt.Errorf("reading gzip headers: %w", err)
} }
defer gzipUncompressor.Close() defer gzipUncompressor.Close()
var reader io.Reader = &io.LimitedReader{
R: gzipUncompressor,
N: maxDecompressionSize,
}
if unrestrict {
reader = gzipUncompressor
}
destFile, err := os.Create(target) destFile, err := os.Create(target)
if err != nil { if err != nil {
return fmt.Errorf("opening file for write: %w", err) return fmt.Errorf("opening file for write: %w", err)
} }
defer destFile.Close() defer destFile.Close()
_, err = io.Copy(destFile, gzipUncompressor) _, err = io.Copy(destFile, reader)
if err != nil { if err != nil {
return fmt.Errorf("uncompressing file: %w", err) return fmt.Errorf("uncompressing file: %w", err)
} }
@@ -87,8 +103,8 @@ func isDir(path string) bool {
file, err := os.Open(path) file, err := os.Open(path)
if err == nil { if err == nil {
defer file.Close() defer file.Close()
stat, err := file.Stat() stat, err2 := file.Stat()
if err != nil { if err2 != nil {
return false return false
} }
@@ -132,8 +148,11 @@ func main() {
var level int var level int
var path string var path string
var target = "." var target = "."
var err error
var unrestrict bool
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level") flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.Parse() flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 { if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -147,20 +166,22 @@ func main() {
} }
if strings.HasSuffix(path, gzipExt) { if strings.HasSuffix(path, gzipExt) {
target, err := pathForUncompressing(path, target) target, err = pathForUncompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
err = uncompress(path, target) err = uncompress(path, target, unrestrict)
if err != nil { if err != nil {
os.Remove(target) os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
} }
} else { return
target, err := pathForCompressing(path, target) }
target, err = pathForCompressing(path, target)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err) fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1) os.Exit(1)
@@ -173,4 +194,3 @@ func main() {
os.Exit(1) os.Exit(1)
} }
} }
}

View File

@@ -40,14 +40,14 @@ func main() {
usage() usage()
} }
min, err := strconv.Atoi(flag.Arg(1)) minVal, err := strconv.Atoi(flag.Arg(1))
dieIf(err) dieIf(err)
max, err := strconv.Atoi(flag.Arg(2)) maxVal, err := strconv.Atoi(flag.Arg(2))
dieIf(err) dieIf(err)
code := kind << 6 code := kind << 6
code += (min << 3) code += (minVal << 3)
code += max code += maxVal
fmt.Printf("%0o\n", code) fmt.Fprintf(os.Stdout, "%0o\n", code)
} }

View File

@@ -5,7 +5,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"sort" "sort"
@@ -47,7 +46,7 @@ func help(w io.Writer) {
} }
func loadDatabase() { func loadDatabase() {
data, err := ioutil.ReadFile(dbFile) data, err := os.ReadFile(dbFile)
if err != nil && os.IsNotExist(err) { if err != nil && os.IsNotExist(err) {
partsDB = &database{ partsDB = &database{
Version: dbVersion, Version: dbVersion,
@@ -74,7 +73,7 @@ func writeDB() {
data, err := json.Marshal(partsDB) data, err := json.Marshal(partsDB)
die.If(err) die.If(err)
err = ioutil.WriteFile(dbFile, data, 0644) err = os.WriteFile(dbFile, data, 0644)
die.If(err) die.If(err)
} }

View File

@@ -4,14 +4,13 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
var ext = ".bin" var ext = ".bin"
func stripPEM(path string) error { func stripPEM(path string) error {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
if err != nil { if err != nil {
return err return err
} }
@@ -22,7 +21,7 @@ func stripPEM(path string) error {
fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n") fmt.Fprintf(os.Stderr, " (only the first object will be decoded)\n")
} }
return ioutil.WriteFile(path+ext, p.Bytes, 0644) return os.WriteFile(path+ext, p.Bytes, 0644)
} }
func main() { func main() {

View File

@@ -3,8 +3,7 @@ package main
import ( import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "io"
"io/ioutil"
"os" "os"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
@@ -21,9 +20,9 @@ func main() {
path := flag.Arg(0) path := flag.Arg(0)
if path == "-" { if path == "-" {
in, err = ioutil.ReadAll(os.Stdin) in, err = io.ReadAll(os.Stdin)
} else { } else {
in, err = ioutil.ReadFile(flag.Arg(0)) in, err = os.ReadFile(flag.Arg(0))
} }
if err != nil { if err != nil {
lib.Err(lib.ExitFailure, err, "couldn't read file") lib.Err(lib.ExitFailure, err, "couldn't read file")
@@ -33,5 +32,7 @@ func main() {
if p == nil { if p == nil {
lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0))
} }
fmt.Printf("%s", p.Bytes) if _, err = os.Stdout.Write(p.Bytes); err != nil {
lib.Err(lib.ExitFailure, err, "writing body")
}
} }

View File

@@ -70,7 +70,7 @@ func main() {
lib.Err(lib.ExitFailure, err, "failed to read input") lib.Err(lib.ExitFailure, err, "failed to read input")
} }
case argc > 1: case argc > 1:
for i := 0; i < argc; i++ { for i := range argc {
path := flag.Arg(i) path := flag.Arg(i)
err = copyFile(path, buf) err = copyFile(path, buf)
if err != nil { if err != nil {

View File

@@ -5,7 +5,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"os" "os"
) )
@@ -13,14 +12,14 @@ func main() {
flag.Parse() flag.Parse()
for _, fileName := range flag.Args() { for _, fileName := range flag.Args() {
data, err := ioutil.ReadFile(fileName) data, err := os.ReadFile(fileName)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
continue continue
} }
fmt.Printf("[+] %s:\n", fileName) fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName)
rest := data[:] rest := data
for { for {
var p *pem.Block var p *pem.Block
p, rest = pem.Decode(rest) p, rest = pem.Decode(rest)
@@ -28,13 +27,14 @@ func main() {
break break
} }
cert, err := x509.ParseCertificate(p.Bytes) var cert *x509.Certificate
cert, err = x509.ParseCertificate(p.Bytes)
if err != nil { if err != nil {
fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err)
break break
} }
fmt.Printf("\t%+v\n", cert.Subject.CommonName) fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName)
} }
} }
} }

View File

@@ -43,7 +43,7 @@ func newName(path string) (string, error) {
return hashName(path, encodedHash), nil return hashName(path, encodedHash), nil
} }
func move(dst, src string, force bool) (err error) { func move(dst, src string, force bool) error {
if fileutil.FileDoesExist(dst) && !force { if fileutil.FileDoesExist(dst) && !force {
return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst) return fmt.Errorf("%s exists (pass the -f flag to overwrite)", dst)
} }
@@ -52,21 +52,23 @@ func move(dst, src string, force bool) (err error) {
return err return err
} }
defer func(e error) { var retErr error
defer func(e *error) {
dstFile.Close() dstFile.Close()
if e != nil { if *e != nil {
os.Remove(dst) os.Remove(dst)
} }
}(err) }(&retErr)
srcFile, err := os.Open(src) srcFile, err := os.Open(src)
if err != nil { if err != nil {
retErr = err
return err return err
} }
defer srcFile.Close() defer srcFile.Close()
_, err = io.Copy(dstFile, srcFile) if _, err = io.Copy(dstFile, srcFile); err != nil {
if err != nil { retErr = err
return err return err
} }
@@ -94,6 +96,44 @@ func init() {
flag.Usage = func() { usage(os.Stdout) } flag.Usage = func() { usage(os.Stdout) }
} }
type options struct {
dryRun, force, printChanged, verbose bool
}
func processOne(file string, opt options) error {
renamed, err := newName(file)
if err != nil {
_, _ = lib.Warn(err, "failed to get new file name")
return err
}
if opt.verbose && !opt.printChanged {
fmt.Fprintln(os.Stdout, file)
}
if renamed == file {
return nil
}
if !opt.dryRun {
if err = move(renamed, file, opt.force); err != nil {
_, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
return err
}
}
if opt.printChanged && !opt.verbose {
fmt.Fprintln(os.Stdout, file, "->", renamed)
}
return nil
}
func run(dryRun, force, printChanged, verbose bool, files []string) {
if verbose && printChanged {
printChanged = false
}
opt := options{dryRun: dryRun, force: force, printChanged: printChanged, verbose: verbose}
for _, file := range files {
_ = processOne(file, opt)
}
}
func main() { func main() {
var dryRun, force, printChanged, verbose bool var dryRun, force, printChanged, verbose bool
flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision") flag.BoolVar(&force, "f", false, "force overwriting of files if there is a collision")
@@ -102,34 +142,5 @@ func main() {
flag.BoolVar(&verbose, "v", false, "list all processed files") flag.BoolVar(&verbose, "v", false, "list all processed files")
flag.Parse() flag.Parse()
run(dryRun, force, printChanged, verbose, flag.Args())
if verbose && printChanged {
printChanged = false
}
for _, file := range flag.Args() {
renamed, err := newName(file)
if err != nil {
lib.Warn(err, "failed to get new file name")
continue
}
if verbose && !printChanged {
fmt.Println(file)
}
if renamed != file {
if !dryRun {
err = move(renamed, file, force)
if err != nil {
lib.Warn(err, "failed to rename file from %s to %s", file, renamed)
continue
}
}
if printChanged && !verbose {
fmt.Println(file, "->", renamed)
}
}
}
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"flag" "flag"
"fmt" "fmt"
"io" "io"
@@ -66,24 +67,25 @@ func main() {
for _, remote := range flag.Args() { for _, remote := range flag.Args() {
u, err := url.Parse(remote) u, err := url.Parse(remote)
if err != nil { if err != nil {
lib.Warn(err, "parsing %s", remote) _, _ = lib.Warn(err, "parsing %s", remote)
continue continue
} }
name := filepath.Base(u.Path) name := filepath.Base(u.Path)
if name == "" { if name == "" {
lib.Warnx("source URL doesn't appear to name a file") _, _ = lib.Warnx("source URL doesn't appear to name a file")
continue continue
} }
resp, err := http.Get(remote) req, reqErr := http.NewRequestWithContext(context.Background(), http.MethodGet, remote, nil)
if err != nil { if reqErr != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(reqErr, "building request for %s", remote)
continue continue
} }
client := &http.Client{}
resp, err := client.Do(req)
if err != nil { if err != nil {
lib.Warn(err, "fetching %s", remote) _, _ = lib.Warn(err, "fetching %s", remote)
continue continue
} }

View File

@@ -3,7 +3,7 @@ package main
import ( import (
"flag" "flag"
"fmt" "fmt"
"math/rand" "math/rand/v2"
"os" "os"
"regexp" "regexp"
"strconv" "strconv"
@@ -17,8 +17,8 @@ func rollDie(count, sides int) []int {
sum := 0 sum := 0
var rolls []int var rolls []int
for i := 0; i < count; i++ { for range count {
roll := rand.Intn(sides) + 1 roll := rand.IntN(sides) + 1 // #nosec G404
sum += roll sum += roll
rolls = append(rolls, roll) rolls = append(rolls, roll)
} }

View File

@@ -53,7 +53,7 @@ func init() {
project = wd[len(gopath):] project = wd[len(gopath):]
} }
func walkFile(path string, info os.FileInfo, err error) error { func walkFile(path string, _ os.FileInfo, err error) error {
if ignores[path] { if ignores[path] {
return filepath.SkipDir return filepath.SkipDir
} }
@@ -62,22 +62,27 @@ func walkFile(path string, info os.FileInfo, err error) error {
return nil return nil
} }
debug.Println(path)
f, err := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err != nil { if err != nil {
return err return err
} }
debug.Println(path)
f, err2 := parser.ParseFile(fset, path, nil, parser.ImportsOnly)
if err2 != nil {
return err2
}
for _, importSpec := range f.Imports { for _, importSpec := range f.Imports {
importPath := strings.Trim(importSpec.Path.Value, `"`) importPath := strings.Trim(importSpec.Path.Value, `"`)
if stdLibRegexp.MatchString(importPath) { switch {
case stdLibRegexp.MatchString(importPath):
debug.Println("standard lib:", importPath) debug.Println("standard lib:", importPath)
continue continue
} else if strings.HasPrefix(importPath, project) { case strings.HasPrefix(importPath, project):
debug.Println("internal import:", importPath) debug.Println("internal import:", importPath)
continue continue
} else if strings.HasPrefix(importPath, "golang.org/") { case strings.HasPrefix(importPath, "golang.org/"):
debug.Println("extended lib:", importPath) debug.Println("extended lib:", importPath)
continue continue
} }
@@ -102,7 +107,7 @@ func main() {
ignores["vendor"] = true ignores["vendor"] = true
} }
for _, word := range strings.Split(ignoreLine, ",") { for word := range strings.SplitSeq(ignoreLine, ",") {
ignores[strings.TrimSpace(word)] = true ignores[strings.TrimSpace(word)] = true
} }

View File

@@ -5,7 +5,7 @@ import (
"crypto" "crypto"
"crypto/ecdsa" "crypto/ecdsa"
"crypto/rsa" "crypto/rsa"
"crypto/sha1" "crypto/sha1" // #nosec G505
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/asn1" "encoding/asn1"
@@ -13,7 +13,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strings" "strings"
@@ -21,6 +20,11 @@ import (
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) { func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -39,14 +43,14 @@ func init() {
flag.Usage = func() { usage(os.Stderr) } flag.Usage = func() { usage(os.Stderr) }
} }
func parse(path string) (public []byte, kt, ft string) { func parse(path string) ([]byte, string, string) {
data, err := ioutil.ReadFile(path) data, err := os.ReadFile(path)
die.If(err) die.If(err)
data = bytes.TrimSpace(data) data = bytes.TrimSpace(data)
p, rest := pem.Decode(data) p, rest := pem.Decode(data)
if len(rest) > 0 { if len(rest) > 0 {
lib.Warnx("trailing data in PEM file") _, _ = lib.Warnx("trailing data in PEM file")
} }
if p == nil { if p == nil {
@@ -55,6 +59,12 @@ func parse(path string) (public []byte, kt, ft string) {
data = p.Bytes data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type { switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY": case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data) public, kt = parseKey(data)
@@ -69,10 +79,10 @@ func parse(path string) (public []byte, kt, ft string) {
die.With("unknown PEM type %s", p.Type) die.With("unknown PEM type %s", p.Type)
} }
return return public, kt, ft
} }
func parseKey(data []byte) (public []byte, kt string) { func parseKey(data []byte) ([]byte, string) {
privInterface, err := x509.ParsePKCS8PrivateKey(data) privInterface, err := x509.ParsePKCS8PrivateKey(data)
if err != nil { if err != nil {
privInterface, err = x509.ParsePKCS1PrivateKey(data) privInterface, err = x509.ParsePKCS1PrivateKey(data)
@@ -85,66 +95,71 @@ func parseKey(data []byte) (public []byte, kt string) {
} }
var priv crypto.Signer var priv crypto.Signer
switch privInterface.(type) { var kt string
switch p := privInterface.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
priv = privInterface.(*rsa.PrivateKey) priv = p
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
priv = privInterface.(*ecdsa.PrivateKey) priv = p
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown private key type %T", privInterface) die.With("unknown private key type %T", privInterface)
} }
public, err = x509.MarshalPKIXPublicKey(priv.Public()) public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err) die.If(err)
return return public, kt
} }
func parseCertificate(data []byte) (public []byte, kt string) { func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data) cert, err := x509.ParseCertificate(data)
die.If(err) die.If(err)
pub := cert.PublicKey pub := cert.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func parseCSR(data []byte) (public []byte, kt string) { func parseCSR(data []byte) ([]byte, string) {
csr, err := x509.ParseCertificateRequest(data) csr, err := x509.ParseCertificateRequest(data)
die.If(err) die.If(err)
pub := csr.PublicKey pub := csr.PublicKey
var kt string
switch pub.(type) { switch pub.(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
kt = "RSA" kt = keyTypeRSA
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
kt = "ECDSA" kt = keyTypeECDSA
default: default:
die.With("unknown public key type %T", pub) die.With("unknown public key type %T", pub)
} }
public, err = x509.MarshalPKIXPublicKey(pub) public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err) die.If(err)
return return public, kt
} }
func dumpHex(in []byte) string { func dumpHex(in []byte) string {
var s string var s string
var sSb153 strings.Builder
for i := range in { for i := range in {
s += fmt.Sprintf("%02X:", in[i]) sSb153.WriteString(fmt.Sprintf("%02X:", in[i]))
} }
s += sSb153.String()
return strings.Trim(s, ":") return strings.Trim(s, ":")
} }
@@ -172,18 +187,18 @@ func main() {
var subPKI subjectPublicKeyInfo var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI) _, err := asn1.Unmarshal(public, &subPKI)
if err != nil { if err != nil {
lib.Warn(err, "failed to get subject PKI") _, _ = lib.Warn(err, "failed to get subject PKI")
continue continue
} }
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:]) pubHashString := dumpHex(pubHash[:])
if ski == "" { if ski == "" {
ski = pubHashString ski = pubHashString
} }
if shouldMatch && ski != pubHashString { if shouldMatch && ski != pubHashString {
lib.Warnx("%s: SKI mismatch (%s != %s)", _, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString) path, ski, pubHashString)
} }
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft) fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)

View File

@@ -1,16 +1,17 @@
package main package main
import ( import (
"context"
"flag" "flag"
"io" "io"
"log"
"net" "net"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
) )
func proxy(conn net.Conn, inside string) error { func proxy(conn net.Conn, inside string) error {
proxyConn, err := net.Dial("tcp", inside) proxyConn, err := (&net.Dialer{}).DialContext(context.Background(), "tcp", inside)
if err != nil { if err != nil {
return err return err
} }
@@ -19,7 +20,7 @@ func proxy(conn net.Conn, inside string) error {
defer conn.Close() defer conn.Close()
go func() { go func() {
io.Copy(conn, proxyConn) _, _ = io.Copy(conn, proxyConn)
}() }()
_, err = io.Copy(proxyConn, conn) _, err = io.Copy(proxyConn, conn)
return err return err
@@ -31,16 +32,22 @@ func main() {
flag.StringVar(&inside, "p", "4000", "inside port") flag.StringVar(&inside, "p", "4000", "inside port")
flag.Parse() flag.Parse()
l, err := net.Listen("tcp", "0.0.0.0:"+outside) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", "0.0.0.0:"+outside)
die.If(err) die.If(err)
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
log.Println(err) _, _ = lib.Warn(err, "accept failed")
continue continue
} }
go proxy(conn, "127.0.0.1:"+inside) go func() {
if err = proxy(conn, "127.0.0.1:"+inside); err != nil {
_, _ = lib.Warn(err, "proxy error")
}
}()
} }
} }

View File

@@ -1,6 +1,7 @@
package main package main
import ( import (
"context"
"crypto/rand" "crypto/rand"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
@@ -8,7 +9,6 @@ import (
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -16,7 +16,7 @@ import (
) )
func main() { func main() {
cfg := &tls.Config{} cfg := &tls.Config{} // #nosec G402
var sysRoot, listenAddr, certFile, keyFile string var sysRoot, listenAddr, certFile, keyFile string
var verify bool var verify bool
@@ -47,7 +47,8 @@ func main() {
} }
cfg.Certificates = append(cfg.Certificates, cert) cfg.Certificates = append(cfg.Certificates, cert)
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) var pemList []byte
pemList, err = os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -59,48 +60,54 @@ func main() {
cfg.RootCAs = roots cfg.RootCAs = roots
} }
l, err := net.Listen("tcp", listenAddr) lc := &net.ListenConfig{}
l, err := lc.Listen(context.Background(), "tcp", listenAddr)
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
os.Exit(1) os.Exit(1)
} }
for { for {
conn, err := l.Accept() var conn net.Conn
conn, err = l.Accept()
if err != nil { if err != nil {
fmt.Println(err.Error()) fmt.Println(err.Error())
continue
}
handleConn(conn, cfg)
}
} }
// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file.
func handleConn(conn net.Conn, cfg *tls.Config) {
defer conn.Close()
raddr := conn.RemoteAddr() raddr := conn.RemoteAddr()
tconn := tls.Server(conn, cfg) tconn := tls.Server(conn, cfg)
err = tconn.Handshake() if err := tconn.HandshakeContext(context.Background()); err != nil {
if err != nil {
fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err) fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err)
continue return
} }
cs := tconn.ConnectionState() cs := tconn.ConnectionState()
if len(cs.PeerCertificates) == 0 { if len(cs.PeerCertificates) == 0 {
fmt.Printf("[+] %v: no chain presented\n", raddr) fmt.Printf("[+] %v: no chain presented\n", raddr)
continue return
} }
var chain []byte var chain []byte
for _, cert := range cs.PeerCertificates { for _, cert := range cs.PeerCertificates {
p := &pem.Block{ p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw}
Type: "CERTIFICATE",
Bytes: cert.Raw,
}
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
var nonce [16]byte var nonce [16]byte
_, err = rand.Read(nonce[:]) if _, err := rand.Read(nonce[:]); err != nil {
if err != nil { fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err)
panic(err) return
} }
fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:])) fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:]))
err = ioutil.WriteFile(fname, chain, 0644) if err := os.WriteFile(fname, chain, 0o644); err != nil {
die.If(err) fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err)
return
}
fmt.Printf("%v: [+] wrote %v.\n", raddr, fname) fmt.Printf("%v: [+] wrote %v.\n", raddr, fname)
} }
}

View File

@@ -1,12 +1,12 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"net" "net"
"os" "os"
@@ -14,7 +14,7 @@ import (
) )
func main() { func main() {
var cfg = &tls.Config{} var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string var sysRoot, serverName string
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
@@ -23,7 +23,7 @@ func main() {
flag.Parse() flag.Parse()
if sysRoot != "" { if sysRoot != "" {
pemList, err := ioutil.ReadFile(sysRoot) pemList, err := os.ReadFile(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool() roots := x509.NewCertPool()
@@ -44,10 +44,13 @@ func main() {
if err != nil { if err != nil {
site += ":443" site += ":443"
} }
conn, err := tls.Dial("tcp", site, cfg) d := &tls.Dialer{Config: cfg}
if err != nil { nc, err := d.DialContext(context.Background(), "tcp", site)
fmt.Println(err.Error()) die.If(err)
os.Exit(1)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
cs := conn.ConnectionState() cs := conn.ConnectionState()
@@ -61,8 +64,9 @@ func main() {
chain = append(chain, pem.EncodeToMemory(p)...) chain = append(chain, pem.EncodeToMemory(p)...)
} }
err = ioutil.WriteFile(site+".pem", chain, 0644) err = os.WriteFile(site+".pem", chain, 0644)
die.If(err) die.If(err)
fmt.Printf("[+] wrote %s.pem.\n", site) fmt.Printf("[+] wrote %s.pem.\n", site)
} }
} }

View File

@@ -60,7 +60,7 @@ func printDigests(paths []string, issuer bool) {
for _, path := range paths { for _, path := range paths {
cert, err := certlib.LoadCertificate(path) cert, err := certlib.LoadCertificate(path)
if err != nil { if err != nil {
lib.Warn(err, "failed to load certificate from %s", path) _, _ = lib.Warn(err, "failed to load certificate from %s", path)
continue continue
} }
@@ -75,20 +75,19 @@ func matchDigests(paths []string, issuer bool) {
} }
var invalid int var invalid int
for { for len(paths) > 0 {
if len(paths) == 0 {
break
}
fst := paths[0] fst := paths[0]
snd := paths[1] snd := paths[1]
paths = paths[2:] paths = paths[2:]
fstCert, err := certlib.LoadCertificate(fst) fstCert, err := certlib.LoadCertificate(fst)
die.If(err) die.If(err)
sndCert, err := certlib.LoadCertificate(snd) sndCert, err := certlib.LoadCertificate(snd)
die.If(err) die.If(err)
if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) {
lib.Warnx("certificates don't match: %s and %s", fst, snd) _, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd)
invalid++ invalid++
} }
} }

View File

@@ -1,10 +1,14 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"fmt" "fmt"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"git.wntrmute.dev/kyle/goutils/die"
) )
func main() { func main() {
@@ -13,16 +17,23 @@ func main() {
os.Exit(1) os.Exit(1)
} }
hostPort := os.Args[1] hostPort, err := hosts.ParseHost(os.Args[1])
conn, err := tls.Dial("tcp", hostPort, &tls.Config{ die.If(err)
InsecureSkipVerify: true,
})
if err != nil { d := &tls.Dialer{Config: &tls.Config{
fmt.Printf("Failed to connect to the TLS server: %v\n", err) InsecureSkipVerify: true,
os.Exit(1) }} // #nosec G402
nc, err := d.DialContext(context.Background(), "tcp", hostPort.String())
die.If(err)
conn, ok := nc.(*tls.Conn)
if !ok {
die.With("invalid TLS connection (not a *tls.Conn)")
} }
defer conn.Close() defer conn.Close()
state := conn.ConnectionState() state := conn.ConnectionState()
printConnectionDetails(state) printConnectionDetails(state)
} }
@@ -37,7 +48,6 @@ func printConnectionDetails(state tls.ConnectionState) {
func tlsVersion(version uint16) string { func tlsVersion(version uint16) string {
switch version { switch version {
case tls.VersionTLS13: case tls.VersionTLS13:
return "TLS 1.3" return "TLS 1.3"
case tls.VersionTLS12: case tls.VersionTLS12:

View File

@@ -11,8 +11,6 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"io/ioutil"
"log"
"os" "os"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
@@ -32,7 +30,7 @@ const (
curveP521 curveP521
) )
func getECCurve(pub interface{}) int { func getECCurve(pub any) int {
switch pub := pub.(type) { switch pub := pub.(type) {
case *ecdsa.PublicKey: case *ecdsa.PublicKey:
switch pub.Curve { switch pub.Curve {
@@ -52,8 +50,75 @@ func getECCurve(pub interface{}) int {
} }
} }
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) { func loadKey(path string) (crypto.Signer, error) {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -78,16 +143,15 @@ func loadKey(path string) (crypto.Signer, error) {
} }
} }
switch priv.(type) { switch p := priv.(type) {
case *rsa.PrivateKey: case *rsa.PrivateKey:
return priv.(*rsa.PrivateKey), nil return p, nil
case *ecdsa.PrivateKey: case *ecdsa.PrivateKey:
return priv.(*ecdsa.PrivateKey), nil return p, nil
} default:
// should never reach here // should never reach here
return nil, errors.New("invalid private key") return nil, errors.New("invalid private key")
}
} }
func main() { func main() {
@@ -96,7 +160,7 @@ func main() {
flag.StringVar(&certFile, "c", "", "TLS `certificate` file") flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse() flag.Parse()
in, err := ioutil.ReadFile(certFile) in, err := os.ReadFile(certFile)
die.If(err) die.If(err)
p, _ := pem.Decode(in) p, _ := pem.Decode(in)
@@ -112,50 +176,11 @@ func main() {
priv, err := loadKey(keyFile) priv, err := loadKey(keyFile)
die.If(err) die.If(err)
switch pub := priv.Public().(type) { matched, reason := matchKeys(cert, priv)
case *rsa.PublicKey: if matched {
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.") fmt.Println("Match.")
return return
case *ecdsa.PublicKey: }
fmt.Println("No match (RSA private key, EC public key).") fmt.Printf("No match (%s).\n", reason)
os.Exit(1) os.Exit(1)
} }
case *ecdsa.PublicKey:
privCurve := getECCurve(pub)
certCurve := getECCurve(cert.PublicKey)
log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve)
if certCurve == curveRSA {
fmt.Println("No match (private key is EC, certificate is RSA).")
os.Exit(1)
} else if privCurve == curveInvalid {
fmt.Println("No match (invalid private key curve).")
os.Exit(1)
} else if privCurve != certCurve {
fmt.Println("No match (EC curves don't match).")
os.Exit(1)
}
certPub := cert.PublicKey.(*ecdsa.PublicKey)
if pub.X.Cmp(certPub.X) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
if pub.Y.Cmp(certPub.Y) != 0 {
fmt.Println("No match (public keys don't match).")
os.Exit(1)
}
fmt.Println("Match.")
default:
fmt.Printf("Unrecognised private key type: %T\n", priv.Public())
os.Exit(1)
}
}

View File

@@ -201,10 +201,6 @@ func init() {
os.Exit(1) os.Exit(1)
} }
if fromLoc == time.UTC {
}
toLoc = time.UTC toLoc = time.UTC
} }
@@ -257,15 +253,16 @@ func main() {
showTime(time.Now()) showTime(time.Now())
os.Exit(0) os.Exit(0)
case 1: case 1:
if flag.Arg(0) == "-" { switch {
case flag.Arg(0) == "-":
s := bufio.NewScanner(os.Stdin) s := bufio.NewScanner(os.Stdin)
for s.Scan() { for s.Scan() {
times = append(times, s.Text()) times = append(times, s.Text())
} }
} else if flag.Arg(0) == "help" { case flag.Arg(0) == "help":
usageExamples() usageExamples()
} else { default:
times = flag.Args() times = flag.Args()
} }
default: default:

View File

@@ -4,7 +4,6 @@ import (
"flag" "flag"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@@ -12,9 +11,8 @@ import (
type empty struct{} type empty struct{}
func errorf(format string, args ...interface{}) { func errorf(path string, err error) {
format += "\n" fmt.Fprintf(os.Stderr, "%s FAILED: %s\n", path, err)
fmt.Fprintf(os.Stderr, format, args...)
} }
func usage(w io.Writer) { func usage(w io.Writer) {
@@ -44,16 +42,16 @@ func main() {
if flag.NArg() == 1 && flag.Arg(0) == "-" { if flag.NArg() == 1 && flag.Arg(0) == "-" {
path := "stdin" path := "stdin"
in, err := ioutil.ReadAll(os.Stdin) in, err := io.ReadAll(os.Stdin)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
os.Exit(1) os.Exit(1)
} }
@@ -65,16 +63,16 @@ func main() {
} }
for _, path := range flag.Args() { for _, path := range flag.Args() {
in, err := ioutil.ReadFile(path) in, err := os.ReadFile(path)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }
var e empty var e empty
err = yaml.Unmarshal(in, &e) err = yaml.Unmarshal(in, &e)
if err != nil { if err != nil {
errorf("%s FAILED: %s", path, err) errorf(path, err)
continue continue
} }

View File

@@ -14,16 +14,16 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"git.wntrmute.dev/kyle/goutils/lib"
) )
const defaultDirectory = ".git/objects" const defaultDirectory = ".git/objects"
func errorf(format string, a ...interface{}) { // maxDecompressedSize limits how many bytes we will decompress from a zlib
fmt.Fprintf(os.Stderr, format, a...) // stream to mitigate decompression bombs (gosec G110).
if format[len(format)-1] != '\n' { // Increase this if you expect larger objects.
fmt.Fprintf(os.Stderr, "\n") const maxDecompressedSize int64 = 64 << 30 // 64 GiB
}
}
func isDir(path string) bool { func isDir(path string) bool {
fi, err := os.Stat(path) fi, err := os.Stat(path)
@@ -48,17 +48,21 @@ func loadFile(path string) ([]byte, error) {
} }
defer zread.Close() defer zread.Close()
_, err = io.Copy(buf, zread) // Protect against decompression bombs by limiting how much we read.
if err != nil { lr := io.LimitReader(zread, maxDecompressedSize+1)
if _, err = buf.ReadFrom(lr); err != nil {
return nil, err return nil, err
} }
if int64(buf.Len()) > maxDecompressedSize {
return nil, fmt.Errorf("decompressed size exceeds limit (%d bytes)", maxDecompressedSize)
}
return buf.Bytes(), nil return buf.Bytes(), nil
} }
func showFile(path string) { func showFile(path string) {
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return return
} }
@@ -68,37 +72,69 @@ func showFile(path string) {
func searchFile(path string, search *regexp.Regexp) error { func searchFile(path string, search *regexp.Regexp) error {
file, err := os.Open(path) file, err := os.Open(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to open %s", path)
return err return err
} }
defer file.Close() defer file.Close()
zread, err := zlib.NewReader(file) zread, err := zlib.NewReader(file)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to decompress %s", path)
return err return err
} }
defer zread.Close() defer zread.Close()
zbuf := bufio.NewReader(zread) // Limit how much we scan to avoid DoS via huge decompression.
if search.MatchReader(zbuf) { lr := io.LimitReader(zread, maxDecompressedSize+1)
zbuf := bufio.NewReader(lr)
if !search.MatchReader(zbuf) {
return nil
}
fileData, err := loadFile(path) fileData, err := loadFile(path)
if err != nil { if err != nil {
errorf("%v", err) lib.Warn(err, "failed to load %s", path)
return err return err
} }
fmt.Printf("%s:\n%s\n", path, fileData) fmt.Printf("%s:\n%s\n", path, fileData)
}
return nil return nil
} }
func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc { func buildWalker(searchExpr *regexp.Regexp) filepath.WalkFunc {
return func(path string, info os.FileInfo, err error) error { return func(path string, info os.FileInfo, _ error) error {
if info.Mode().IsRegular() { if !info.Mode().IsRegular() {
return searchFile(path, searchExpr)
}
return nil return nil
} }
return searchFile(path, searchExpr)
}
}
// runSearch compiles the search expression and processes the provided paths.
// It returns an error for fatal conditions; per-file errors are logged.
func runSearch(expr string) error {
search, err := regexp.Compile(expr)
if err != nil {
return fmt.Errorf("invalid regexp: %w", err)
}
pathList := flag.Args()
if len(pathList) == 0 {
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
if err2 := filepath.Walk(path, buildWalker(search)); err2 != nil {
return err2
}
continue
}
if err2 := searchFile(path, search); err2 != nil {
// Non-fatal: keep going, but report it.
lib.Warn(err2, "non-fatal error while searching files")
}
}
return nil
} }
func main() { func main() {
@@ -109,28 +145,10 @@ func main() {
for _, path := range flag.Args() { for _, path := range flag.Args() {
showFile(path) showFile(path)
} }
} else {
search, err := regexp.Compile(*flSearch)
if err != nil {
errorf("Bad regexp: %v", err)
return return
} }
pathList := flag.Args() if err := runSearch(*flSearch); err != nil {
if len(pathList) == 0 { lib.Err(lib.ExitFailure, err, "failed to run search")
pathList = []string{defaultDirectory}
}
for _, path := range pathList {
if isDir(path) {
err := filepath.Walk(path, buildWalker(search))
if err != nil {
errorf("%v", err)
return
}
} else {
searchFile(path, search)
}
}
} }
} }

1
go.mod
View File

@@ -12,6 +12,7 @@ require (
) )
require ( require (
github.com/benbjohnson/clock v1.3.5
github.com/davecgh/go-spew v1.1.1 github.com/davecgh/go-spew v1.1.1
github.com/google/certificate-transparency-go v1.0.21 github.com/google/certificate-transparency-go v1.0.21
) )

2
go.sum
View File

@@ -1,3 +1,5 @@
github.com/benbjohnson/clock v1.3.5 h1:VvXlSJBzZpA/zum6Sj74hxwYI2DIxRWuNIoXAzHZz5o=
github.com/benbjohnson/clock v1.3.5/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

View File

@@ -1,4 +1,4 @@
// +build freebsd darwin,386 netbsd //go:build bsd
package lib package lib

View File

@@ -1,4 +1,4 @@
// +build unix linux openbsd darwin,amd64 //go:build unix || linux || openbsd || (darwin && amd64)
package lib package lib
@@ -18,7 +18,7 @@ type FileTime struct {
func timeSpecToTime(ts unix.Timespec) time.Time { func timeSpecToTime(ts unix.Timespec) time.Time {
// The casts to int64 are needed because on 386, these are int32s. // The casts to int64 are needed because on 386, these are int32s.
return time.Unix(int64(ts.Sec), int64(ts.Nsec)) return time.Unix(ts.Sec, ts.Nsec)
} }
// LoadFileTime returns a FileTime associated with the file. // LoadFileTime returns a FileTime associated with the file.

View File

@@ -10,6 +10,12 @@ import (
var progname = filepath.Base(os.Args[0]) var progname = filepath.Base(os.Args[0])
const (
daysInYear = 365
digitWidth = 10
hoursInQuarterDay = 6
)
// ProgName returns what lib thinks the program name is, namely the // ProgName returns what lib thinks the program name is, namely the
// basename of argv0. // basename of argv0.
// //
@@ -20,7 +26,7 @@ func ProgName() string {
// Warnx displays a formatted error message to standard error, à la // Warnx displays a formatted error message to standard error, à la
// warnx(3). // warnx(3).
func Warnx(format string, a ...interface{}) (int, error) { func Warnx(format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
return fmt.Fprintf(os.Stderr, format, a...) return fmt.Fprintf(os.Stderr, format, a...)
@@ -28,7 +34,7 @@ func Warnx(format string, a ...interface{}) (int, error) {
// Warn displays a formatted error message to standard output, // Warn displays a formatted error message to standard output,
// appending the error string, à la warn(3). // appending the error string, à la warn(3).
func Warn(err error, format string, a ...interface{}) (int, error) { func Warn(err error, format string, a ...any) (int, error) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -37,7 +43,7 @@ func Warn(err error, format string, a ...interface{}) (int, error) {
// Errx displays a formatted error message to standard error and exits // Errx displays a formatted error message to standard error and exits
// with the status code from `exit`, à la errx(3). // with the status code from `exit`, à la errx(3).
func Errx(exit int, format string, a ...interface{}) { func Errx(exit int, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += "\n" format += "\n"
fmt.Fprintf(os.Stderr, format, a...) fmt.Fprintf(os.Stderr, format, a...)
@@ -47,7 +53,7 @@ func Errx(exit int, format string, a ...interface{}) {
// Err displays a formatting error message to standard error, // Err displays a formatting error message to standard error,
// appending the error string, and exits with the status code from // appending the error string, and exits with the status code from
// `exit`, à la err(3). // `exit`, à la err(3).
func Err(exit int, err error, format string, a ...interface{}) { func Err(exit int, err error, format string, a ...any) {
format = fmt.Sprintf("[%s] %s", progname, format) format = fmt.Sprintf("[%s] %s", progname, format)
format += ": %v\n" format += ": %v\n"
a = append(a, err) a = append(a, err)
@@ -62,30 +68,30 @@ func Itoa(i int, wid int) string {
// Assemble decimal in reverse order. // Assemble decimal in reverse order.
var b [20]byte var b [20]byte
bp := len(b) - 1 bp := len(b) - 1
for i >= 10 || wid > 1 { for i >= digitWidth || wid > 1 {
wid-- wid--
q := i / 10 q := i / digitWidth
b[bp] = byte('0' + i - q*10) b[bp] = byte('0' + i - q*digitWidth)
bp-- bp--
i = q i = q
} }
// i < 10
b[bp] = byte('0' + i) b[bp] = byte('0' + i)
return string(b[bp:]) return string(b[bp:])
} }
var ( var (
dayDuration = 24 * time.Hour dayDuration = 24 * time.Hour
yearDuration = (365 * dayDuration) + (6 * time.Hour) yearDuration = (daysInYear * dayDuration) + (hoursInQuarterDay * time.Hour)
) )
// Duration returns a prettier string for time.Durations. // Duration returns a prettier string for time.Durations.
func Duration(d time.Duration) string { func Duration(d time.Duration) string {
var s string var s string
if d >= yearDuration { if d >= yearDuration {
years := d / yearDuration years := int64(d / yearDuration)
s += fmt.Sprintf("%dy", years) s += fmt.Sprintf("%dy", years)
d -= years * yearDuration d -= time.Duration(years) * yearDuration
} }
if d >= dayDuration { if d >= dayDuration {
@@ -98,8 +104,8 @@ func Duration(d time.Duration) string {
} }
d %= 1 * time.Second d %= 1 * time.Second
hours := d / time.Hour hours := int64(d / time.Hour)
d -= hours * time.Hour d -= time.Duration(hours) * time.Hour
s += fmt.Sprintf("%dh%s", hours, d) s += fmt.Sprintf("%dh%s", hours, d)
return s return s
} }

View File

@@ -1,6 +1,7 @@
package logging package logging
import ( import (
"errors"
"fmt" "fmt"
"os" "os"
) )
@@ -61,9 +62,9 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) {
if err != nil { if err != nil {
if closeErr := fl.Close(); closeErr != nil { if closeErr := fl.Close(); closeErr != nil {
return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err) return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err))
} }
return nil, err return nil, fmt.Errorf("failed to open error log: %w", err)
} }
fl.LogWriter = NewLogWriter(fl.fo, fl.fe) fl.LogWriter = NewLogWriter(fl.fo, fl.fe)