diff --git a/.golangci.yml b/.golangci.yml index a8158b5..76095d4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -18,6 +18,19 @@ issues: # Default: 3 max-same-issues: 50 + # Exclude some lints for CLI programs under cmd/ (package main). + # The project allows fmt.Print* in command-line tools; keep forbidigo for libraries. + exclude-rules: + - path: ^cmd/ + linters: + - forbidigo + - path: cmd/.* + linters: + - forbidigo + - path: .*/cmd/.* + linters: + - forbidigo + formatters: enable: - goimports # checks if the code and import statements are formatted according to the 'goimports' command diff --git a/cmd/cert-revcheck/main.go b/cmd/cert-revcheck/main.go index a1ecd11..037bea7 100644 --- a/cmd/cert-revcheck/main.go +++ b/cmd/cert-revcheck/main.go @@ -1,20 +1,20 @@ package main import ( - "crypto/tls" - "crypto/x509" - "flag" - "errors" - "fmt" - "io/ioutil" - "net" - "os" - "time" + "crypto/tls" + "crypto/x509" + "errors" + "flag" + "fmt" + "net" + "os" + "strings" + "time" - "git.wntrmute.dev/kyle/goutils/certlib" - hosts "git.wntrmute.dev/kyle/goutils/certlib/hosts" - "git.wntrmute.dev/kyle/goutils/certlib/revoke" - "git.wntrmute.dev/kyle/goutils/fileutil" + "git.wntrmute.dev/kyle/goutils/certlib" + hosts "git.wntrmute.dev/kyle/goutils/certlib/hosts" + "git.wntrmute.dev/kyle/goutils/certlib/revoke" + "git.wntrmute.dev/kyle/goutils/fileutil" ) var ( @@ -23,6 +23,13 @@ var ( verbose bool ) +var ( + strOK = "OK" + strExpired = "EXPIRED" + strRevoked = "REVOKED" + strUnknown = "UNKNOWN" +) + func main() { flag.BoolVar(&hardfail, "hardfail", false, "treat revocation check failures as fatal") flag.DurationVar(&timeout, "timeout", 10*time.Second, "network timeout for OCSP/CRL fetches and TLS site connects") @@ -42,16 +49,16 @@ func main() { for _, target := range flag.Args() { status, err := processTarget(target) switch status { - case "OK": - fmt.Printf("%s: OK\n", target) - case "EXPIRED": - fmt.Printf("%s: EXPIRED: %v\n", target, err) + case strOK: + fmt.Printf("%s: %s\n", target, strOK) + case strExpired: + fmt.Printf("%s: %s: %v\n", target, strExpired, err) exitCode = 1 - case "REVOKED": - fmt.Printf("%s: REVOKED\n", target) + case strRevoked: + fmt.Printf("%s: %s\n", target, strRevoked) exitCode = 1 - case "UNKNOWN": - fmt.Printf("%s: UNKNOWN: %v\n", target, err) + case strUnknown: + fmt.Printf("%s: %s: %v\n", target, strUnknown, err) if hardfail { // In hardfail, treat unknown as failure exitCode = 1 @@ -67,74 +74,67 @@ func processTarget(target string) (string, error) { return checkFile(target) } - // Not a file; treat as site return checkSite(target) } func checkFile(path string) (string, error) { - in, err := ioutil.ReadFile(path) - if err != nil { - return "UNKNOWN", err - } + // Prefer high-level helpers from certlib to load certificates from disk + if certs, err := certlib.LoadCertificates(path); err == nil && len(certs) > 0 { + // Evaluate the first certificate (leaf) by default + return evaluateCert(certs[0]) + } - // Try PEM first; if that fails, try single DER cert - certs, err := certlib.ReadCertificates(in) - if err != nil || len(certs) == 0 { - cert, _, derr := certlib.ReadCertificate(in) - if derr != nil || cert == nil { - if err == nil { - err = derr - } - return "UNKNOWN", err - } - return evaluateCert(cert) - } - - // Evaluate the first certificate (leaf) by default - return evaluateCert(certs[0]) + cert, err := certlib.LoadCertificate(path) + if err != nil || cert == nil { + return strUnknown, err + } + return evaluateCert(cert) } func checkSite(hostport string) (string, error) { // Use certlib/hosts to parse host/port (supports https URLs and host:port) target, err := hosts.ParseHost(hostport) if err != nil { - return "UNKNOWN", err + return strUnknown, err } d := &net.Dialer{Timeout: timeout} - conn, err := tls.DialWithDialer(d, "tcp", target.String(), &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}) + conn, err := tls.DialWithDialer( + d, + "tcp", + target.String(), + &tls.Config{InsecureSkipVerify: true, ServerName: target.Host}, // #nosec G402 + ) if err != nil { - return "UNKNOWN", err + return strUnknown, err } defer conn.Close() state := conn.ConnectionState() if len(state.PeerCertificates) == 0 { - return "UNKNOWN", errors.New("no peer certificates presented") + return strUnknown, errors.New("no peer certificates presented") } return evaluateCert(state.PeerCertificates[0]) } func evaluateCert(cert *x509.Certificate) (string, error) { - // Expiry check - now := time.Now() - if !now.Before(cert.NotAfter) { - return "EXPIRED", fmt.Errorf("expired at %s", cert.NotAfter) - } - if !now.After(cert.NotBefore) { - return "EXPIRED", fmt.Errorf("not valid until %s", cert.NotBefore) - } + // Delegate validity and revocation checks to certlib/revoke helper. + // It returns revoked=true for both revoked and expired/not-yet-valid. + // Map those cases back to our statuses using the returned error text. + revoked, ok, err := revoke.VerifyCertificateError(cert) + if revoked { + if err != nil { + msg := err.Error() + if strings.Contains(msg, "expired") || strings.Contains(msg, "isn't valid until") || strings.Contains(msg, "not valid until") { + return strExpired, err + } + } + return strRevoked, err + } + if !ok { + // Revocation status could not be determined + return strUnknown, err + } - // Revocation check using certlib/revoke - revoked, ok, err := revoke.VerifyCertificateError(cert) - if revoked { - // If revoked is true, ok will be true per implementation, err may describe why - return "REVOKED", err - } - if !ok { - // Revocation status could not be determined - return "UNKNOWN", err - } - - return "OK", nil + return strOK, nil } diff --git a/cmd/certchain/certchain.go b/cmd/certchain/certchain.go index 8815f1f..45fb44a 100644 --- a/cmd/certchain/certchain.go +++ b/cmd/certchain/certchain.go @@ -5,6 +5,7 @@ import ( "encoding/pem" "flag" "fmt" + "os" "regexp" "git.wntrmute.dev/kyle/goutils/die" @@ -34,6 +35,6 @@ func main() { chain += string(pem.EncodeToMemory(&p)) } - fmt.Println(chain) + fmt.Fprintln(os.Stdout, chain) } } diff --git a/cmd/certdump/certdump.go b/cmd/certdump/certdump.go index de0c66e..392353c 100644 --- a/cmd/certdump/certdump.go +++ b/cmd/certdump/certdump.go @@ -101,30 +101,30 @@ func extUsage(ext []x509.ExtKeyUsage) string { } func showBasicConstraints(cert *x509.Certificate) { - fmt.Printf("\tBasic constraints: ") - if cert.BasicConstraintsValid { - fmt.Printf("valid") - } else { - fmt.Printf("invalid") - } + fmt.Fprint(os.Stdout, "\tBasic constraints: ") + if cert.BasicConstraintsValid { + fmt.Fprint(os.Stdout, "valid") + } else { + fmt.Fprint(os.Stdout, "invalid") + } - if cert.IsCA { - fmt.Printf(", is a CA certificate") - if !cert.BasicConstraintsValid { - fmt.Printf(" (basic constraint failure)") - } - } else { - fmt.Printf("is not a CA certificate") - if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { - fmt.Printf(" (key encipherment usage enabled!)") - } - } + if cert.IsCA { + fmt.Fprint(os.Stdout, ", is a CA certificate") + if !cert.BasicConstraintsValid { + fmt.Fprint(os.Stdout, " (basic constraint failure)") + } + } else { + fmt.Fprint(os.Stdout, "is not a CA certificate") + if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 { + fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)") + } + } - if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) { - fmt.Printf(", max path length %d", cert.MaxPathLen) - } + if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) { + fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen) + } - fmt.Printf("\n") + fmt.Fprintln(os.Stdout) } const oneTrueDateFormat = "2006-01-02T15:04:05-0700" @@ -135,41 +135,41 @@ var ( ) func wrapPrint(text string, indent int) { - tabs := "" - for i := 0; i < indent; i++ { - tabs += "\t" - } + tabs := "" + for i := 0; i < indent; i++ { + tabs += "\t" + } - fmt.Printf(tabs+"%s\n", wrap(text, indent)) + fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent)) } func displayCert(cert *x509.Certificate) { - fmt.Println("CERTIFICATE") - if showHash { - fmt.Println(wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0)) - } - fmt.Println(wrap("Subject: "+displayName(cert.Subject), 0)) - fmt.Println(wrap("Issuer: "+displayName(cert.Issuer), 0)) - fmt.Printf("\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm), - sigAlgoHash(cert.SignatureAlgorithm)) - fmt.Println("Details:") - wrapPrint("Public key: "+certPublic(cert), 1) - fmt.Printf("\tSerial number: %s\n", cert.SerialNumber) + fmt.Fprintln(os.Stdout, "CERTIFICATE") + if showHash { + fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0)) + } + fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0)) + fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0)) + fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm), + sigAlgoHash(cert.SignatureAlgorithm)) + fmt.Fprintln(os.Stdout, "Details:") + wrapPrint("Public key: "+certPublic(cert), 1) + fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber) - if len(cert.AuthorityKeyId) > 0 { - fmt.Printf("\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1)) - } - if len(cert.SubjectKeyId) > 0 { - fmt.Printf("\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1)) - } + if len(cert.AuthorityKeyId) > 0 { + fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1)) + } + if len(cert.SubjectKeyId) > 0 { + fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1)) + } wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1) - fmt.Printf("\t until: %s\n", cert.NotAfter.Format(dateFormat)) - fmt.Printf("\tKey usages: %s\n", keyUsages(cert.KeyUsage)) + fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat)) + fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage)) - if len(cert.ExtKeyUsage) > 0 { - fmt.Printf("\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage)) - } + if len(cert.ExtKeyUsage) > 0 { + fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage)) + } showBasicConstraints(cert) @@ -217,19 +217,19 @@ func displayCert(cert *x509.Certificate) { } func displayAllCerts(in []byte, leafOnly bool) { - certs, err := certlib.ParseCertificatesPEM(in) - if err != nil { - certs, _, err = certlib.ParseCertificatesDER(in, "") - if err != nil { - lib.Warn(err, "failed to parse certificates") - return - } - } + certs, err := certlib.ParseCertificatesPEM(in) + if err != nil { + certs, _, err = certlib.ParseCertificatesDER(in, "") + if err != nil { + _, _ = lib.Warn(err, "failed to parse certificates") + return + } + } if len(certs) == 0 { - lib.Warnx("no certificates found") - return - } + _, _ = lib.Warnx("no certificates found") + return + } if leafOnly { displayCert(certs[0]) @@ -243,11 +243,11 @@ func displayAllCerts(in []byte, leafOnly bool) { func displayAllCertsWeb(uri string, leafOnly bool) { ci := getConnInfo(uri) - conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig()) - if err != nil { - lib.Warn(err, "couldn't connect to %s", ci.Addr) - return - } + conn, err := tls.Dial("tcp", ci.Addr, permissiveConfig()) + if err != nil { + _, _ = lib.Warn(err, "couldn't connect to %s", ci.Addr) + return + } defer conn.Close() state := conn.ConnectionState() @@ -260,34 +260,34 @@ func displayAllCertsWeb(uri string, leafOnly bool) { state = conn.ConnectionState() } conn.Close() - } else { - lib.Warn(err, "TLS verification error with server name %s", ci.Host) - } + } else { + _, _ = lib.Warn(err, "TLS verification error with server name %s", ci.Host) + } - if len(state.PeerCertificates) == 0 { - lib.Warnx("no certificates found") - return - } + if len(state.PeerCertificates) == 0 { + _, _ = lib.Warnx("no certificates found") + return + } if leafOnly { displayCert(state.PeerCertificates[0]) return } - if len(state.VerifiedChains) == 0 { - lib.Warnx("no verified chains found; using peer chain") - for i := range state.PeerCertificates { - displayCert(state.PeerCertificates[i]) - } - } else { - fmt.Println("TLS chain verified successfully.") - for i := range state.VerifiedChains { - fmt.Printf("--- Verified certificate chain %d ---\n", i+1) - for j := range state.VerifiedChains[i] { - displayCert(state.VerifiedChains[i][j]) - } - } - } + if len(state.VerifiedChains) == 0 { + _, _ = lib.Warnx("no verified chains found; using peer chain") + for i := range state.PeerCertificates { + displayCert(state.PeerCertificates[i]) + } + } else { + fmt.Fprintln(os.Stdout, "TLS chain verified successfully.") + for i := range state.VerifiedChains { + fmt.Fprintf(os.Stdout, "--- Verified certificate chain %d ---%s", i+1, "\n") + for j := range state.VerifiedChains[i] { + displayCert(state.VerifiedChains[i][j]) + } + } + } } func main() { @@ -298,11 +298,11 @@ func main() { flag.Parse() if flag.NArg() == 0 || (flag.NArg() == 1 && flag.Arg(0) == "-") { - certs, err := io.ReadAll(os.Stdin) - if err != nil { - lib.Warn(err, "couldn't read certificates from standard input") - os.Exit(1) - } + certs, err := io.ReadAll(os.Stdin) + if err != nil { + _, _ = lib.Warn(err, "couldn't read certificates from standard input") + os.Exit(1) + } // This is needed for getting certs from JSON/jq. certs = bytes.TrimSpace(certs) @@ -311,15 +311,15 @@ func main() { displayAllCerts(certs, leafOnly) } else { for _, filename := range flag.Args() { - fmt.Printf("--%s ---\n", filename) + fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n") if strings.HasPrefix(filename, "https://") { displayAllCertsWeb(filename, leafOnly) } else { - in, err := os.ReadFile(filename) - if err != nil { - lib.Warn(err, "couldn't read certificate") - continue - } + in, err := os.ReadFile(filename) + if err != nil { + _, _ = lib.Warn(err, "couldn't read certificate") + continue + } displayAllCerts(in, leafOnly) } diff --git a/cmd/certdump/util.go b/cmd/certdump/util.go index c445dd3..f4c818d 100644 --- a/cmd/certdump/util.go +++ b/cmd/certdump/util.go @@ -26,64 +26,85 @@ var keyUsage = map[x509.KeyUsage]string{ } var extKeyUsages = map[x509.ExtKeyUsage]string{ - x509.ExtKeyUsageAny: "any", - x509.ExtKeyUsageServerAuth: "server auth", - x509.ExtKeyUsageClientAuth: "client auth", - x509.ExtKeyUsageCodeSigning: "code signing", - x509.ExtKeyUsageEmailProtection: "s/mime", - x509.ExtKeyUsageIPSECEndSystem: "ipsec end system", - x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel", - x509.ExtKeyUsageIPSECUser: "ipsec user", - x509.ExtKeyUsageTimeStamping: "timestamping", - x509.ExtKeyUsageOCSPSigning: "ocsp signing", - x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc", - x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc", + x509.ExtKeyUsageAny: "any", + x509.ExtKeyUsageServerAuth: "server auth", + x509.ExtKeyUsageClientAuth: "client auth", + x509.ExtKeyUsageCodeSigning: "code signing", + x509.ExtKeyUsageEmailProtection: "s/mime", + x509.ExtKeyUsageIPSECEndSystem: "ipsec end system", + x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel", + x509.ExtKeyUsageIPSECUser: "ipsec user", + x509.ExtKeyUsageTimeStamping: "timestamping", + x509.ExtKeyUsageOCSPSigning: "ocsp signing", + x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc", + x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc", + x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing", + x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing", } func pubKeyAlgo(a x509.PublicKeyAlgorithm) string { - switch a { - case x509.RSA: - return "RSA" - case x509.ECDSA: - return "ECDSA" - case x509.DSA: - return "DSA" - default: - return "unknown public key algorithm" - } + switch a { + case x509.UnknownPublicKeyAlgorithm: + return "unknown public key algorithm" + case x509.RSA: + return "RSA" + case x509.ECDSA: + return "ECDSA" + case x509.DSA: + return "DSA" + case x509.Ed25519: + return "Ed25519" + default: + return "unknown public key algorithm" + } } func sigAlgoPK(a x509.SignatureAlgorithm) string { - switch a { - - case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: - return "RSA" - case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: - return "ECDSA" - case x509.DSAWithSHA1, x509.DSAWithSHA256: - return "DSA" - default: - return "unknown public key algorithm" - } + switch a { + case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: + return "RSA" + case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS: + return "RSA-PSS" + case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512: + return "ECDSA" + case x509.DSAWithSHA1, x509.DSAWithSHA256: + return "DSA" + case x509.PureEd25519: + return "Ed25519" + case x509.UnknownSignatureAlgorithm: + return "unknown public key algorithm" + default: + return "unknown public key algorithm" + } } func sigAlgoHash(a x509.SignatureAlgorithm) string { - switch a { - case x509.MD2WithRSA: - return "MD2" - case x509.MD5WithRSA: - return "MD5" - case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1: - return "SHA1" - case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256: - return "SHA256" - case x509.SHA384WithRSA, x509.ECDSAWithSHA384: - return "SHA384" - case x509.SHA512WithRSA, x509.ECDSAWithSHA512: - return "SHA512" - default: - return "unknown hash algorithm" - } + switch a { + case x509.MD2WithRSA: + return "MD2" + case x509.MD5WithRSA: + return "MD5" + case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1: + return "SHA1" + case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256: + return "SHA256" + case x509.SHA256WithRSAPSS: + return "SHA256" + case x509.SHA384WithRSA, x509.ECDSAWithSHA384: + return "SHA384" + case x509.SHA384WithRSAPSS: + return "SHA384" + case x509.SHA512WithRSA, x509.ECDSAWithSHA512: + return "SHA512" + case x509.SHA512WithRSAPSS: + return "SHA512" + case x509.PureEd25519: + return "SHA512" + case x509.UnknownSignatureAlgorithm: + return "unknown hash algorithm" + default: + return "unknown hash algorithm" + } } const maxLine = 78 diff --git a/cmd/certexpiry/main.go b/cmd/certexpiry/main.go index 1ce62a8..c0ed5c7 100644 --- a/cmd/certexpiry/main.go +++ b/cmd/certexpiry/main.go @@ -1,18 +1,17 @@ package main import ( - "crypto/x509" - "crypto/x509/pkix" - "flag" - "fmt" - "io/ioutil" - "os" - "strings" - "time" + "crypto/x509" + "crypto/x509/pkix" + "flag" + "fmt" + "os" + "strings" + "time" - "git.wntrmute.dev/kyle/goutils/certlib" - "git.wntrmute.dev/kyle/goutils/die" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/certlib" + "git.wntrmute.dev/kyle/goutils/die" + "git.wntrmute.dev/kyle/goutils/lib" ) var warnOnly bool @@ -80,21 +79,21 @@ func main() { flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.Parse() - for _, file := range flag.Args() { - in, err := ioutil.ReadFile(file) - if err != nil { - lib.Warn(err, "failed to read file") - continue - } + for _, file := range flag.Args() { + in, err := os.ReadFile(file) + if err != nil { + _, _ = lib.Warn(err, "failed to read file") + continue + } - certs, err := certlib.ParseCertificatesPEM(in) - if err != nil { - lib.Warn(err, "while parsing certificates") - continue - } + certs, err := certlib.ParseCertificatesPEM(in) + if err != nil { + _, _ = lib.Warn(err, "while parsing certificates") + continue + } - for _, cert := range certs { - checkCert(cert) - } - } + for _, cert := range certs { + checkCert(cert) + } + } } diff --git a/cmd/clustersh/main.go b/cmd/clustersh/main.go index 559f357..f168b8f 100644 --- a/cmd/clustersh/main.go +++ b/cmd/clustersh/main.go @@ -1,15 +1,16 @@ package main import ( - "bufio" - "flag" - "fmt" - "io" - "log" - "net" - "os" - "strings" - "sync" + "bufio" + "flag" + "fmt" + "io" + "errors" + "log" + "net" + "os" + "strings" + "sync" "git.wntrmute.dev/kyle/goutils/lib" "github.com/pkg/sftp" @@ -92,12 +93,12 @@ func exec(wg *sync.WaitGroup, user, host string, commands []string) { defer func() { for i := len(shutdown) - 1; i >= 0; i-- { - err := shutdown[i]() - if err != nil && err != io.EOF { - logError(host, err, "shutting down") - } - } - }() + err := shutdown[i]() + if err != nil && !errors.Is(err, io.EOF) { + logError(host, err, "shutting down") + } + } + }() defer wg.Done() conf := sshConfig(user) @@ -149,12 +150,12 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) { defer func() { for i := len(shutdown) - 1; i >= 0; i-- { - err := shutdown[i]() - if err != nil && err != io.EOF { - logError(host, err, "shutting down") - } - } - }() + err := shutdown[i]() + if err != nil && !errors.Is(err, io.EOF) { + logError(host, err, "shutting down") + } + } + }() defer wg.Done() conf := sshConfig(user) @@ -199,13 +200,13 @@ func upload(wg *sync.WaitGroup, user, host, local, remote string) { fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) } - if err == io.EOF { - break - } else if err != nil { - logError(host, err, "reading chunk") - return - } - } + if errors.Is(err, io.EOF) { + break + } else if err != nil { + logError(host, err, "reading chunk") + return + } + } fmt.Printf("[%s] %s uploaded to %s\n", host, remote, local) } @@ -214,12 +215,12 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) { defer func() { for i := len(shutdown) - 1; i >= 0; i-- { - err := shutdown[i]() - if err != nil && err != io.EOF { - logError(host, err, "shutting down") - } - } - }() + err := shutdown[i]() + if err != nil && !errors.Is(err, io.EOF) { + logError(host, err, "shutting down") + } + } + }() defer wg.Done() conf := sshConfig(user) @@ -265,12 +266,12 @@ func download(wg *sync.WaitGroup, user, host, local, remote string) { fmt.Printf("[%s] wrote %d-byte chunk\n", host, n) } - if err == io.EOF { - break - } else if err != nil { - logError(host, err, "reading chunk") - return - } + if errors.Is(err, io.EOF) { + break + } else if err != nil { + logError(host, err, "reading chunk") + return + } } fmt.Printf("[%s] %s downloaded to %s\n", host, remote, local) } diff --git a/cmd/cruntar/main.go b/cmd/cruntar/main.go index 95dcd56..c6ea64e 100644 --- a/cmd/cruntar/main.go +++ b/cmd/cruntar/main.go @@ -262,11 +262,11 @@ func main() { tfr := tar.NewReader(r) for { - hdr, err := tfr.Next() - if err == io.EOF { - break - } - die.If(err) + hdr, err := tfr.Next() + if errors.Is(err, io.EOF) { + break + } + die.If(err) err = processFile(tfr, hdr, top) die.If(err) diff --git a/cmd/csrpubdump/pubdump.go b/cmd/csrpubdump/pubdump.go index a91ce85..a608621 100644 --- a/cmd/csrpubdump/pubdump.go +++ b/cmd/csrpubdump/pubdump.go @@ -7,8 +7,7 @@ import ( "encoding/pem" "flag" "fmt" - "io/ioutil" - "log" + "os" "git.wntrmute.dev/kyle/goutils/die" ) @@ -17,12 +16,12 @@ func main() { flag.Parse() for _, fileName := range flag.Args() { - in, err := ioutil.ReadFile(fileName) + in, err := os.ReadFile(fileName) die.If(err) if p, _ := pem.Decode(in); p != nil { if p.Type != "CERTIFICATE REQUEST" { - log.Fatal("INVALID FILE TYPE") + die.With("INVALID FILE TYPE") } in = p.Bytes } @@ -48,8 +47,8 @@ func main() { Bytes: out, } - err = ioutil.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0644) - die.If(err) - fmt.Printf("[+] wrote %s.\n", fileName+".pub") + err = os.WriteFile(fileName+".pub", pem.EncodeToMemory(p), 0o644) + die.If(err) + fmt.Fprintf(os.Stdout, "[+] wrote %s.\n", fileName+".pub") } } diff --git a/cmd/fragment/fragment.go b/cmd/fragment/fragment.go index 3d0eca9..c64fe69 100644 --- a/cmd/fragment/fragment.go +++ b/cmd/fragment/fragment.go @@ -95,12 +95,12 @@ func main() { return false } - fmtStr += "\n" - for i := start; !endFunc(i); i++ { - if *quiet { - fmt.Println(lines[i]) - } else { - fmt.Printf(fmtStr, i, lines[i]) - } - } + fmtStr += "\n" + for i := start; !endFunc(i); i++ { + if *quiet { + fmt.Fprintln(os.Stdout, lines[i]) + } else { + fmt.Fprintf(os.Stdout, fmtStr, i, lines[i]) + } + } } diff --git a/cmd/jlp/jlp.go b/cmd/jlp/jlp.go index 4e74b02..49b1c62 100644 --- a/cmd/jlp/jlp.go +++ b/cmd/jlp/jlp.go @@ -1,51 +1,51 @@ package main import ( - "bytes" - "encoding/json" - "flag" - "fmt" - "io/ioutil" - "os" + "bytes" + "encoding/json" + "flag" + "fmt" + "io" + "os" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib" ) func prettify(file string, validateOnly bool) error { var in []byte var err error - if file == "-" { - in, err = ioutil.ReadAll(os.Stdin) - } else { - in, err = ioutil.ReadFile(file) - } + if file == "-" { + in, err = io.ReadAll(os.Stdin) + } else { + in, err = os.ReadFile(file) + } - if err != nil { - lib.Warn(err, "ReadFile") - return err - } + if err != nil { + _, _ = lib.Warn(err, "ReadFile") + return err + } var buf = &bytes.Buffer{} err = json.Indent(buf, in, "", " ") - if err != nil { - lib.Warn(err, "%s", file) - return err - } + if err != nil { + _, _ = lib.Warn(err, "%s", file) + return err + } if validateOnly { return nil } - if file == "-" { - _, err = os.Stdout.Write(buf.Bytes()) - } else { - err = ioutil.WriteFile(file, buf.Bytes(), 0644) - } + if file == "-" { + _, err = os.Stdout.Write(buf.Bytes()) + } else { + err = os.WriteFile(file, buf.Bytes(), 0o644) + } - if err != nil { - lib.Warn(err, "WriteFile") - } + if err != nil { + _, _ = lib.Warn(err, "WriteFile") + } return err } @@ -54,44 +54,44 @@ func compact(file string, validateOnly bool) error { var in []byte var err error - if file == "-" { - in, err = ioutil.ReadAll(os.Stdin) - } else { - in, err = ioutil.ReadFile(file) - } + if file == "-" { + in, err = io.ReadAll(os.Stdin) + } else { + in, err = os.ReadFile(file) + } - if err != nil { - lib.Warn(err, "ReadFile") - return err - } + if err != nil { + _, _ = lib.Warn(err, "ReadFile") + return err + } var buf = &bytes.Buffer{} err = json.Compact(buf, in) - if err != nil { - lib.Warn(err, "%s", file) - return err - } + if err != nil { + _, _ = lib.Warn(err, "%s", file) + return err + } if validateOnly { return nil } - if file == "-" { - _, err = os.Stdout.Write(buf.Bytes()) - } else { - err = ioutil.WriteFile(file, buf.Bytes(), 0644) - } + if file == "-" { + _, err = os.Stdout.Write(buf.Bytes()) + } else { + err = os.WriteFile(file, buf.Bytes(), 0o644) + } - if err != nil { - lib.Warn(err, "WriteFile") - } + if err != nil { + _, _ = lib.Warn(err, "WriteFile") + } return err } func usage() { - progname := lib.ProgName() - fmt.Printf(`Usage: %s [-h] files... + progname := lib.ProgName() + fmt.Fprintf(os.Stdout, `Usage: %s [-h] files... %s is used to lint and prettify (or compact) JSON files. The files will be updated in-place. diff --git a/cmd/minmax/minmax.go b/cmd/minmax/minmax.go index 146b369..d369d88 100644 --- a/cmd/minmax/minmax.go +++ b/cmd/minmax/minmax.go @@ -46,8 +46,8 @@ func main() { max, err := strconv.Atoi(flag.Arg(2)) dieIf(err) - code := kind << 6 - code += (min << 3) - code += max - fmt.Printf("%0o\n", code) + code := kind << 6 + code += (min << 3) + code += max + fmt.Fprintf(os.Stdout, "%0o\n", code) } diff --git a/cmd/pembody/pembody.go b/cmd/pembody/pembody.go index b26fbe3..f885531 100644 --- a/cmd/pembody/pembody.go +++ b/cmd/pembody/pembody.go @@ -3,8 +3,7 @@ package main import ( "encoding/pem" "flag" - "fmt" - "io/ioutil" + "io" "os" "git.wntrmute.dev/kyle/goutils/lib" @@ -19,19 +18,21 @@ func main() { var in []byte var err error - path := flag.Arg(0) - if path == "-" { - in, err = ioutil.ReadAll(os.Stdin) - } else { - in, err = ioutil.ReadFile(flag.Arg(0)) - } + path := flag.Arg(0) + if path == "-" { + in, err = io.ReadAll(os.Stdin) + } else { + in, err = os.ReadFile(flag.Arg(0)) + } if err != nil { lib.Err(lib.ExitFailure, err, "couldn't read file") } p, _ := pem.Decode(in) - if p == nil { - lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) - } - fmt.Printf("%s", p.Bytes) + if p == nil { + lib.Errx(lib.ExitFailure, "%s isn't a PEM-encoded file", flag.Arg(0)) + } + if _, err := os.Stdout.Write(p.Bytes); err != nil { + lib.Err(lib.ExitFailure, err, "writing body") + } } diff --git a/cmd/readchain/chain.go b/cmd/readchain/chain.go index db87e92..89d5b4d 100644 --- a/cmd/readchain/chain.go +++ b/cmd/readchain/chain.go @@ -1,25 +1,24 @@ package main import ( - "crypto/x509" - "encoding/pem" - "flag" - "fmt" - "io/ioutil" - "os" + "crypto/x509" + "encoding/pem" + "flag" + "fmt" + "os" ) func main() { flag.Parse() for _, fileName := range flag.Args() { - data, err := ioutil.ReadFile(fileName) + data, err := os.ReadFile(fileName) if err != nil { fmt.Fprintf(os.Stderr, "[!] %s: %v\n", fileName, err) continue } - fmt.Printf("[+] %s:\n", fileName) + fmt.Fprintf(os.Stdout, "[+] %s:\n", fileName) rest := data[:] for { var p *pem.Block @@ -34,7 +33,7 @@ func main() { break } - fmt.Printf("\t%+v\n", cert.Subject.CommonName) + fmt.Fprintf(os.Stdout, "\t%+v\n", cert.Subject.CommonName) } } } diff --git a/cmd/renfnv/renfnv.go b/cmd/renfnv/renfnv.go index 8f6fc3e..1f9008c 100644 --- a/cmd/renfnv/renfnv.go +++ b/cmd/renfnv/renfnv.go @@ -109,27 +109,27 @@ func main() { for _, file := range flag.Args() { renamed, err := newName(file) - if err != nil { - lib.Warn(err, "failed to get new file name") - continue - } + if err != nil { + _, _ = lib.Warn(err, "failed to get new file name") + continue + } - if verbose && !printChanged { - fmt.Println(file) - } + if verbose && !printChanged { + fmt.Fprintln(os.Stdout, file) + } if renamed != file { if !dryRun { err = move(renamed, file, force) - if err != nil { - lib.Warn(err, "failed to rename file from %s to %s", file, renamed) - continue - } + if err != nil { + _, _ = lib.Warn(err, "failed to rename file from %s to %s", file, renamed) + continue + } } - if printChanged && !verbose { - fmt.Println(file, "->", renamed) - } + if printChanged && !verbose { + fmt.Fprintln(os.Stdout, file, "->", renamed) + } } } } diff --git a/cmd/rhash/main.go b/cmd/rhash/main.go index f8aaaf6..3df2e46 100644 --- a/cmd/rhash/main.go +++ b/cmd/rhash/main.go @@ -66,24 +66,19 @@ func main() { for _, remote := range flag.Args() { u, err := url.Parse(remote) if err != nil { - lib.Warn(err, "parsing %s", remote) + _, _ = lib.Warn(err, "parsing %s", remote) continue } name := filepath.Base(u.Path) if name == "" { - lib.Warnx("source URL doesn't appear to name a file") + _, _ = lib.Warnx("source URL doesn't appear to name a file") continue } resp, err := http.Get(remote) if err != nil { - lib.Warn(err, "fetching %s", remote) - continue - } - - if err != nil { - lib.Warn(err, "fetching %s", remote) + _, _ = lib.Warn(err, "fetching %s", remote) continue } diff --git a/cmd/rolldie/main.go b/cmd/rolldie/main.go index b1ce3af..aef7ce1 100644 --- a/cmd/rolldie/main.go +++ b/cmd/rolldie/main.go @@ -3,7 +3,7 @@ package main import ( "flag" "fmt" - "math/rand" + "math/rand/v2" "os" "regexp" "strconv" @@ -17,11 +17,11 @@ func rollDie(count, sides int) []int { sum := 0 var rolls []int - for i := 0; i < count; i++ { - roll := rand.Intn(sides) + 1 - sum += roll - rolls = append(rolls, roll) - } + for i := 0; i < count; i++ { + roll := rand.IntN(sides) + 1 + sum += roll + rolls = append(rolls, roll) + } rolls = append(rolls, sum) return rolls diff --git a/cmd/ski/main.go b/cmd/ski/main.go index e9cdc86..ac61e6d 100644 --- a/cmd/ski/main.go +++ b/cmd/ski/main.go @@ -1,24 +1,23 @@ package main import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/rsa" - "crypto/sha1" - "crypto/x509" - "crypto/x509/pkix" - "encoding/asn1" - "encoding/pem" - "flag" - "fmt" - "io" - "io/ioutil" - "os" - "strings" + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "crypto/x509/pkix" + "encoding/asn1" + "encoding/pem" + "flag" + "fmt" + "io" + "os" + "strings" - "git.wntrmute.dev/kyle/goutils/die" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/die" + "git.wntrmute.dev/kyle/goutils/lib" ) func usage(w io.Writer) { @@ -40,14 +39,14 @@ func init() { } func parse(path string) (public []byte, kt, ft string) { - data, err := ioutil.ReadFile(path) + data, err := os.ReadFile(path) die.If(err) data = bytes.TrimSpace(data) p, rest := pem.Decode(data) - if len(rest) > 0 { - lib.Warnx("trailing data in PEM file") - } + if len(rest) > 0 { + _, _ = lib.Warnx("trailing data in PEM file") + } if p == nil { die.With("no PEM data found") @@ -73,7 +72,7 @@ func parse(path string) (public []byte, kt, ft string) { } func parseKey(data []byte) (public []byte, kt string) { - privInterface, err := x509.ParsePKCS8PrivateKey(data) + privInterface, err := x509.ParsePKCS8PrivateKey(data) if err != nil { privInterface, err = x509.ParsePKCS1PrivateKey(data) if err != nil { @@ -85,12 +84,12 @@ func parseKey(data []byte) (public []byte, kt string) { } var priv crypto.Signer - switch privInterface.(type) { + switch p := privInterface.(type) { case *rsa.PrivateKey: - priv = privInterface.(*rsa.PrivateKey) + priv = p kt = "RSA" case *ecdsa.PrivateKey: - priv = privInterface.(*ecdsa.PrivateKey) + priv = p kt = "ECDSA" default: die.With("unknown private key type %T", privInterface) @@ -171,10 +170,10 @@ func main() { var subPKI subjectPublicKeyInfo _, err := asn1.Unmarshal(public, &subPKI) - if err != nil { - lib.Warn(err, "failed to get subject PKI") - continue - } + if err != nil { + _, _ = lib.Warn(err, "failed to get subject PKI") + continue + } pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) pubHashString := dumpHex(pubHash[:]) @@ -182,10 +181,10 @@ func main() { ski = pubHashString } - if shouldMatch && ski != pubHashString { - lib.Warnx("%s: SKI mismatch (%s != %s)", - path, ski, pubHashString) - } + if shouldMatch && ski != pubHashString { + _, _ = lib.Warnx("%s: SKI mismatch (%s != %s)", + path, ski, pubHashString) + } fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft) } } diff --git a/cmd/sprox/main.go b/cmd/sprox/main.go index e020a85..eaee652 100644 --- a/cmd/sprox/main.go +++ b/cmd/sprox/main.go @@ -3,26 +3,26 @@ package main import ( "flag" "io" - "log" "net" "git.wntrmute.dev/kyle/goutils/die" + "git.wntrmute.dev/kyle/goutils/lib" ) func proxy(conn net.Conn, inside string) error { - proxyConn, err := net.Dial("tcp", inside) - if err != nil { - return err - } + proxyConn, err := net.Dial("tcp", inside) + if err != nil { + return err + } defer proxyConn.Close() defer conn.Close() - go func() { - io.Copy(conn, proxyConn) - }() - _, err = io.Copy(proxyConn, conn) - return err + go func() { + _, _ = io.Copy(conn, proxyConn) + }() + _, err = io.Copy(proxyConn, conn) + return err } func main() { @@ -34,13 +34,17 @@ func main() { l, err := net.Listen("tcp", "0.0.0.0:"+outside) die.If(err) - for { - conn, err := l.Accept() - if err != nil { - log.Println(err) - continue - } + for { + conn, err := l.Accept() + if err != nil { + _, _ = lib.Warn(err, "accept failed") + continue + } - go proxy(conn, "127.0.0.1:"+inside) - } + go func() { + if err := proxy(conn, "127.0.0.1:"+inside); err != nil { + _, _ = lib.Warn(err, "proxy error") + } + }() + } } diff --git a/cmd/stealchain-server/main.go b/cmd/stealchain-server/main.go index 26b1666..badbb86 100644 --- a/cmd/stealchain-server/main.go +++ b/cmd/stealchain-server/main.go @@ -8,7 +8,6 @@ import ( "encoding/pem" "flag" "fmt" - "io/ioutil" "net" "os" @@ -46,18 +45,18 @@ func main() { os.Exit(1) } cfg.Certificates = append(cfg.Certificates, cert) - if sysRoot != "" { - pemList, err := ioutil.ReadFile(sysRoot) - die.If(err) + if sysRoot != "" { + pemList, err := os.ReadFile(sysRoot) + die.If(err) - roots := x509.NewCertPool() - if !roots.AppendCertsFromPEM(pemList) { - fmt.Printf("[!] no valid roots found") - roots = nil - } + roots := x509.NewCertPool() + if !roots.AppendCertsFromPEM(pemList) { + fmt.Printf("[!] no valid roots found") + roots = nil + } - cfg.RootCAs = roots - } + cfg.RootCAs = roots + } l, err := net.Listen("tcp", listenAddr) if err != nil { @@ -65,42 +64,46 @@ func main() { os.Exit(1) } - for { - conn, err := l.Accept() - if err != nil { - fmt.Println(err.Error()) - } - - raddr := conn.RemoteAddr() - tconn := tls.Server(conn, cfg) - err = tconn.Handshake() - if err != nil { - fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err) - continue - } - cs := tconn.ConnectionState() - if len(cs.PeerCertificates) == 0 { - fmt.Printf("[+] %v: no chain presented\n", raddr) - continue - } - - var chain []byte - for _, cert := range cs.PeerCertificates { - p := &pem.Block{ - Type: "CERTIFICATE", - Bytes: cert.Raw, - } - chain = append(chain, pem.EncodeToMemory(p)...) - } - - var nonce [16]byte - _, err = rand.Read(nonce[:]) - if err != nil { - panic(err) - } - fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:])) - err = ioutil.WriteFile(fname, chain, 0644) - die.If(err) - fmt.Printf("%v: [+] wrote %v.\n", raddr, fname) - } + for { + conn, err := l.Accept() + if err != nil { + fmt.Println(err.Error()) + continue + } + handleConn(conn, cfg) + } +} + +// handleConn performs a TLS handshake, extracts the peer chain, and writes it to a file. +func handleConn(conn net.Conn, cfg *tls.Config) { + defer conn.Close() + raddr := conn.RemoteAddr() + tconn := tls.Server(conn, cfg) + if err := tconn.Handshake(); err != nil { + fmt.Printf("[+] %v: failed to complete handshake: %v\n", raddr, err) + return + } + cs := tconn.ConnectionState() + if len(cs.PeerCertificates) == 0 { + fmt.Printf("[+] %v: no chain presented\n", raddr) + return + } + + var chain []byte + for _, cert := range cs.PeerCertificates { + p := &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw} + chain = append(chain, pem.EncodeToMemory(p)...) + } + + var nonce [16]byte + if _, err := rand.Read(nonce[:]); err != nil { + fmt.Printf("[+] %v: failed to generate filename nonce: %v\n", raddr, err) + return + } + fname := fmt.Sprintf("%v-%v.pem", raddr, hex.EncodeToString(nonce[:])) + if err := os.WriteFile(fname, chain, 0o644); err != nil { + fmt.Printf("[+] %v: failed to write %v: %v\n", raddr, fname, err) + return + } + fmt.Printf("%v: [+] wrote %v.\n", raddr, fname) } diff --git a/cmd/subjhash/main.go b/cmd/subjhash/main.go index d034949..7baf42b 100644 --- a/cmd/subjhash/main.go +++ b/cmd/subjhash/main.go @@ -57,16 +57,16 @@ func getSubjectInfoHash(cert *x509.Certificate, issuer bool) []byte { } func printDigests(paths []string, issuer bool) { - for _, path := range paths { - cert, err := certlib.LoadCertificate(path) - if err != nil { - lib.Warn(err, "failed to load certificate from %s", path) - continue - } + for _, path := range paths { + cert, err := certlib.LoadCertificate(path) + if err != nil { + _, _ = lib.Warn(err, "failed to load certificate from %s", path) + continue + } digest := getSubjectInfoHash(cert, issuer) - fmt.Printf("%x %s\n", digest, path) - } + fmt.Printf("%x %s\n", digest, path) + } } func matchDigests(paths []string, issuer bool) { @@ -87,10 +87,10 @@ func matchDigests(paths []string, issuer bool) { die.If(err) sndCert, err := certlib.LoadCertificate(snd) die.If(err) - if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { - lib.Warnx("certificates don't match: %s and %s", fst, snd) - invalid++ - } + if !bytes.Equal(getSubjectInfoHash(fstCert, issuer), getSubjectInfoHash(sndCert, issuer)) { + _, _ = lib.Warnx("certificates don't match: %s and %s", fst, snd) + invalid++ + } } if invalid > 0 { diff --git a/cmd/tlskeypair/main.go b/cmd/tlskeypair/main.go index 87b2b2e..4992852 100644 --- a/cmd/tlskeypair/main.go +++ b/cmd/tlskeypair/main.go @@ -1,21 +1,19 @@ package main import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "flag" - "fmt" - "io/ioutil" - "log" - "os" + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "flag" + "fmt" + "os" - "git.wntrmute.dev/kyle/goutils/die" + "git.wntrmute.dev/kyle/goutils/die" ) var validPEMs = map[string]bool{ @@ -52,8 +50,75 @@ func getECCurve(pub interface{}) int { } } +// matchRSA compares an RSA public key from certificate against RSA public key from private key. +// It returns true on match. +func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool { + return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E +} + +// matchECDSA compares ECDSA public keys for equality and compatible curve. +// It returns match=true when they are on the same curve and have the same X/Y. +// If curves mismatch, match is false. +func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool { + if getECCurve(certPub) != getECCurve(keyPub) { + return false + } + if keyPub.X.Cmp(certPub.X) != 0 { + return false + } + if keyPub.Y.Cmp(certPub.Y) != 0 { + return false + } + return true +} + +// matchKeys determines whether the certificate's public key matches the given private key. +// It returns true if they match; otherwise, it returns false and a human-friendly reason. +func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) { + switch keyPub := priv.Public().(type) { + case *rsa.PublicKey: + switch certPub := cert.PublicKey.(type) { + case *rsa.PublicKey: + if matchRSA(certPub, keyPub) { + return true, "" + } + return false, "public keys don't match" + case *ecdsa.PublicKey: + return false, "RSA private key, EC public key" + default: + return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) + } + case *ecdsa.PublicKey: + switch certPub := cert.PublicKey.(type) { + case *ecdsa.PublicKey: + if matchECDSA(certPub, keyPub) { + return true, "" + } + // Determine a more precise reason + kc := getECCurve(keyPub) + cc := getECCurve(certPub) + if kc == curveInvalid { + return false, "invalid private key curve" + } + if cc == curveRSA { + return false, "private key is EC, certificate is RSA" + } + if kc != cc { + return false, "EC curves don't match" + } + return false, "public keys don't match" + case *rsa.PublicKey: + return false, "private key is EC, certificate is RSA" + default: + return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) + } + default: + return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public()) + } +} + func loadKey(path string) (crypto.Signer, error) { - in, err := ioutil.ReadFile(path) + in, err := os.ReadFile(path) if err != nil { return nil, err } @@ -67,7 +132,7 @@ func loadKey(path string) (crypto.Signer, error) { in = p.Bytes } - priv, err := x509.ParsePKCS8PrivateKey(in) + priv, err := x509.ParsePKCS8PrivateKey(in) if err != nil { priv, err = x509.ParsePKCS1PrivateKey(in) if err != nil { @@ -78,15 +143,15 @@ func loadKey(path string) (crypto.Signer, error) { } } - switch priv.(type) { - case *rsa.PrivateKey: - return priv.(*rsa.PrivateKey), nil - case *ecdsa.PrivateKey: - return priv.(*ecdsa.PrivateKey), nil - } - - // should never reach here - return nil, errors.New("invalid private key") + switch p := priv.(type) { + case *rsa.PrivateKey: + return p, nil + case *ecdsa.PrivateKey: + return p, nil + default: + // should never reach here + return nil, errors.New("invalid private key") + } } @@ -96,7 +161,7 @@ func main() { flag.StringVar(&certFile, "c", "", "TLS `certificate` file") flag.Parse() - in, err := ioutil.ReadFile(certFile) + in, err := os.ReadFile(certFile) die.If(err) p, _ := pem.Decode(in) @@ -112,50 +177,11 @@ func main() { priv, err := loadKey(keyFile) die.If(err) - switch pub := priv.Public().(type) { - case *rsa.PublicKey: - switch certPub := cert.PublicKey.(type) { - case *rsa.PublicKey: - if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E { - fmt.Println("No match (public keys don't match).") - os.Exit(1) - } - fmt.Println("Match.") - return - case *ecdsa.PublicKey: - fmt.Println("No match (RSA private key, EC public key).") - os.Exit(1) - } - case *ecdsa.PublicKey: - privCurve := getECCurve(pub) - certCurve := getECCurve(cert.PublicKey) - log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve) - - if certCurve == curveRSA { - fmt.Println("No match (private key is EC, certificate is RSA).") - os.Exit(1) - } else if privCurve == curveInvalid { - fmt.Println("No match (invalid private key curve).") - os.Exit(1) - } else if privCurve != certCurve { - fmt.Println("No match (EC curves don't match).") - os.Exit(1) - } - - certPub := cert.PublicKey.(*ecdsa.PublicKey) - if pub.X.Cmp(certPub.X) != 0 { - fmt.Println("No match (public keys don't match).") - os.Exit(1) - } - - if pub.Y.Cmp(certPub.Y) != 0 { - fmt.Println("No match (public keys don't match).") - os.Exit(1) - } - - fmt.Println("Match.") - default: - fmt.Printf("Unrecognised private key type: %T\n", priv.Public()) - os.Exit(1) - } + matched, reason := matchKeys(cert, priv) + if matched { + fmt.Println("Match.") + return + } + fmt.Printf("No match (%s).\n", reason) + os.Exit(1) } diff --git a/cmd/zsearch/main.go b/cmd/zsearch/main.go index af6e6aa..82d765c 100644 --- a/cmd/zsearch/main.go +++ b/cmd/zsearch/main.go @@ -123,13 +123,14 @@ func main() { for _, path := range pathList { if isDir(path) { - err := filepath.Walk(path, buildWalker(search)) - if err != nil { + if err := filepath.Walk(path, buildWalker(search)); err != nil { errorf("%v", err) return } } else { - searchFile(path, search) + if err := searchFile(path, search); err != nil { + errorf("%v", err) + } } } } diff --git a/lib/lib.go b/lib/lib.go index b0a7a82..216b0cc 100644 --- a/lib/lib.go +++ b/lib/lib.go @@ -81,12 +81,12 @@ var ( // Duration returns a prettier string for time.Durations. func Duration(d time.Duration) string { - var s string - if d >= yearDuration { - years := d / yearDuration - s += fmt.Sprintf("%dy", years) - d -= years * yearDuration - } + var s string + if d >= yearDuration { + years := int64(d / yearDuration) + s += fmt.Sprintf("%dy", years) + d -= time.Duration(years) * yearDuration + } if d >= dayDuration { days := d / dayDuration @@ -97,9 +97,9 @@ func Duration(d time.Duration) string { return s } - d %= 1 * time.Second - hours := d / time.Hour - d -= hours * time.Hour + d %= 1 * time.Second + hours := int64(d / time.Hour) + d -= time.Duration(hours) * time.Hour s += fmt.Sprintf("%dh%s", hours, d) return s } diff --git a/logging/file.go b/logging/file.go index a65f18b..655b81c 100644 --- a/logging/file.go +++ b/logging/file.go @@ -1,6 +1,7 @@ package logging import ( + "errors" "fmt" "os" ) @@ -61,9 +62,9 @@ func NewSplitFile(outpath, errpath string, overwrite bool) (*File, error) { if err != nil { if closeErr := fl.Close(); closeErr != nil { - return nil, fmt.Errorf("failed to open error log: cleanup close failed: %v: %w", closeErr, err) + return nil, fmt.Errorf("failed to open error log: %w", errors.Join(closeErr, err)) } - return nil, err + return nil, fmt.Errorf("failed to open error log: %w", err) } fl.LogWriter = NewLogWriter(fl.fo, fl.fe)