diff --git a/README.md b/README.md index 07c011a..fbb3e1e 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ Contents: in a bundle. stealchain/ Dump the verified chain from a TLS connection. + tlskeypair/ Check whether a TLS certificate and key file match. Each program should have a small README in the directory with more information. diff --git a/cmd/tlskeypair/README b/cmd/tlskeypair/README new file mode 100644 index 0000000..c3b6630 --- /dev/null +++ b/cmd/tlskeypair/README @@ -0,0 +1,17 @@ +tlskeypair + +This is a small utility that prints whether a certificate and private +key correspond. Specify the private key file with -k and the +certificate with -c. + +Example: + + $ tlskeypair -k server.key -c server.pem + Match. + $ echo $? + 0 + $ tlskeypair -k server.key -c client.pem + No match (public keys don't match). + + + diff --git a/cmd/tlskeypair/main.go b/cmd/tlskeypair/main.go new file mode 100644 index 0000000..c4dc9b9 --- /dev/null +++ b/cmd/tlskeypair/main.go @@ -0,0 +1,161 @@ +package main + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "flag" + "fmt" + "io/ioutil" + "log" + "os" + + "github.com/kisom/die" +) + +var validPEMs = map[string]bool{ + "PRIVATE KEY": true, + "RSA PRIVATE KEY": true, + "EC PRIVATE KEY": true, +} + +const ( + curveInvalid = iota // any invalid curve + curveRSA // indicates key is an RSA key, not an EC key + curveP256 + curveP384 + curveP521 +) + +func getECCurve(pub interface{}) int { + switch pub := pub.(type) { + case *ecdsa.PublicKey: + switch pub.Curve { + case elliptic.P256(): + return curveP256 + case elliptic.P384(): + return curveP384 + case elliptic.P521(): + return curveP521 + default: + return curveInvalid + } + case *rsa.PublicKey: + return curveRSA + default: + return curveInvalid + } +} + +func loadKey(path string) (crypto.Signer, error) { + in, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + + in = bytes.TrimSpace(in) + p, _ := pem.Decode(in) + if p != nil { + if !validPEMs[p.Type] { + return nil, errors.New("invalid private key file type " + p.Type) + } + in = p.Bytes + } + + priv, err := x509.ParsePKCS8PrivateKey(in) + if err != nil { + priv, err = x509.ParsePKCS1PrivateKey(in) + if err != nil { + priv, err = x509.ParseECPrivateKey(in) + if err != nil { + return nil, err + } + } + } + + switch priv.(type) { + case *rsa.PrivateKey: + return priv.(*rsa.PrivateKey), nil + case *ecdsa.PrivateKey: + return priv.(*ecdsa.PrivateKey), nil + } + + // should never reach here + return nil, errors.New("invalid private key") + +} + +func main() { + var keyFile, certFile string + flag.StringVar(&keyFile, "k", "", "TLS private `key` file") + flag.StringVar(&certFile, "c", "", "TLS `certificate` file") + flag.Parse() + + in, err := ioutil.ReadFile(certFile) + die.If(err) + + p, _ := pem.Decode(in) + if p != nil { + if p.Type != "CERTIFICATE" { + die.With("invalid certificate (type is %s)", p.Type) + } + in = p.Bytes + } + cert, err := x509.ParseCertificate(in) + die.If(err) + + priv, err := loadKey(keyFile) + die.If(err) + + switch pub := priv.Public().(type) { + case *rsa.PublicKey: + switch certPub := cert.PublicKey.(type) { + case *rsa.PublicKey: + if pub.N.Cmp(certPub.N) != 0 || pub.E != certPub.E { + fmt.Println("No match (public keys don't match).") + os.Exit(1) + } + fmt.Println("Match.") + return + case *ecdsa.PublicKey: + fmt.Println("No match (RSA private key, EC public key).") + os.Exit(1) + } + case *ecdsa.PublicKey: + privCurve := getECCurve(pub) + certCurve := getECCurve(cert.PublicKey) + log.Printf("priv: %d\tcert: %d\n", privCurve, certCurve) + + if certCurve == curveRSA { + fmt.Println("No match (private key is EC, certificate is RSA).") + os.Exit(1) + } else if privCurve == curveInvalid { + fmt.Println("No match (invalid private key curve).") + os.Exit(1) + } else if privCurve != certCurve { + fmt.Println("No match (EC curves don't match).") + os.Exit(1) + } + + certPub := cert.PublicKey.(*ecdsa.PublicKey) + if pub.X.Cmp(certPub.X) != 0 { + fmt.Println("No match (public keys don't match).") + os.Exit(1) + } + + if pub.Y.Cmp(certPub.Y) != 0 { + fmt.Println("No match (public keys don't match).") + os.Exit(1) + } + + fmt.Println("Match.") + default: + fmt.Printf("Unrecognised private key type: %T\n", priv.Public()) + os.Exit(1) + } +}