diff --git a/.golangci.yml b/.golangci.yml index d109533..a8158b5 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -451,7 +451,7 @@ linters: - path: 'logging/example_test.go' linters: [ testableexamples ] - path: 'main.go' - linters: [ forbidigo, mnd ] + linters: [ forbidigo, mnd, reassign ] - source: 'TODO' linters: [ godot ] - text: 'should have a package comment' diff --git a/cmd/cert-bundler/main.go b/cmd/cert-bundler/main.go index be2f22e..6a7a638 100644 --- a/cmd/cert-bundler/main.go +++ b/cmd/cert-bundler/main.go @@ -8,8 +8,10 @@ import ( "crypto/x509" _ "embed" "encoding/pem" + "errors" "flag" "fmt" + "io" "os" "path/filepath" "strings" @@ -19,7 +21,7 @@ import ( "gopkg.in/yaml.v2" ) -// Config represents the top-level YAML configuration +// Config represents the top-level YAML configuration. type Config struct { Config struct { Hashes string `yaml:"hashes"` @@ -28,19 +30,19 @@ type Config struct { Chains map[string]ChainGroup `yaml:"chains"` } -// ChainGroup represents a named group of certificate chains +// ChainGroup represents a named group of certificate chains. type ChainGroup struct { Certs []CertChain `yaml:"certs"` Outputs Outputs `yaml:"outputs"` } -// CertChain represents a root certificate and its intermediates +// CertChain represents a root certificate and its intermediates. type CertChain struct { Root string `yaml:"root"` Intermediates []string `yaml:"intermediates"` } -// Outputs defines output format options +// Outputs defines output format options. type Outputs struct { IncludeSingle bool `yaml:"include_single"` IncludeIndividual bool `yaml:"include_individual"` @@ -95,7 +97,8 @@ func main() { } // Create output directory if it doesn't exist - if err := os.MkdirAll(outputDir, 0755); err != nil { + err = os.MkdirAll(outputDir, 0750) + if err != nil { fmt.Fprintf(os.Stderr, "Error creating output directory: %v\n", err) os.Exit(1) } @@ -108,9 +111,9 @@ func main() { } createdFiles := make([]string, 0, totalFormats) for groupName, group := range cfg.Chains { - files, err := processChainGroup(groupName, group, expiryDuration) - if err != nil { - fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, err) + files, perr := processChainGroup(groupName, group, expiryDuration) + if perr != nil { + fmt.Fprintf(os.Stderr, "Error processing chain group %s: %v\n", groupName, perr) os.Exit(1) } createdFiles = append(createdFiles, files...) @@ -119,8 +122,8 @@ func main() { // Generate hash file for all created archives if cfg.Config.Hashes != "" { hashFile := filepath.Join(outputDir, cfg.Config.Hashes) - if err := generateHashFile(hashFile, createdFiles); err != nil { - fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", err) + if gerr := generateHashFile(hashFile, createdFiles); gerr != nil { + fmt.Fprintf(os.Stderr, "Error generating hash file: %v\n", gerr) os.Exit(1) } } @@ -135,8 +138,8 @@ func loadConfig(path string) (*Config, error) { } var cfg Config - if err := yaml.Unmarshal(data, &cfg); err != nil { - return nil, err + if uerr := yaml.Unmarshal(data, &cfg); uerr != nil { + return nil, uerr } return &cfg, nil @@ -200,72 +203,107 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D return createdFiles, nil } -// loadAndCollectCerts loads all certificates from chains and collects them for processing -func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration time.Duration) ([]*x509.Certificate, []certWithPath, error) { +// loadAndCollectCerts loads all certificates from chains and collects them for processing. +func loadAndCollectCerts( + chains []CertChain, + outputs Outputs, + expiryDuration time.Duration, +) ([]*x509.Certificate, []certWithPath, error) { var singleFileCerts []*x509.Certificate var individualCerts []certWithPath for _, chain := range chains { - // Load root certificate - rootCert, err := certlib.LoadCertificate(chain.Root) - if err != nil { - return nil, nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err) + s, i, cerr := collectFromChain(chain, outputs, expiryDuration) + if cerr != nil { + return nil, nil, cerr } - - // Check expiry for root - checkExpiry(chain.Root, rootCert, expiryDuration) - - // Add root to collections if needed - if outputs.IncludeSingle { - singleFileCerts = append(singleFileCerts, rootCert) + if len(s) > 0 { + singleFileCerts = append(singleFileCerts, s...) } - if outputs.IncludeIndividual { - individualCerts = append(individualCerts, certWithPath{ - cert: rootCert, - path: chain.Root, - }) - } - - // Load and validate intermediates - for _, intPath := range chain.Intermediates { - intCert, err := certlib.LoadCertificate(intPath) - if err != nil { - return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err) - } - - // Validate that intermediate is signed by root - if err := intCert.CheckSignatureFrom(rootCert); err != nil { - return nil, nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err) - } - - // Check expiry for intermediate - checkExpiry(intPath, intCert, expiryDuration) - - // Add intermediate to collections if needed - if outputs.IncludeSingle { - singleFileCerts = append(singleFileCerts, intCert) - } - if outputs.IncludeIndividual { - individualCerts = append(individualCerts, certWithPath{ - cert: intCert, - path: intPath, - }) - } + if len(i) > 0 { + individualCerts = append(individualCerts, i...) } } return singleFileCerts, individualCerts, nil } -// prepareArchiveFiles prepares all files to be included in archives -func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []certWithPath, outputs Outputs, encoding string) ([]fileEntry, error) { +// collectFromChain loads a single chain, performs checks, and returns the certs to include. +func collectFromChain( + chain CertChain, + outputs Outputs, + expiryDuration time.Duration, +) ( + []*x509.Certificate, + []certWithPath, + error, +) { + var single []*x509.Certificate + var indiv []certWithPath + + // Load root certificate + rootCert, rerr := certlib.LoadCertificate(chain.Root) + if rerr != nil { + return nil, nil, fmt.Errorf("failed to load root certificate %s: %w", chain.Root, rerr) + } + + // Check expiry for root + checkExpiry(chain.Root, rootCert, expiryDuration) + + // Add root to collections if needed + if outputs.IncludeSingle { + single = append(single, rootCert) + } + if outputs.IncludeIndividual { + indiv = append(indiv, certWithPath{cert: rootCert, path: chain.Root}) + } + + // Load and validate intermediates + for _, intPath := range chain.Intermediates { + intCert, lerr := certlib.LoadCertificate(intPath) + if lerr != nil { + return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %w", intPath, lerr) + } + + // Validate that intermediate is signed by root + if sigErr := intCert.CheckSignatureFrom(rootCert); sigErr != nil { + return nil, nil, fmt.Errorf( + "intermediate %s is not properly signed by root %s: %w", + intPath, + chain.Root, + sigErr, + ) + } + + // Check expiry for intermediate + checkExpiry(intPath, intCert, expiryDuration) + + // Add intermediate to collections if needed + if outputs.IncludeSingle { + single = append(single, intCert) + } + if outputs.IncludeIndividual { + indiv = append(indiv, certWithPath{cert: intCert, path: intPath}) + } + } + + return single, indiv, nil +} + +// prepareArchiveFiles prepares all files to be included in archives. +func prepareArchiveFiles( + singleFileCerts []*x509.Certificate, + individualCerts []certWithPath, + outputs Outputs, + encoding string, +) ([]fileEntry, error) { var archiveFiles []fileEntry // Handle a single bundle file if outputs.IncludeSingle && len(singleFileCerts) > 0 { files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true) if err != nil { - return nil, fmt.Errorf("failed to encode single bundle: %v", err) + return nil, fmt.Errorf("failed to encode single bundle: %w", err) } archiveFiles = append(archiveFiles, files...) } @@ -276,7 +314,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts [] baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path)) files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false) if err != nil { - return nil, fmt.Errorf("failed to encode individual cert %s: %v", cp.path, err) + return nil, fmt.Errorf("failed to encode individual cert %s: %w", cp.path, err) } archiveFiles = append(archiveFiles, files...) } @@ -294,7 +332,7 @@ func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts [] return archiveFiles, nil } -// createArchiveFiles creates archive files in the specified formats +// createArchiveFiles creates archive files in the specified formats. func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) { createdFiles := make([]string, 0, len(formats)) @@ -307,11 +345,11 @@ func createArchiveFiles(groupName string, formats []string, archiveFiles []fileE switch format { case "zip": if err := createZipArchive(archivePath, archiveFiles); err != nil { - return nil, fmt.Errorf("failed to create zip archive: %v", err) + return nil, fmt.Errorf("failed to create zip archive: %w", err) } case "tgz": if err := createTarGzArchive(archivePath, archiveFiles); err != nil { - return nil, fmt.Errorf("failed to create tar.gz archive: %v", err) + return nil, fmt.Errorf("failed to create tar.gz archive: %w", err) } default: return nil, fmt.Errorf("unsupported format: %s", format) @@ -329,7 +367,12 @@ func checkExpiry(path string, cert *x509.Certificate, expiryDuration time.Durati if cert.NotAfter.Before(expiryThreshold) { daysUntilExpiry := int(cert.NotAfter.Sub(now).Hours() / 24) if daysUntilExpiry < 0 { - fmt.Fprintf(os.Stderr, "WARNING: Certificate %s has EXPIRED (expired %d days ago)\n", path, -daysUntilExpiry) + fmt.Fprintf( + os.Stderr, + "WARNING: Certificate %s has EXPIRED (expired %d days ago)\n", + path, + -daysUntilExpiry, + ) } else { fmt.Fprintf(os.Stderr, "WARNING: Certificate %s will expire in %d days (on %s)\n", path, daysUntilExpiry, cert.NotAfter.Format("2006-01-02")) } @@ -347,8 +390,13 @@ type certWithPath struct { } // encodeCertsToFiles converts certificates to file entries based on encoding type -// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file -func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding string, isSingle bool) ([]fileEntry, error) { +// If isSingle is true, certs are concatenated into a single file; otherwise one cert per file. +func encodeCertsToFiles( + certs []*x509.Certificate, + baseName string, + encoding string, + isSingle bool, +) ([]fileEntry, error) { var files []fileEntry switch encoding { @@ -369,14 +417,12 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str name: baseName + ".crt", content: derContent, }) - } else { + } else if len(certs) > 0 { // Individual DER file (should only have one cert) - if len(certs) > 0 { - files = append(files, fileEntry{ - name: baseName + ".crt", - content: certs[0].Raw, - }) - } + files = append(files, fileEntry{ + name: baseName + ".crt", + content: certs[0].Raw, + }) } case "both": // Add PEM version @@ -395,13 +441,11 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str name: baseName + ".crt", content: derContent, }) - } else { - if len(certs) > 0 { - files = append(files, fileEntry{ - name: baseName + ".crt", - content: certs[0].Raw, - }) - } + } else if len(certs) > 0 { + files = append(files, fileEntry{ + name: baseName + ".crt", + content: certs[0].Raw, + }) } default: return nil, fmt.Errorf("unsupported encoding: %s (must be 'pem', 'der', or 'both')", encoding) @@ -410,7 +454,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str return files, nil } -// encodeCertsToPEM encodes certificates to PEM format +// encodeCertsToPEM encodes certificates to PEM format. func encodeCertsToPEM(certs []*x509.Certificate) []byte { var pemContent []byte for _, cert := range certs { @@ -435,40 +479,49 @@ func generateManifest(files []fileEntry) []byte { return []byte(manifest.String()) } +// closeWithErr attempts to close all provided closers, joining any close errors with baseErr. +func closeWithErr(baseErr error, closers ...io.Closer) error { + for _, c := range closers { + if c == nil { + continue + } + if cerr := c.Close(); cerr != nil { + baseErr = errors.Join(baseErr, cerr) + } + } + return baseErr +} + func createZipArchive(path string, files []fileEntry) error { - f, err := os.Create(path) - if err != nil { - return err + f, zerr := os.Create(path) + if zerr != nil { + return zerr } w := zip.NewWriter(f) for _, file := range files { - fw, err := w.Create(file.name) - if err != nil { - w.Close() - f.Close() - return err + fw, werr := w.Create(file.name) + if werr != nil { + return closeWithErr(werr, w, f) } - if _, err := fw.Write(file.content); err != nil { - w.Close() - f.Close() - return err + if _, werr = fw.Write(file.content); werr != nil { + return closeWithErr(werr, w, f) } } // Check errors on close operations - if err := w.Close(); err != nil { - f.Close() - return err + if cerr := w.Close(); cerr != nil { + _ = f.Close() + return cerr } return f.Close() } func createTarGzArchive(path string, files []fileEntry) error { - f, err := os.Create(path) - if err != nil { - return err + f, terr := os.Create(path) + if terr != nil { + return terr } gw := gzip.NewWriter(f) @@ -480,29 +533,23 @@ func createTarGzArchive(path string, files []fileEntry) error { Mode: 0644, Size: int64(len(file.content)), } - if err := tw.WriteHeader(hdr); err != nil { - tw.Close() - gw.Close() - f.Close() - return err + if herr := tw.WriteHeader(hdr); herr != nil { + return closeWithErr(herr, tw, gw, f) } - if _, err := tw.Write(file.content); err != nil { - tw.Close() - gw.Close() - f.Close() - return err + if _, werr := tw.Write(file.content); werr != nil { + return closeWithErr(werr, tw, gw, f) } } // Check errors on close operations in the correct order - if err := tw.Close(); err != nil { - gw.Close() - f.Close() - return err + if cerr := tw.Close(); cerr != nil { + _ = gw.Close() + _ = f.Close() + return cerr } - if err := gw.Close(); err != nil { - f.Close() - return err + if cerr := gw.Close(); cerr != nil { + _ = f.Close() + return cerr } return f.Close() } @@ -515,9 +562,9 @@ func generateHashFile(path string, files []string) error { defer f.Close() for _, file := range files { - data, err := os.ReadFile(file) - if err != nil { - return err + data, rerr := os.ReadFile(file) + if rerr != nil { + return rerr } hash := sha256.Sum256(data)