cmd: switch programs over to certlib.Fetcher.

This commit is contained in:
2025-11-18 11:08:17 -08:00
parent 8d5406256f
commit 4560868688
8 changed files with 114 additions and 25 deletions

View File

@@ -12,6 +12,7 @@ import (
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
"sort"
"strings" "strings"
"time" "time"
@@ -483,11 +484,19 @@ func encodeCertsToPEM(certs []*x509.Certificate) []byte {
} }
func generateManifest(files []fileEntry) []byte { func generateManifest(files []fileEntry) []byte {
var manifest strings.Builder // Build a sorted list of files by filename to ensure deterministic manifest ordering
for _, file := range files { sorted := make([]fileEntry, 0, len(files))
if file.name == "MANIFEST" { for _, f := range files {
// Defensive: skip any existing manifest entry
if f.name == "MANIFEST" {
continue continue
} }
sorted = append(sorted, f)
}
sort.Slice(sorted, func(i, j int) bool { return sorted[i].name < sorted[j].name })
var manifest strings.Builder
for _, file := range sorted {
hash := sha256.Sum256(file.content) hash := sha256.Sum256(file.content)
manifest.WriteString(fmt.Sprintf("%x %s\n", hash, file.name)) manifest.WriteString(fmt.Sprintf("%x %s\n", hash, file.name))
} }

View File

@@ -396,6 +396,41 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
return certs, rest, nil return certs, rest, nil
} }
// LoadFullCertPool returns a certificate pool with roots and intermediates
// from disk. If no roots are provided, the system root pool will be used.
func LoadFullCertPool(roots, intermediates string) (*x509.CertPool, error) {
pool := x509.NewCertPool()
if roots == "" {
pool, err := x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("loading system cert pool: %w", err)
}
} else {
rootCerts, err := LoadCertificates(roots)
if err != nil {
return nil, fmt.Errorf("loading roots: %w", err)
}
for _, cert := range rootCerts {
pool.AddCert(cert)
}
}
if intermediates != "" {
intCerts, err := LoadCertificates(intermediates)
if err != nil {
return nil, fmt.Errorf("loading intermediates: %w", err)
}
for _, cert := range intCerts {
pool.AddCert(cert)
}
}
return pool, nil
}
// LoadPEMCertPool loads a pool of PEM certificates from file. // LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" { if certsFile == "" {

View File

