diff --git a/cmd/cert-bundler/main.go b/cmd/cert-bundler/main.go index c9e0e1f..be2f22e 100644 --- a/cmd/cert-bundler/main.go +++ b/cmd/cert-bundler/main.go @@ -101,7 +101,12 @@ func main() { } // Process each chain group - createdFiles := []string{} + // Pre-allocate createdFiles based on total number of formats across all groups + totalFormats := 0 + for _, group := range cfg.Chains { + totalFormats += len(group.Outputs.Formats) + } + createdFiles := make([]string, 0, totalFormats) for groupName, group := range cfg.Chains { files, err := processChainGroup(groupName, group, expiryDuration) if err != nil { @@ -168,68 +173,79 @@ func parseDuration(s string) (time.Duration, error) { } func processChainGroup(groupName string, group ChainGroup, expiryDuration time.Duration) ([]string, error) { - var createdFiles []string - // Default encoding to "pem" if not specified encoding := group.Outputs.Encoding if encoding == "" { encoding = "pem" } - // Collect data from all chains in the group + // Collect certificates from all chains in the group + singleFileCerts, individualCerts, err := loadAndCollectCerts(group.Certs, group.Outputs, expiryDuration) + if err != nil { + return nil, err + } + + // Prepare files for inclusion in archives + archiveFiles, err := prepareArchiveFiles(singleFileCerts, individualCerts, group.Outputs, encoding) + if err != nil { + return nil, err + } + + // Create archives for the entire group + createdFiles, err := createArchiveFiles(groupName, group.Outputs.Formats, archiveFiles) + if err != nil { + return nil, err + } + + return createdFiles, nil +} + +// loadAndCollectCerts loads all certificates from chains and collects them for processing +func loadAndCollectCerts(chains []CertChain, outputs Outputs, expiryDuration time.Duration) ([]*x509.Certificate, []certWithPath, error) { var singleFileCerts []*x509.Certificate var individualCerts []certWithPath - for _, chain := range group.Certs { - // Step 1: Load all certificates for this chain - allCerts := make(map[string]*x509.Certificate) - + for _, chain := range chains { // Load root certificate rootCert, err := certlib.LoadCertificate(chain.Root) if err != nil { - return nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err) + return nil, nil, fmt.Errorf("failed to load root certificate %s: %v", chain.Root, err) } - allCerts[chain.Root] = rootCert // Check expiry for root checkExpiry(chain.Root, rootCert, expiryDuration) - // Add root to single file if needed - if group.Outputs.IncludeSingle { + // Add root to collections if needed + if outputs.IncludeSingle { singleFileCerts = append(singleFileCerts, rootCert) } - - // Add root to individual files if needed - if group.Outputs.IncludeIndividual { + if outputs.IncludeIndividual { individualCerts = append(individualCerts, certWithPath{ cert: rootCert, path: chain.Root, }) } - // Step 2: Load and validate intermediates + // Load and validate intermediates for _, intPath := range chain.Intermediates { intCert, err := certlib.LoadCertificate(intPath) if err != nil { - return nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err) + return nil, nil, fmt.Errorf("failed to load intermediate certificate %s: %v", intPath, err) } - allCerts[intPath] = intCert // Validate that intermediate is signed by root if err := intCert.CheckSignatureFrom(rootCert); err != nil { - return nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err) + return nil, nil, fmt.Errorf("intermediate %s is not properly signed by root %s: %v", intPath, chain.Root, err) } // Check expiry for intermediate checkExpiry(intPath, intCert, expiryDuration) - // Add intermediate to a single file if needed - if group.Outputs.IncludeSingle { + // Add intermediate to collections if needed + if outputs.IncludeSingle { singleFileCerts = append(singleFileCerts, intCert) } - - // Add intermediate to individual files if needed - if group.Outputs.IncludeIndividual { + if outputs.IncludeIndividual { individualCerts = append(individualCerts, certWithPath{ cert: intCert, path: intPath, @@ -238,11 +254,15 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D } } - // Prepare files for inclusion in archives for the entire group. + return singleFileCerts, individualCerts, nil +} + +// prepareArchiveFiles prepares all files to be included in archives +func prepareArchiveFiles(singleFileCerts []*x509.Certificate, individualCerts []certWithPath, outputs Outputs, encoding string) ([]fileEntry, error) { var archiveFiles []fileEntry - // Handle a single bundle file. - if group.Outputs.IncludeSingle && len(singleFileCerts) > 0 { + // Handle a single bundle file + if outputs.IncludeSingle && len(singleFileCerts) > 0 { files, err := encodeCertsToFiles(singleFileCerts, "bundle", encoding, true) if err != nil { return nil, fmt.Errorf("failed to encode single bundle: %v", err) @@ -251,7 +271,7 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D } // Handle individual files - if group.Outputs.IncludeIndividual { + if outputs.IncludeIndividual { for _, cp := range individualCerts { baseName := strings.TrimSuffix(filepath.Base(cp.path), filepath.Ext(cp.path)) files, err := encodeCertsToFiles([]*x509.Certificate{cp.cert}, baseName, encoding, false) @@ -263,7 +283,7 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D } // Generate manifest if requested - if group.Outputs.Manifest { + if outputs.Manifest { manifestContent := generateManifest(archiveFiles) archiveFiles = append(archiveFiles, fileEntry{ name: "MANIFEST", @@ -271,8 +291,14 @@ func processChainGroup(groupName string, group ChainGroup, expiryDuration time.D }) } - // Create archives for the entire group - for _, format := range group.Outputs.Formats { + return archiveFiles, nil +} + +// createArchiveFiles creates archive files in the specified formats +func createArchiveFiles(groupName string, formats []string, archiveFiles []fileEntry) ([]string, error) { + createdFiles := make([]string, 0, len(formats)) + + for _, format := range formats { ext, ok := formatExtensions[format] if !ok { return nil, fmt.Errorf("unsupported format: %s", format) @@ -327,10 +353,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str switch encoding { case "pem": - pemContent, err := encodeCertsToPEM(certs, isSingle) - if err != nil { - return nil, err - } + pemContent := encodeCertsToPEM(certs) files = append(files, fileEntry{ name: baseName + ".pem", content: pemContent, @@ -357,10 +380,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str } case "both": // Add PEM version - pemContent, err := encodeCertsToPEM(certs, isSingle) - if err != nil { - return nil, err - } + pemContent := encodeCertsToPEM(certs) files = append(files, fileEntry{ name: baseName + ".pem", content: pemContent, @@ -391,7 +411,7 @@ func encodeCertsToFiles(certs []*x509.Certificate, baseName string, encoding str } // encodeCertsToPEM encodes certificates to PEM format -func encodeCertsToPEM(certs []*x509.Certificate, concatenate bool) ([]byte, error) { +func encodeCertsToPEM(certs []*x509.Certificate) []byte { var pemContent []byte for _, cert := range certs { pemBlock := &pem.Block{ @@ -400,7 +420,7 @@ func encodeCertsToPEM(certs []*x509.Certificate, concatenate bool) ([]byte, erro } pemContent = append(pemContent, pem.EncodeToMemory(pemBlock)...) } - return pemContent, nil + return pemContent } func generateManifest(files []fileEntry) []byte { @@ -420,22 +440,29 @@ func createZipArchive(path string, files []fileEntry) error { if err != nil { return err } - defer f.Close() w := zip.NewWriter(f) - defer w.Close() for _, file := range files { fw, err := w.Create(file.name) if err != nil { + w.Close() + f.Close() return err } if _, err := fw.Write(file.content); err != nil { + w.Close() + f.Close() return err } } - return nil + // Check errors on close operations + if err := w.Close(); err != nil { + f.Close() + return err + } + return f.Close() } func createTarGzArchive(path string, files []fileEntry) error { @@ -443,13 +470,9 @@ func createTarGzArchive(path string, files []fileEntry) error { if err != nil { return err } - defer f.Close() gw := gzip.NewWriter(f) - defer gw.Close() - tw := tar.NewWriter(gw) - defer tw.Close() for _, file := range files { hdr := &tar.Header{ @@ -458,14 +481,30 @@ func createTarGzArchive(path string, files []fileEntry) error { Size: int64(len(file.content)), } if err := tw.WriteHeader(hdr); err != nil { + tw.Close() + gw.Close() + f.Close() return err } if _, err := tw.Write(file.content); err != nil { + tw.Close() + gw.Close() + f.Close() return err } } - return nil + // Check errors on close operations in the correct order + if err := tw.Close(); err != nil { + gw.Close() + f.Close() + return err + } + if err := gw.Close(); err != nil { + f.Close() + return err + } + return f.Close() } func generateHashFile(path string, files []string) error { diff --git a/cmd/cert-bundler/prompt.txt b/cmd/cert-bundler/prompt.txt index 3651d66..9a3ea61 100644 --- a/cmd/cert-bundler/prompt.txt +++ b/cmd/cert-bundler/prompt.txt @@ -186,4 +186,9 @@ to provide the same detailed information. It may be easier to embed the README.txt in the program on build. +----- + +For the archive (tar.gz and zip) writers, make sure errors are +checked at the end, and don't just defer the close operations. +