diff --git a/certlib/keymatch.go b/certlib/keymatch.go index 20126c8..19f00d7 100644 --- a/certlib/keymatch.go +++ b/certlib/keymatch.go @@ -1,135 +1,135 @@ package certlib import ( - "bytes" - "crypto" - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rsa" - "crypto/x509" - "encoding/pem" - "errors" - "fmt" - "os" + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "fmt" + "os" ) // LoadPrivateKey loads a private key from disk. It accepts both PEM and DER // encodings and supports RSA and ECDSA keys. If the file contains a PEM block, // the block type must be one of the recognised private key types. func LoadPrivateKey(path string) (crypto.Signer, error) { - in, err := os.ReadFile(path) - if err != nil { - return nil, err - } + in, err := os.ReadFile(path) + if err != nil { + return nil, err + } - in = bytes.TrimSpace(in) - if p, _ := pem.Decode(in); p != nil { - if !validPEMs[p.Type] { - return nil, errors.New("invalid private key file type " + p.Type) - } - return ParsePrivateKeyPEM(in) - } + in = bytes.TrimSpace(in) + if p, _ := pem.Decode(in); p != nil { + if !validPEMs[p.Type] { + return nil, errors.New("invalid private key file type " + p.Type) + } + return ParsePrivateKeyPEM(in) + } - return ParsePrivateKeyDER(in) + return ParsePrivateKeyDER(in) } var validPEMs = map[string]bool{ - "PRIVATE KEY": true, - "RSA PRIVATE KEY": true, - "EC PRIVATE KEY": true, + "PRIVATE KEY": true, + "RSA PRIVATE KEY": true, + "EC PRIVATE KEY": true, } const ( - curveInvalid = iota // any invalid curve - curveRSA // indicates key is an RSA key, not an EC key - curveP256 - curveP384 - curveP521 + curveInvalid = iota // any invalid curve + curveRSA // indicates key is an RSA key, not an EC key + curveP256 + curveP384 + curveP521 ) func getECCurve(pub any) int { - switch pub := pub.(type) { - case *ecdsa.PublicKey: - switch pub.Curve { - case elliptic.P256(): - return curveP256 - case elliptic.P384(): - return curveP384 - case elliptic.P521(): - return curveP521 - default: - return curveInvalid - } - case *rsa.PublicKey: - return curveRSA - default: - return curveInvalid - } + switch pub := pub.(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + return curveP256 + case elliptic.P384(): + return curveP384 + case elliptic.P521(): + return curveP521 + default: + return curveInvalid + } + case *rsa.PublicKey: + return curveRSA + default: + return curveInvalid + } } // matchRSA compares an RSA public key from certificate against RSA public key from private key. // It returns true on match. func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool { - return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E + return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E } // matchECDSA compares ECDSA public keys for equality and compatible curve. // It returns match=true when they are on the same curve and have the same X/Y. // If curves mismatch, match is false. func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool { - if getECCurve(certPub) != getECCurve(keyPub) { - return false - } - if keyPub.X.Cmp(certPub.X) != 0 { - return false - } - if keyPub.Y.Cmp(certPub.Y) != 0 { - return false - } - return true + if getECCurve(certPub) != getECCurve(keyPub) { + return false + } + if keyPub.X.Cmp(certPub.X) != 0 { + return false + } + if keyPub.Y.Cmp(certPub.Y) != 0 { + return false + } + return true } // MatchKeys determines whether the certificate's public key matches the given private key. // It returns true if they match; otherwise, it returns false and a human-friendly reason. func MatchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) { - switch keyPub := priv.Public().(type) { - case *rsa.PublicKey: - switch certPub := cert.PublicKey.(type) { - case *rsa.PublicKey: - if matchRSA(certPub, keyPub) { - return true, "" - } - return false, "public keys don't match" - case *ecdsa.PublicKey: - return false, "RSA private key, EC public key" - default: - return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) - } - case *ecdsa.PublicKey: - switch certPub := cert.PublicKey.(type) { - case *ecdsa.PublicKey: - if matchECDSA(certPub, keyPub) { - return true, "" - } - // Determine a more precise reason - kc := getECCurve(keyPub) - cc := getECCurve(certPub) - if kc == curveInvalid { - return false, "invalid private key curve" - } - if cc == curveRSA { - return false, "private key is EC, certificate is RSA" - } - if kc != cc { - return false, "EC curves don't match" - } - return false, "public keys don't match" - case *rsa.PublicKey: - return false, "private key is EC, certificate is RSA" - default: - return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) - } - default: - return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public()) - } + switch keyPub := priv.Public().(type) { + case *rsa.PublicKey: + switch certPub := cert.PublicKey.(type) { + case *rsa.PublicKey: + if matchRSA(certPub, keyPub) { + return true, "" + } + return false, "public keys don't match" + case *ecdsa.PublicKey: + return false, "RSA private key, EC public key" + default: + return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) + } + case *ecdsa.PublicKey: + switch certPub := cert.PublicKey.(type) { + case *ecdsa.PublicKey: + if matchECDSA(certPub, keyPub) { + return true, "" + } + // Determine a more precise reason + kc := getECCurve(keyPub) + cc := getECCurve(certPub) + if kc == curveInvalid { + return false, "invalid private key curve" + } + if cc == curveRSA { + return false, "private key is EC, certificate is RSA" + } + if kc != cc { + return false, "EC curves don't match" + } + return false, "public keys don't match" + case *rsa.PublicKey: + return false, "private key is EC, certificate is RSA" + default: + return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey) + } + default: + return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public()) + } } diff --git a/cmd/ca-signed/main.go b/cmd/ca-signed/main.go index 384e97c..c5f2fd6 100644 --- a/cmd/ca-signed/main.go +++ b/cmd/ca-signed/main.go @@ -182,7 +182,7 @@ func main() { continue } - if _, err := verify.CertWith(cert, roots, nil, false); err != nil { + if _, err = verify.CertWith(cert, roots, nil, false); err != nil { fmt.Printf("%s: INVALID\n", arg) } else { fmt.Printf("%s: OK (expires %s)\n", arg, cert.NotAfter.Format(lib.DateShortFormat)) diff --git a/cmd/tlskeypair/main.go b/cmd/tlskeypair/main.go index 9461fe3..391e280 100644 --- a/cmd/tlskeypair/main.go +++ b/cmd/tlskeypair/main.go @@ -1,33 +1,33 @@ package main import ( - "flag" - "fmt" - "os" + "flag" + "fmt" + "os" - "git.wntrmute.dev/kyle/goutils/certlib" - "git.wntrmute.dev/kyle/goutils/die" + "git.wntrmute.dev/kyle/goutils/certlib" + "git.wntrmute.dev/kyle/goutils/die" ) // functionality refactored into certlib func main() { - var keyFile, certFile string - flag.StringVar(&keyFile, "k", "", "TLS private `key` file") - flag.StringVar(&certFile, "c", "", "TLS `certificate` file") - flag.Parse() + var keyFile, certFile string + flag.StringVar(&keyFile, "k", "", "TLS private `key` file") + flag.StringVar(&certFile, "c", "", "TLS `certificate` file") + flag.Parse() - cert, err := certlib.LoadCertificate(certFile) - die.If(err) + cert, err := certlib.LoadCertificate(certFile) + die.If(err) - priv, err := certlib.LoadPrivateKey(keyFile) - die.If(err) + priv, err := certlib.LoadPrivateKey(keyFile) + die.If(err) - matched, reason := certlib.MatchKeys(cert, priv) - if matched { - fmt.Println("Match.") - return - } - fmt.Printf("No match (%s).\n", reason) - os.Exit(1) + matched, reason := certlib.MatchKeys(cert, priv) + if matched { + fmt.Println("Match.") + return + } + fmt.Printf("No match (%s).\n", reason) + os.Exit(1) } diff --git a/lib/lib.go b/lib/lib.go index 76f8634..feca520 100644 --- a/lib/lib.go +++ b/lib/lib.go @@ -1,6 +1,7 @@ package lib import ( + "encoding/base64" "encoding/hex" "fmt" "os" @@ -126,6 +127,8 @@ const ( HexEncodeUpperColon // HexEncodeBytes prints the string as a sequence of []byte. HexEncodeBytes + // HexEncodeBase64 prints the string as a base64-encoded string. + HexEncodeBase64 ) func (m HexEncodeMode) String() string { @@ -140,6 +143,8 @@ func (m HexEncodeMode) String() string { return "ucolon" case HexEncodeBytes: return "bytes" + case HexEncodeBase64: + return "base64" default: panic("invalid hex encode mode") } @@ -157,6 +162,8 @@ func ParseHexEncodeMode(s string) HexEncodeMode { return HexEncodeUpperColon case "bytes": return HexEncodeBytes + case "base64": + return HexEncodeBase64 } panic("invalid hex encode mode") @@ -218,21 +225,22 @@ func bytesAsByteSliceString(buf []byte) string { return sb.String() } -// HexEncode encodes the given bytes as a hexadecimal string. +// HexEncode encodes the given bytes as a hexadecimal string. It +// also supports a few other binary-encoding formats as well. func HexEncode(b []byte, mode HexEncodeMode) string { - str := hexEncode(b) - switch mode { case HexEncodeLower: - return str + return hexEncode(b) case HexEncodeUpper: - return strings.ToUpper(str) + return strings.ToUpper(hexEncode(b)) case HexEncodeLowerColon: - return hexColons(str) + return hexColons(hexEncode(b)) case HexEncodeUpperColon: - return strings.ToUpper(hexColons(str)) + return strings.ToUpper(hexColons(hexEncode(b))) case HexEncodeBytes: return bytesAsByteSliceString(b) + case HexEncodeBase64: + return base64.StdEncoding.EncodeToString(b) default: panic("invalid hex encode mode") }