@@ -26,8 +26,14 @@ func parseURL(host string) (string, int, error) {
return "", 0, fmt.Errorf("certlib/hosts: invalid host: %s", host) return "", 0, fmt.Errorf("certlib/hosts: invalid host: %s", host)
} }
if strings.ToLower(url.Scheme) != "https" { switch strings.ToLower(url.Scheme) {
case "https":
// OK
case "tls":
// OK
default:
return "", 0, errors.New("certlib/hosts: only https scheme supported") return "", 0, errors.New("certlib/hosts: only https scheme supported")
} }
if url.Port() == "" { if url.Port() == "" {
@@ -43,28 +49,28 @@ func parseURL(host string) (string, int, error) {
} }
func parseHostPort(host string) (string, int, error) { func parseHostPort(host string) (string, int, error) {
host, sport, err := net.SplitHostPort(host) shost, sport, err := net.SplitHostPort(host)
if err == nil { if err == nil {
portInt, err2 := strconv.ParseInt(sport, 10, 16) portInt, err2 := strconv.ParseInt(sport, 10, 16)
if err2 != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport)
} }
return host, int(portInt), nil return shost, int(portInt), nil
} }
return host, defaultHTTPSPort, nil return host, defaultHTTPSPort, nil
} }
func ParseHost(host string) (*Target, error) { func ParseHost(host string) (*Target, error) {
host, port, err := parseURL(host) uhost, port, err := parseURL(host)
if err == nil { if err == nil {
return &Target{Host: host, Port: port}, nil return &Target{Host: uhost, Port: port}, nil
} }
host, port, err = parseHostPort(host) shost, port, err := parseHostPort(host)
if err == nil { if err == nil {
return &Target{Host: host, Port: port}, nil return &Target{Host: shost, Port: port}, nil
} }
return nil, fmt.Errorf("certlib/hosts: invalid host: %s", host) return nil, fmt.Errorf("certlib/hosts: invalid host: %s", host)

View File

@@ -0,0 +1,34 @@
package hosts_test
import (
"git.wntrmute.dev/kyle/goutils/certlib/hosts"
"testing"
)
type testCase struct {
Host string
Target hosts.Target
}
var testCases = []testCase{
{Host: "server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "tls://server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "https://server-name", Target: hosts.Target{Host: "server-name", Port: 443}},
{Host: "https://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "tls://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}},
{Host: "https://server-name/something/else", Target: hosts.Target{Host: "server-name", Port: 443}},
}
func TestParseHost(t *testing.T) {
for i, tc := range testCases {
target, err := hosts.ParseHost(tc.Host)
if err != nil {
t.Fatalf("test case %d: %s", i+1, err)
}
if target.Host != tc.Target.Host {
t.Fatalf("test case %d: got host '%s', want host '%s'", i+1, target.Host, tc.Target.Host)
}
}
}

View File

@@ -112,7 +112,7 @@ func showBasicConstraints(cert *x509.Certificate) {
fmt.Fprint(os.Stdout, " (basic constraint failure)") fmt.Fprint(os.Stdout, " (basic constraint failure)")
} }
} else { } else {
fmt.Fprint(os.Stdout, "is not a CA certificate") fmt.Fprint(os.Stdout, ", is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)") fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
} }

View File

@@ -75,18 +75,15 @@ func checkCert(cert *x509.Certificate) {
} }
func main() { func main() {
opts := &certlib.FetcherOpts{}
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs")
flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring")
flag.Parse() flag.Parse()
for _, file := range flag.Args() { for _, file := range flag.Args() {
in, err := os.ReadFile(file) certs, err := certlib.GetCertificateChain(file, opts)
if err != nil {
_, _ = lib.Warn(err, "failed to read file")
continue
}
certs, err := certlib.ParseCertificatesPEM(in)
if err != nil { if err != nil {
_, _ = lib.Warn(err, "while parsing certificates") _, _ = lib.Warn(err, "while parsing certificates")
continue continue

View File

@@ -32,14 +32,16 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
} }
func main() { func main() {
opts := &certlib.FetcherOpts{}
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)") displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date") showExpiry := flag.Bool("e", false, "show expiry date")
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
flag.Parse() flag.Parse()
displayMode := parseDisplayMode(*displayAs) displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() { for _, arg := range flag.Args() {
cert, err := certlib.LoadCertificate(arg) cert, err := certlib.GetCertificate(arg, opts)
die.If(err) die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode)) fmt.Printf("%s: %s", arg, serialString(cert, displayMode))

View File

@@ -31,7 +31,7 @@ func printRevocation(cert *x509.Certificate) {
type appConfig struct { type appConfig struct {
caFile, intFile string caFile, intFile string
forceIntermediateBundle bool forceIntermediateBundle bool
revexp, verbose bool revexp, skipVerify, verbose bool
} }
func parseFlags() appConfig { func parseFlags() appConfig {
@@ -40,6 +40,7 @@ func parseFlags() appConfig {
flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`") flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`")
flag.BoolVar(&cfg.forceIntermediateBundle, "f", false, flag.BoolVar(&cfg.forceIntermediateBundle, "f", false,
"force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") "force the use of the intermediate bundle, ignoring any intermediates bundled with certificate")
flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose") flag.BoolVar(&cfg.verbose, "v", false, "verbose")
flag.Parse() flag.Parse()
@@ -102,12 +103,17 @@ func run(cfg appConfig) error {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName()) fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
} }
fileData, err := os.ReadFile(flag.Arg(0)) combinedPool, err := certlib.LoadFullCertPool(cfg.caFile, cfg.intFile)
if err != nil { if err != nil {
return err return fmt.Errorf("failed to build combined pool: %w", err)
} }
chain, err := certlib.ParseCertificatesPEM(fileData) opts := &certlib.FetcherOpts{
Roots: combinedPool,
SkipVerify: cfg.skipVerify,
}
chain, err := certlib.GetCertificateChain(flag.Arg(0), opts)
if err != nil { if err != nil {
return err return err
} }