From 4560868688d7dd56847d5214c1e7f983e109620b Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Tue, 18 Nov 2025 11:08:17 -0800 Subject: [PATCH] cmd: switch programs over to certlib.Fetcher. --- certlib/bundler/bundler.go | 15 ++++++++++++--- certlib/helpers.go | 35 +++++++++++++++++++++++++++++++++++ certlib/hosts/hosts.go | 20 +++++++++++++------- certlib/hosts/hosts_test.go | 34 ++++++++++++++++++++++++++++++++++ cmd/certdump/main.go | 2 +- cmd/certexpiry/main.go | 11 ++++------- cmd/certser/main.go | 4 +++- cmd/certverify/main.go | 18 ++++++++++++------ 8 files changed, 114 insertions(+), 25 deletions(-) create mode 100644 certlib/hosts/hosts_test.go diff --git a/certlib/bundler/bundler.go b/certlib/bundler/bundler.go index e47f354..8287155 100644 --- a/certlib/bundler/bundler.go +++ b/certlib/bundler/bundler.go @@ -12,6 +12,7 @@ import ( "io" "os" "path/filepath" + "sort" "strings" "time" @@ -483,11 +484,19 @@ func encodeCertsToPEM(certs []*x509.Certificate) []byte { } func generateManifest(files []fileEntry) []byte { - var manifest strings.Builder - for _, file := range files { - if file.name == "MANIFEST" { + // Build a sorted list of files by filename to ensure deterministic manifest ordering + sorted := make([]fileEntry, 0, len(files)) + for _, f := range files { + // Defensive: skip any existing manifest entry + if f.name == "MANIFEST" { continue } + sorted = append(sorted, f) + } + sort.Slice(sorted, func(i, j int) bool { return sorted[i].name < sorted[j].name }) + + var manifest strings.Builder + for _, file := range sorted { hash := sha256.Sum256(file.content) manifest.WriteString(fmt.Sprintf("%x %s\n", hash, file.name)) } diff --git a/certlib/helpers.go b/certlib/helpers.go index 9fc4b91..bf525d4 100644 --- a/certlib/helpers.go +++ b/certlib/helpers.go @@ -396,6 +396,41 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e return certs, rest, nil } +// LoadFullCertPool returns a certificate pool with roots and intermediates +// from disk. If no roots are provided, the system root pool will be used. +func LoadFullCertPool(roots, intermediates string) (*x509.CertPool, error) { + pool := x509.NewCertPool() + + if roots == "" { + pool, err := x509.SystemCertPool() + if err != nil { + return nil, fmt.Errorf("loading system cert pool: %w", err) + } + } else { + rootCerts, err := LoadCertificates(roots) + if err != nil { + return nil, fmt.Errorf("loading roots: %w", err) + } + + for _, cert := range rootCerts { + pool.AddCert(cert) + } + } + + if intermediates != "" { + intCerts, err := LoadCertificates(intermediates) + if err != nil { + return nil, fmt.Errorf("loading intermediates: %w", err) + } + + for _, cert := range intCerts { + pool.AddCert(cert) + } + } + + return pool, nil +} + // LoadPEMCertPool loads a pool of PEM certificates from file. func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { if certsFile == "" { diff --git a/certlib/hosts/hosts.go b/certlib/hosts/hosts.go index d1892af..80ac494 100644 --- a/certlib/hosts/hosts.go +++ b/certlib/hosts/hosts.go @@ -26,8 +26,14 @@ func parseURL(host string) (string, int, error) { return "", 0, fmt.Errorf("certlib/hosts: invalid host: %s", host) } - if strings.ToLower(url.Scheme) != "https" { + switch strings.ToLower(url.Scheme) { + case "https": + // OK + case "tls": + // OK + default: return "", 0, errors.New("certlib/hosts: only https scheme supported") + } if url.Port() == "" { @@ -43,28 +49,28 @@ func parseURL(host string) (string, int, error) { } func parseHostPort(host string) (string, int, error) { - host, sport, err := net.SplitHostPort(host) + shost, sport, err := net.SplitHostPort(host) if err == nil { portInt, err2 := strconv.ParseInt(sport, 10, 16) if err2 != nil { return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) } - return host, int(portInt), nil + return shost, int(portInt), nil } return host, defaultHTTPSPort, nil } func ParseHost(host string) (*Target, error) { - host, port, err := parseURL(host) + uhost, port, err := parseURL(host) if err == nil { - return &Target{Host: host, Port: port}, nil + return &Target{Host: uhost, Port: port}, nil } - host, port, err = parseHostPort(host) + shost, port, err := parseHostPort(host) if err == nil { - return &Target{Host: host, Port: port}, nil + return &Target{Host: shost, Port: port}, nil } return nil, fmt.Errorf("certlib/hosts: invalid host: %s", host) diff --git a/certlib/hosts/hosts_test.go b/certlib/hosts/hosts_test.go new file mode 100644 index 0000000..09dab4b --- /dev/null +++ b/certlib/hosts/hosts_test.go @@ -0,0 +1,34 @@ +package hosts_test + +import ( + "git.wntrmute.dev/kyle/goutils/certlib/hosts" + "testing" +) + +type testCase struct { + Host string + Target hosts.Target +} + +var testCases = []testCase{ + {Host: "server-name", Target: hosts.Target{Host: "server-name", Port: 443}}, + {Host: "server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}}, + {Host: "tls://server-name", Target: hosts.Target{Host: "server-name", Port: 443}}, + {Host: "https://server-name", Target: hosts.Target{Host: "server-name", Port: 443}}, + {Host: "https://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}}, + {Host: "tls://server-name:8443", Target: hosts.Target{Host: "server-name", Port: 8443}}, + {Host: "https://server-name/something/else", Target: hosts.Target{Host: "server-name", Port: 443}}, +} + +func TestParseHost(t *testing.T) { + for i, tc := range testCases { + target, err := hosts.ParseHost(tc.Host) + if err != nil { + t.Fatalf("test case %d: %s", i+1, err) + } + + if target.Host != tc.Target.Host { + t.Fatalf("test case %d: got host '%s', want host '%s'", i+1, target.Host, tc.Target.Host) + } + } +} diff --git a/cmd/certdump/main.go b/cmd/certdump/main.go index 4c33195..b2913da 100644 --- a/cmd/certdump/main.go +++ b/cmd/certdump/main.go @@ -112,7 +112,7 @@ func showBasicConstraints(cert *x509.Certificate) { fmt.Fprint(os.Stdout, " (basic constraint failure)") } } else { - fmt.Fprint(os.Stdout, "is not a CA certificate") + fmt.Fprint(os.Stdout, ", is not a CA certificate") if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)") } diff --git a/cmd/certexpiry/main.go b/cmd/certexpiry/main.go index a604fb2..71a680d 100644 --- a/cmd/certexpiry/main.go +++ b/cmd/certexpiry/main.go @@ -75,18 +75,15 @@ func checkCert(cert *x509.Certificate) { } func main() { + opts := &certlib.FetcherOpts{} + + flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.Parse() for _, file := range flag.Args() { - in, err := os.ReadFile(file) - if err != nil { - _, _ = lib.Warn(err, "failed to read file") - continue - } - - certs, err := certlib.ParseCertificatesPEM(in) + certs, err := certlib.GetCertificateChain(file, opts) if err != nil { _, _ = lib.Warn(err, "while parsing certificates") continue diff --git a/cmd/certser/main.go b/cmd/certser/main.go index d651b50..83243b8 100644 --- a/cmd/certser/main.go +++ b/cmd/certser/main.go @@ -32,14 +32,16 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string { } func main() { + opts := &certlib.FetcherOpts{} displayAs := flag.String("d", "int", "display mode (int, hex, uhex)") showExpiry := flag.Bool("e", false, "show expiry date") + flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") flag.Parse() displayMode := parseDisplayMode(*displayAs) for _, arg := range flag.Args() { - cert, err := certlib.LoadCertificate(arg) + cert, err := certlib.GetCertificate(arg, opts) die.If(err) fmt.Printf("%s: %s", arg, serialString(cert, displayMode)) diff --git a/cmd/certverify/main.go b/cmd/certverify/main.go index 9e32e5a..95c07ec 100644 --- a/cmd/certverify/main.go +++ b/cmd/certverify/main.go @@ -29,9 +29,9 @@ func printRevocation(cert *x509.Certificate) { } type appConfig struct { - caFile, intFile string - forceIntermediateBundle bool - revexp, verbose bool + caFile, intFile string + forceIntermediateBundle bool + revexp, skipVerify, verbose bool } func parseFlags() appConfig { @@ -40,6 +40,7 @@ func parseFlags() appConfig { flag.StringVar(&cfg.intFile, "i", "", "intermediate `bundle`") flag.BoolVar(&cfg.forceIntermediateBundle, "f", false, "force the use of the intermediate bundle, ignoring any intermediates bundled with certificate") + flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.verbose, "v", false, "verbose") flag.Parse() @@ -102,12 +103,17 @@ func run(cfg appConfig) error { fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName()) } - fileData, err := os.ReadFile(flag.Arg(0)) + combinedPool, err := certlib.LoadFullCertPool(cfg.caFile, cfg.intFile) if err != nil { - return err + return fmt.Errorf("failed to build combined pool: %w", err) } - chain, err := certlib.ParseCertificatesPEM(fileData) + opts := &certlib.FetcherOpts{ + Roots: combinedPool, + SkipVerify: cfg.skipVerify, + } + + chain, err := certlib.GetCertificateChain(flag.Arg(0), opts) if err != nil { return err }