certlib: complete overhaul.

This commit is contained in:
2025-11-15 22:54:12 -08:00
parent f3b4838cf6
commit cf2b016433
11 changed files with 246 additions and 177 deletions

View File

@@ -79,24 +79,23 @@ func (e *Error) Error() string {
func (e *Error) Unwrap() error { return e.Err } func (e *Error) Unwrap() error { return e.Err }
// InvalidPEMType is used to indicate that we were expecting one type of PEM // InvalidPEMTypeError is used to indicate that we were expecting one type of PEM
// file, but saw another. // file, but saw another.
type InvalidPEMType struct { type InvalidPEMTypeError struct {
have string have string
want []string want []string
} }
func (err *InvalidPEMType) Error() string { func (err *InvalidPEMTypeError) Error() string {
if len(err.want) == 1 { if len(err.want) == 1 {
return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0]) return fmt.Sprintf("invalid PEM type: have %s, expected %s", err.have, err.want[0])
} else {
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
return fmt.Sprintf("invalid PEM type: have %s, expected one of %s", err.have, strings.Join(err.want, ", "))
} }
// ErrInvalidPEMType returns a new InvalidPEMType error. // ErrInvalidPEMType returns a new InvalidPEMTypeError error.
func ErrInvalidPEMType(have string, want ...string) error { func ErrInvalidPEMType(have string, want ...string) error {
return &InvalidPEMType{ return &InvalidPEMTypeError{
have: have, have: have,
want: want, want: want,
} }

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certerr package certerr
import ( import (

View File

@@ -11,7 +11,7 @@ import (
// ReadCertificate reads a DER or PEM-encoded certificate from the // ReadCertificate reads a DER or PEM-encoded certificate from the
// byte slice. // byte slice.
func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error) { func ReadCertificate(in []byte) (*x509.Certificate, []byte, error) {
if len(in) == 0 { if len(in) == 0 {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, certerr.ErrEmptyCertificate)
} }
@@ -22,7 +22,7 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file")) return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, errors.New("invalid PEM file"))
} }
rest = remaining rest := remaining
if p.Type != "CERTIFICATE" { if p.Type != "CERTIFICATE" {
return nil, rest, certerr.ParsingError( return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCertificate, certerr.ErrorSourceCertificate,
@@ -31,19 +31,26 @@ func ReadCertificate(in []byte) (cert *x509.Certificate, rest []byte, err error)
} }
in = p.Bytes in = p.Bytes
} cert, err := x509.ParseCertificate(in)
cert, err = x509.ParseCertificate(in)
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return nil, rest, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return cert, rest, nil return cert, rest, nil
}
cert, err := x509.ParseCertificate(in)
if err != nil {
return nil, nil, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
}
return cert, nil, nil
} }
// ReadCertificates tries to read all the certificates in a // ReadCertificates tries to read all the certificates in a
// PEM-encoded collection. // PEM-encoded collection.
func ReadCertificates(in []byte) (certs []*x509.Certificate, err error) { func ReadCertificates(in []byte) ([]*x509.Certificate, error) {
var cert *x509.Certificate var cert *x509.Certificate
var certs []*x509.Certificate
var err error
for { for {
cert, in, err = ReadCertificate(in) cert, in, err = ReadCertificate(in)
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package certlib package certlib
import ( import (

View File

@@ -38,6 +38,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rsa" "crypto/rsa"
"crypto/x509" "crypto/x509"
"errors"
"fmt" "fmt"
"git.wntrmute.dev/kyle/goutils/certlib/certerr" "git.wntrmute.dev/kyle/goutils/certlib/certerr"
@@ -47,29 +48,36 @@ import (
// private key. The key must not be in PEM format. If an error is returned, it // private key. The key must not be in PEM format. If an error is returned, it
// may contain information about the private key, so care should be taken when // may contain information about the private key, so care should be taken when
// displaying it directly. // displaying it directly.
func ParsePrivateKeyDER(keyDER []byte) (key crypto.Signer, err error) { func ParsePrivateKeyDER(keyDER []byte) (crypto.Signer, error) {
generalKey, err := x509.ParsePKCS8PrivateKey(keyDER) // Try common encodings in order without deep nesting.
if err != nil { if k, err := x509.ParsePKCS8PrivateKey(keyDER); err == nil {
generalKey, err = x509.ParsePKCS1PrivateKey(keyDER) switch kk := k.(type) {
if err != nil { case *rsa.PrivateKey:
generalKey, err = x509.ParseECPrivateKey(keyDER) return kk, nil
if err != nil { case *ecdsa.PrivateKey:
generalKey, err = ParseEd25519PrivateKey(keyDER) return kk, nil
if err != nil { case ed25519.PrivateKey:
return kk, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
}
if k, err := x509.ParsePKCS1PrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := x509.ParseECPrivateKey(keyDER); err == nil {
return k, nil
}
if k, err := ParseEd25519PrivateKey(keyDER); err == nil {
if kk, ok := k.(ed25519.PrivateKey); ok {
return kk, nil
}
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %T", k))
}
// If all parsers failed, return the last error from Ed25519 attempt (approximate cause).
if _, err := ParseEd25519PrivateKey(keyDER); err != nil {
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err) return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, err)
} }
} // Fallback (should be unreachable)
} return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, errors.New("unknown key encoding"))
}
switch generalKey := generalKey.(type) {
case *rsa.PrivateKey:
return generalKey, nil
case *ecdsa.PrivateKey:
return generalKey, nil
case ed25519.PrivateKey:
return generalKey, nil
default:
return nil, certerr.ParsingError(certerr.ErrorSourcePrivateKey, fmt.Errorf("unknown key type %t", generalKey))
}
} }

View File

@@ -65,12 +65,14 @@ func MarshalEd25519PublicKey(pk crypto.PublicKey) ([]byte, error) {
return nil, errEd25519WrongKeyType return nil, errEd25519WrongKeyType
} }
const bitsPerByte = 8
spki := subjectPublicKeyInfo{ spki := subjectPublicKeyInfo{
Algorithm: pkix.AlgorithmIdentifier{ Algorithm: pkix.AlgorithmIdentifier{
Algorithm: ed25519OID, Algorithm: ed25519OID,
}, },
PublicKey: asn1.BitString{ PublicKey: asn1.BitString{
BitLength: len(pub) * 8, BitLength: len(pub) * bitsPerByte,
Bytes: pub, Bytes: pub,
}, },
} }
@@ -91,7 +93,8 @@ func ParseEd25519PublicKey(der []byte) (crypto.PublicKey, error) {
return nil, errEd25519WrongID return nil, errEd25519WrongID
} }
if spki.PublicKey.BitLength != ed25519.PublicKeySize*8 { const bitsPerByte = 8
if spki.PublicKey.BitLength != ed25519.PublicKeySize*bitsPerByte {
return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch") return nil, errors.New("SubjectPublicKeyInfo PublicKey length mismatch")
} }

View File

@@ -49,14 +49,14 @@ import (
"strings" "strings"
"time" "time"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
ct "github.com/google/certificate-transparency-go" ct "github.com/google/certificate-transparency-go"
cttls "github.com/google/certificate-transparency-go/tls" cttls "github.com/google/certificate-transparency-go/tls"
ctx509 "github.com/google/certificate-transparency-go/x509" ctx509 "github.com/google/certificate-transparency-go/x509"
"golang.org/x/crypto/ocsp" "golang.org/x/crypto/ocsp"
"golang.org/x/crypto/pkcs12" "golang.org/x/crypto/pkcs12"
"git.wntrmute.dev/kyle/goutils/certlib/certerr"
"git.wntrmute.dev/kyle/goutils/certlib/pkcs7"
) )
// OneYear is a time.Duration representing a year's worth of seconds. // OneYear is a time.Duration representing a year's worth of seconds.
@@ -68,7 +68,7 @@ const OneDay = 24 * time.Hour
// DelegationUsage is the OID for the DelegationUseage extensions. // DelegationUsage is the OID for the DelegationUseage extensions.
var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44} var DelegationUsage = asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 44363, 44}
// DelegationExtension. // DelegationExtension is a non-critical extension marking delegation usage.
var DelegationExtension = pkix.Extension{ var DelegationExtension = pkix.Extension{
Id: DelegationUsage, Id: DelegationUsage,
Critical: false, Critical: false,
@@ -81,13 +81,19 @@ func InclusiveDate(year int, month time.Month, day int) time.Time {
return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond) return time.Date(year, month, day, 0, 0, 0, 0, time.UTC).Add(-1 * time.Nanosecond)
} }
const (
year2012 = 2012
year2015 = 2015
day1 = 1
)
// Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop // Jul2012 is the July 2012 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 5 years. // issuing certificates valid for more than 5 years.
var Jul2012 = InclusiveDate(2012, time.July, 01) var Jul2012 = InclusiveDate(year2012, time.July, day1)
// Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop // Apr2015 is the April 2015 CAB Forum deadline for when CAs must stop
// issuing certificates valid for more than 39 months. // issuing certificates valid for more than 39 months.
var Apr2015 = InclusiveDate(2015, time.April, 01) var Apr2015 = InclusiveDate(year2015, time.April, day1)
// KeyLength returns the bit size of ECDSA or RSA PublicKey. // KeyLength returns the bit size of ECDSA or RSA PublicKey.
func KeyLength(key any) int { func KeyLength(key any) int {
@@ -108,11 +114,11 @@ func KeyLength(key any) int {
} }
// ExpiryTime returns the time when the certificate chain is expired. // ExpiryTime returns the time when the certificate chain is expired.
func ExpiryTime(chain []*x509.Certificate) (notAfter time.Time) { func ExpiryTime(chain []*x509.Certificate) time.Time {
var notAfter time.Time
if len(chain) == 0 { if len(chain) == 0 {
return notAfter return notAfter
} }
notAfter = chain[0].NotAfter notAfter = chain[0].NotAfter
for _, cert := range chain { for _, cert := range chain {
if notAfter.After(cert.NotAfter) { if notAfter.After(cert.NotAfter) {
@@ -158,18 +164,23 @@ func ValidExpiry(c *x509.Certificate) bool {
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
// an X509 signature algorithm. // an X509 signature algorithm.
var signatureString = map[x509.SignatureAlgorithm]string{ var signatureString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Signature",
x509.MD2WithRSA: "MD2WithRSA", x509.MD2WithRSA: "MD2WithRSA",
x509.MD5WithRSA: "MD5WithRSA", x509.MD5WithRSA: "MD5WithRSA",
x509.SHA1WithRSA: "SHA1WithRSA", x509.SHA1WithRSA: "SHA1WithRSA",
x509.SHA256WithRSA: "SHA256WithRSA", x509.SHA256WithRSA: "SHA256WithRSA",
x509.SHA384WithRSA: "SHA384WithRSA", x509.SHA384WithRSA: "SHA384WithRSA",
x509.SHA512WithRSA: "SHA512WithRSA", x509.SHA512WithRSA: "SHA512WithRSA",
x509.SHA256WithRSAPSS: "SHA256WithRSAPSS",
x509.SHA384WithRSAPSS: "SHA384WithRSAPSS",
x509.SHA512WithRSAPSS: "SHA512WithRSAPSS",
x509.DSAWithSHA1: "DSAWithSHA1", x509.DSAWithSHA1: "DSAWithSHA1",
x509.DSAWithSHA256: "DSAWithSHA256", x509.DSAWithSHA256: "DSAWithSHA256",
x509.ECDSAWithSHA1: "ECDSAWithSHA1", x509.ECDSAWithSHA1: "ECDSAWithSHA1",
x509.ECDSAWithSHA256: "ECDSAWithSHA256", x509.ECDSAWithSHA256: "ECDSAWithSHA256",
x509.ECDSAWithSHA384: "ECDSAWithSHA384", x509.ECDSAWithSHA384: "ECDSAWithSHA384",
x509.ECDSAWithSHA512: "ECDSAWithSHA512", x509.ECDSAWithSHA512: "ECDSAWithSHA512",
x509.PureEd25519: "PureEd25519",
} }
// SignatureString returns the TLS signature string corresponding to // SignatureString returns the TLS signature string corresponding to
@@ -184,18 +195,23 @@ func SignatureString(alg x509.SignatureAlgorithm) string {
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
// method. // method.
var hashAlgoString = map[x509.SignatureAlgorithm]string{ var hashAlgoString = map[x509.SignatureAlgorithm]string{
x509.UnknownSignatureAlgorithm: "Unknown Hash Algorithm",
x509.MD2WithRSA: "MD2", x509.MD2WithRSA: "MD2",
x509.MD5WithRSA: "MD5", x509.MD5WithRSA: "MD5",
x509.SHA1WithRSA: "SHA1", x509.SHA1WithRSA: "SHA1",
x509.SHA256WithRSA: "SHA256", x509.SHA256WithRSA: "SHA256",
x509.SHA384WithRSA: "SHA384", x509.SHA384WithRSA: "SHA384",
x509.SHA512WithRSA: "SHA512", x509.SHA512WithRSA: "SHA512",
x509.SHA256WithRSAPSS: "SHA256",
x509.SHA384WithRSAPSS: "SHA384",
x509.SHA512WithRSAPSS: "SHA512",
x509.DSAWithSHA1: "SHA1", x509.DSAWithSHA1: "SHA1",
x509.DSAWithSHA256: "SHA256", x509.DSAWithSHA256: "SHA256",
x509.ECDSAWithSHA1: "SHA1", x509.ECDSAWithSHA1: "SHA1",
x509.ECDSAWithSHA256: "SHA256", x509.ECDSAWithSHA256: "SHA256",
x509.ECDSAWithSHA384: "SHA384", x509.ECDSAWithSHA384: "SHA384",
x509.ECDSAWithSHA512: "SHA512", x509.ECDSAWithSHA512: "SHA512",
x509.PureEd25519: "SHA512", // per x509 docs Ed25519 uses SHA-512 internally
} }
// HashAlgoString returns the hash algorithm name contains in the signature // HashAlgoString returns the hash algorithm name contains in the signature
@@ -273,7 +289,7 @@ func ParseCertificatesPEM(certsPEM []byte) ([]*x509.Certificate, error) {
// ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key, // ParseCertificatesDER parses a DER encoding of a certificate object and possibly private key,
// either PKCS #7, PKCS #12, or raw x509. // either PKCS #7, PKCS #12, or raw x509.
func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certificate, key crypto.Signer, err error) { func ParseCertificatesDER(certsDER []byte, password string) ([]*x509.Certificate, crypto.Signer, error) {
certsDER = bytes.TrimSpace(certsDER) certsDER = bytes.TrimSpace(certsDER)
// First, try PKCS #7 // First, try PKCS #7
@@ -284,7 +300,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
errors.New("can only extract certificates from signed data content info"), errors.New("can only extract certificates from signed data content info"),
) )
} }
certs = pkcs7data.Content.SignedData.Certificates certs := pkcs7data.Content.SignedData.Certificates
if certs == nil { if certs == nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded")) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, errors.New("no certificates decoded"))
} }
@@ -304,7 +320,7 @@ func ParseCertificatesDER(certsDER []byte, password string) (certs []*x509.Certi
} }
// Finally, attempt to parse raw X.509 certificates // Finally, attempt to parse raw X.509 certificates
certs, err = x509.ParseCertificates(certsDER) certs, err := x509.ParseCertificates(certsDER)
if err != nil { if err != nil {
return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err) return nil, nil, certerr.DecodeError(certerr.ErrorSourceCertificate, err)
} }
@@ -318,7 +334,8 @@ func ParseSelfSignedCertificatePEM(certPEM []byte) (*x509.Certificate, error) {
return nil, err return nil, err
} }
if err := cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature); err != nil { err = cert.CheckSignature(cert.SignatureAlgorithm, cert.RawTBSCertificate, cert.Signature)
if err != nil {
return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err) return nil, certerr.VerifyError(certerr.ErrorSourceCertificate, err)
} }
return cert, nil return cert, nil
@@ -362,8 +379,8 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
cert, err := x509.ParseCertificate(block.Bytes) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
pkcs7data, err := pkcs7.ParsePKCS7(block.Bytes) pkcs7data, err2 := pkcs7.ParsePKCS7(block.Bytes)
if err != nil { if err2 != nil {
return nil, rest, err return nil, rest, err
} }
if pkcs7data.ContentInfo != "SignedData" { if pkcs7data.ContentInfo != "SignedData" {
@@ -382,7 +399,7 @@ func ParseOneCertificateFromPEM(certsPEM []byte) ([]*x509.Certificate, []byte, e
// LoadPEMCertPool loads a pool of PEM certificates from file. // LoadPEMCertPool loads a pool of PEM certificates from file.
func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) { func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
if certsFile == "" { if certsFile == "" {
return nil, nil return nil, nil //nolint:nilnil // no CA file provided -> treat as no pool and no error
} }
pemCerts, err := os.ReadFile(certsFile) pemCerts, err := os.ReadFile(certsFile)
if err != nil { if err != nil {
@@ -395,7 +412,7 @@ func LoadPEMCertPool(certsFile string) (*x509.CertPool, error) {
// PEMToCertPool concerts PEM certificates to a CertPool. // PEMToCertPool concerts PEM certificates to a CertPool.
func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) { func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
if len(pemCerts) == 0 { if len(pemCerts) == 0 {
return nil, nil return nil, nil //nolint:nilnil // empty input means no pool needed
} }
certPool := x509.NewCertPool() certPool := x509.NewCertPool()
@@ -409,14 +426,14 @@ func PEMToCertPool(pemCerts []byte) (*x509.CertPool, error) {
// ParsePrivateKeyPEM parses and returns a PEM-encoded private // ParsePrivateKeyPEM parses and returns a PEM-encoded private
// key. The private key may be either an unencrypted PKCS#8, PKCS#1, // key. The private key may be either an unencrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEM(keyPEM []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEM(keyPEM []byte) (crypto.Signer, error) {
return ParsePrivateKeyPEMWithPassword(keyPEM, nil) return ParsePrivateKeyPEMWithPassword(keyPEM, nil)
} }
// ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private // ParsePrivateKeyPEMWithPassword parses and returns a PEM-encoded private
// key. The private key may be a potentially encrypted PKCS#8, PKCS#1, // key. The private key may be a potentially encrypted PKCS#8, PKCS#1,
// or elliptic private key. // or elliptic private key.
func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (key crypto.Signer, err error) { func ParsePrivateKeyPEMWithPassword(keyPEM []byte, password []byte) (crypto.Signer, error) {
keyDER, err := GetKeyDERFromPEM(keyPEM, password) keyDER, err := GetKeyDERFromPEM(keyPEM, password)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -436,26 +453,35 @@ func GetKeyDERFromPEM(in []byte, password []byte) ([]byte, error) {
break break
} }
} }
if keyDER != nil { if keyDER == nil {
if procType, ok := keyDER.Headers["Proc-Type"]; ok { return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
if strings.Contains(procType, "ENCRYPTED") { }
if procType, ok := keyDER.Headers["Proc-Type"]; ok && strings.Contains(procType, "ENCRYPTED") {
if password != nil { if password != nil {
// nolintlint requires rationale:
//nolint:staticcheck // legacy RFC1423 PEM encryption supported for backward compatibility when caller supplies a password
return x509.DecryptPEMBlock(keyDER, password) return x509.DecryptPEMBlock(keyDER, password)
} }
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey) return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, certerr.ErrEncryptedPrivateKey)
} }
}
return keyDER.Bytes, nil return keyDER.Bytes, nil
}
return nil, certerr.DecodeError(certerr.ErrorSourcePrivateKey, errors.New("failed to decode private key"))
} }
// ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request. // ParseCSR parses a PEM- or DER-encoded PKCS #10 certificate signing request.
func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error) { func ParseCSR(in []byte) (*x509.CertificateRequest, []byte, error) {
in = bytes.TrimSpace(in) in = bytes.TrimSpace(in)
p, rest := pem.Decode(in) p, rest := pem.Decode(in)
if p != nil { if p == nil {
csr, err := x509.ParseCertificateRequest(in)
if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
}
if sigErr := csr.CheckSignature(); sigErr != nil {
return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
}
return csr, rest, nil
}
if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" { if p.Type != "NEW CERTIFICATE REQUEST" && p.Type != "CERTIFICATE REQUEST" {
return nil, rest, certerr.ParsingError( return nil, rest, certerr.ParsingError(
certerr.ErrorSourceCSR, certerr.ErrorSourceCSR,
@@ -463,20 +489,13 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
) )
} }
csr, err = x509.ParseCertificateRequest(p.Bytes) csr, err := x509.ParseCertificateRequest(p.Bytes)
} else {
csr, err = x509.ParseCertificateRequest(in)
}
if err != nil { if err != nil {
return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err) return nil, rest, certerr.ParsingError(certerr.ErrorSourceCSR, err)
} }
if sigErr := csr.CheckSignature(); sigErr != nil {
err = csr.CheckSignature() return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, sigErr)
if err != nil {
return nil, rest, certerr.VerifyError(certerr.ErrorSourceCSR, err)
} }
return csr, rest, nil return csr, rest, nil
} }
@@ -484,7 +503,7 @@ func ParseCSR(in []byte) (csr *x509.CertificateRequest, rest []byte, err error)
// It does not check the signature. This is useful for dumping data from a CSR // It does not check the signature. This is useful for dumping data from a CSR
// locally. // locally.
func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) { func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
block, _ := pem.Decode([]byte(csrPEM)) block, _ := pem.Decode(csrPEM)
if block == nil { if block == nil {
return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty")) return nil, certerr.DecodeError(certerr.ErrorSourceCSR, errors.New("PEM block is empty"))
} }
@@ -499,15 +518,20 @@ func ParseCSRPEM(csrPEM []byte) (*x509.CertificateRequest, error) {
// SignerAlgo returns an X.509 signature algorithm from a crypto.Signer. // SignerAlgo returns an X.509 signature algorithm from a crypto.Signer.
func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm { func SignerAlgo(priv crypto.Signer) x509.SignatureAlgorithm {
const (
rsaBits2048 = 2048
rsaBits3072 = 3072
rsaBits4096 = 4096
)
switch pub := priv.Public().(type) { switch pub := priv.Public().(type) {
case *rsa.PublicKey: case *rsa.PublicKey:
bitLength := pub.N.BitLen() bitLength := pub.N.BitLen()
switch { switch {
case bitLength >= 4096: case bitLength >= rsaBits4096:
return x509.SHA512WithRSA return x509.SHA512WithRSA
case bitLength >= 3072: case bitLength >= rsaBits3072:
return x509.SHA384WithRSA return x509.SHA384WithRSA
case bitLength >= 2048: case bitLength >= rsaBits2048:
return x509.SHA256WithRSA return x509.SHA256WithRSA
default: default:
return x509.SHA1WithRSA return x509.SHA1WithRSA
@@ -537,7 +561,7 @@ func LoadClientCertificate(certFile string, keyFile string) (*tls.Certificate, e
} }
return &cert, nil return &cert, nil
} }
return nil, nil return nil, nil //nolint:nilnil // absence of client cert is not an error
} }
// CreateTLSConfig creates a tls.Config object from certs and roots. // CreateTLSConfig creates a tls.Config object from certs and roots.
@@ -549,6 +573,7 @@ func CreateTLSConfig(remoteCAs *x509.CertPool, cert *tls.Certificate) *tls.Confi
return &tls.Config{ return &tls.Config{
Certificates: certs, Certificates: certs,
RootCAs: remoteCAs, RootCAs: remoteCAs,
MinVersion: tls.VersionTLS12, // secure default
} }
} }
@@ -582,11 +607,11 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList)) list := make([]ct.SignedCertificateTimestamp, len(sctList.SCTList))
for i, serializedSCT := range sctList.SCTList { for i, serializedSCT := range sctList.SCTList {
var sct ct.SignedCertificateTimestamp var sct ct.SignedCertificateTimestamp
rest, err := cttls.Unmarshal(serializedSCT.Val, &sct) rest2, err2 := cttls.Unmarshal(serializedSCT.Val, &sct)
if err != nil { if err2 != nil {
return nil, err return nil, err2
} }
if len(rest) != 0 { if len(rest2) != 0 {
return nil, certerr.ParsingError( return nil, certerr.ParsingError(
certerr.ErrorSourceSCTList, certerr.ErrorSourceSCTList,
errors.New("serialized SCT list contained trailing garbage"), errors.New("serialized SCT list contained trailing garbage"),
@@ -602,12 +627,12 @@ func DeserializeSCTList(serializedSCTList []byte) ([]ct.SignedCertificateTimesta
// unmarshalled. // unmarshalled.
func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) { func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTimestamp, error) {
// This loop finds the SCTListExtension in the OCSP response. // This loop finds the SCTListExtension in the OCSP response.
var SCTListExtension, ext pkix.Extension var sctListExtension, ext pkix.Extension
for _, ext = range response.Extensions { for _, ext = range response.Extensions {
// sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp. // sctExtOid is the ObjectIdentifier of a Signed Certificate Timestamp.
sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5} sctExtOid := asn1.ObjectIdentifier{1, 3, 6, 1, 4, 1, 11129, 2, 4, 5}
if ext.Id.Equal(sctExtOid) { if ext.Id.Equal(sctExtOid) {
SCTListExtension = ext sctListExtension = ext
break break
} }
} }
@@ -615,10 +640,10 @@ func SCTListFromOCSPResponse(response *ocsp.Response) ([]ct.SignedCertificateTim
// This code block extracts the sctList from the SCT extension. // This code block extracts the sctList from the SCT extension.
var sctList []ct.SignedCertificateTimestamp var sctList []ct.SignedCertificateTimestamp
var err error var err error
if numBytes := len(SCTListExtension.Value); numBytes != 0 { if numBytes := len(sctListExtension.Value); numBytes != 0 {
var serializedSCTList []byte var serializedSCTList []byte
rest := make([]byte, numBytes) rest := make([]byte, numBytes)
copy(rest, SCTListExtension.Value) copy(rest, sctListExtension.Value)
for len(rest) != 0 { for len(rest) != 0 {
rest, err = asn1.Unmarshal(rest, &serializedSCTList) rest, err = asn1.Unmarshal(rest, &serializedSCTList)
if err != nil { if err != nil {

View File

@@ -9,6 +9,8 @@ import (
"strings" "strings"
) )
const defaultHTTPSPort = 443
type Target struct { type Target struct {
Host string Host string
Port int Port int
@@ -29,29 +31,29 @@ func parseURL(host string) (string, int, error) {
} }
if url.Port() == "" { if url.Port() == "" {
return url.Hostname(), 443, nil return url.Hostname(), defaultHTTPSPort, nil
} }
port, err := strconv.ParseInt(url.Port(), 10, 16) portInt, err2 := strconv.ParseInt(url.Port(), 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port()) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", url.Port())
} }
return url.Hostname(), int(port), nil return url.Hostname(), int(portInt), nil
} }
func parseHostPort(host string) (string, int, error) { func parseHostPort(host string) (string, int, error) {
host, sport, err := net.SplitHostPort(host) host, sport, err := net.SplitHostPort(host)
if err == nil { if err == nil {
port, err := strconv.ParseInt(sport, 10, 16) portInt, err2 := strconv.ParseInt(sport, 10, 16)
if err != nil { if err2 != nil {
return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport) return "", 0, fmt.Errorf("certlib/hosts: invalid port: %s", sport)
} }
return host, int(port), nil return host, int(portInt), nil
} }
return host, 443, nil return host, defaultHTTPSPort, nil
} }
func ParseHost(host string) (*Target, error) { func ParseHost(host string) (*Target, error) {

View File

@@ -158,9 +158,9 @@ type EncryptedContentInfo struct {
EncryptedContent []byte `asn1:"tag:0,optional"` EncryptedContent []byte `asn1:"tag:0,optional"`
} }
func unmarshalInit(raw []byte) (init initPKCS7, err error) { func unmarshalInit(raw []byte) (initPKCS7, error) {
_, err = asn1.Unmarshal(raw, &init) var init initPKCS7
if err != nil { if _, err := asn1.Unmarshal(raw, &init); err != nil {
return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err) return initPKCS7{}, certerr.ParsingError(certerr.ErrorSourceCertificate, err)
} }
return init, nil return init, nil
@@ -218,28 +218,28 @@ func populateEncryptedData(msg *PKCS7, contentBytes []byte) error {
// ParsePKCS7 attempts to parse the DER encoded bytes of a // ParsePKCS7 attempts to parse the DER encoded bytes of a
// PKCS7 structure. // PKCS7 structure.
func ParsePKCS7(raw []byte) (msg *PKCS7, err error) { func ParsePKCS7(raw []byte) (*PKCS7, error) {
pkcs7, err := unmarshalInit(raw) pkcs7, err := unmarshalInit(raw)
if err != nil { if err != nil {
return nil, err return nil, err
} }
msg = new(PKCS7) msg := new(PKCS7)
msg.Raw = pkcs7.Raw msg.Raw = pkcs7.Raw
msg.ContentInfo = pkcs7.ContentType.String() msg.ContentInfo = pkcs7.ContentType.String()
switch msg.ContentInfo { switch msg.ContentInfo {
case ObjIDData: case ObjIDData:
if err := populateData(msg, pkcs7.Content); err != nil { if e := populateData(msg, pkcs7.Content); e != nil {
return nil, err return nil, e
} }
case ObjIDSignedData: case ObjIDSignedData:
if err := populateSignedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateSignedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
case ObjIDEncryptedData: case ObjIDEncryptedData:
if err := populateEncryptedData(msg, pkcs7.Content.Bytes); err != nil { if e := populateEncryptedData(msg, pkcs7.Content.Bytes); e != nil {
return nil, err return nil, e
} }
default: default:
return nil, certerr.ParsingError( return nil, certerr.ParsingError(

View File

@@ -5,6 +5,7 @@ package revoke
import ( import (
"bytes" "bytes"
"context"
"crypto" "crypto"
"crypto/x509" "crypto/x509"
"encoding/base64" "encoding/base64"
@@ -90,34 +91,34 @@ func ldapURL(url string) bool {
// - false, true: the certificate was checked successfully, and it is not revoked. // - false, true: the certificate was checked successfully, and it is not revoked.
// - true, true: the certificate was checked successfully, and it is revoked. // - true, true: the certificate was checked successfully, and it is revoked.
// - true, false: failure to check revocation status causes verification to fail. // - true, false: failure to check revocation status causes verification to fail.
func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) { func revCheck(cert *x509.Certificate) (bool, bool, error) {
for _, url := range cert.CRLDistributionPoints { for _, url := range cert.CRLDistributionPoints {
if ldapURL(url) { if ldapURL(url) {
log.Infof("skipping LDAP CRL: %s", url) log.Infof("skipping LDAP CRL: %s", url)
continue continue
} }
if revoked, ok, err := certIsRevokedCRL(cert, url); !ok { if rvk, ok2, err2 := certIsRevokedCRL(cert, url); !ok2 {
log.Warning("error checking revocation via CRL") log.Warning("error checking revocation via CRL")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via CRL") log.Info("certificate is revoked via CRL")
return true, true, err return true, true, err2
} }
} }
if revoked, ok, err := certIsRevokedOCSP(cert, HardFail); !ok { if rvk, ok2, err2 := certIsRevokedOCSP(cert, HardFail); !ok2 {
log.Warning("error checking revocation via OCSP") log.Warning("error checking revocation via OCSP")
if HardFail { if HardFail {
return true, false, err return true, false, err2
} }
return false, false, err return false, false, err2
} else if revoked { } else if rvk {
log.Info("certificate is revoked via OCSP") log.Info("certificate is revoked via OCSP")
return true, true, err return true, true, err2
} }
return false, true, nil return false, true, nil
@@ -125,13 +126,17 @@ func revCheck(cert *x509.Certificate) (revoked, ok bool, err error) {
// fetchCRL fetches and parses a CRL. // fetchCRL fetches and parses a CRL.
func fetchCRL(url string) (*x509.RevocationList, error) { func fetchCRL(url string) (*x509.RevocationList, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode >= 300 { if resp.StatusCode >= http.StatusMultipleChoices {
return nil, errors.New("failed to retrieve CRL") return nil, errors.New("failed to retrieve CRL")
} }
@@ -158,7 +163,7 @@ func getIssuer(cert *x509.Certificate) *x509.Certificate {
// check a cert against a specific CRL. Returns the same bool pair // check a cert against a specific CRL. Returns the same bool pair
// as revCheck, plus an error if one occurred. // as revCheck, plus an error if one occurred.
func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err error) { func certIsRevokedCRL(cert *x509.Certificate, url string) (bool, bool, error) {
crlLock.Lock() crlLock.Lock()
crl, ok := CRLSet[url] crl, ok := CRLSet[url]
if ok && crl == nil { if ok && crl == nil {
@@ -186,10 +191,9 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
// check CRL signature // check CRL signature
if issuer != nil { if issuer != nil {
err = crl.CheckSignatureFrom(issuer) if sigErr := crl.CheckSignatureFrom(issuer); sigErr != nil {
if err != nil { log.Warningf("failed to verify CRL: %v", sigErr)
log.Warningf("failed to verify CRL: %v", err) return false, false, sigErr
return false, false, err
} }
} }
@@ -198,26 +202,26 @@ func certIsRevokedCRL(cert *x509.Certificate, url string) (revoked, ok bool, err
crlLock.Unlock() crlLock.Unlock()
} }
for _, revoked := range crl.RevokedCertificates { for _, entry := range crl.RevokedCertificateEntries {
if cert.SerialNumber.Cmp(revoked.SerialNumber) == 0 { if cert.SerialNumber.Cmp(entry.SerialNumber) == 0 {
log.Info("Serial number match: intermediate is revoked.") log.Info("Serial number match: intermediate is revoked.")
return true, true, err return true, true, nil
} }
} }
return false, true, err return false, true, nil
} }
// VerifyCertificate ensures that the certificate passed in hasn't // VerifyCertificate ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificate(cert *x509.Certificate) (revoked, ok bool) { func VerifyCertificate(cert *x509.Certificate) (bool, bool) {
revoked, ok, _ = VerifyCertificateError(cert) revoked, ok, _ := VerifyCertificateError(cert)
return revoked, ok return revoked, ok
} }
// VerifyCertificateError ensures that the certificate passed in hasn't // VerifyCertificateError ensures that the certificate passed in hasn't
// expired and checks the CRL for the server. // expired and checks the CRL for the server.
func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error) { func VerifyCertificateError(cert *x509.Certificate) (bool, bool, error) {
if !time.Now().Before(cert.NotAfter) { if !time.Now().Before(cert.NotAfter) {
msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter) msg := fmt.Sprintf("Certificate expired %s\n", cert.NotAfter)
log.Info(msg) log.Info(msg)
@@ -231,7 +235,11 @@ func VerifyCertificateError(cert *x509.Certificate) (revoked, ok bool, err error
} }
func fetchRemote(url string) (*x509.Certificate, error) { func fetchRemote(url string) (*x509.Certificate, error) {
resp, err := HTTPClient.Get(url) req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil)
if err != nil {
return nil, err
}
resp, err := HTTPClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -254,8 +262,12 @@ var ocspOpts = ocsp.RequestOptions{
Hash: crypto.SHA1, Hash: crypto.SHA1,
} }
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e error) { const ocspGetURLMaxLen = 256
var err error
func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (bool, bool, error) {
var revoked bool
var ok bool
var lastErr error
ocspURLs := leaf.OCSPServer ocspURLs := leaf.OCSPServer
if len(ocspURLs) == 0 { if len(ocspURLs) == 0 {
@@ -271,15 +283,16 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts) ocspRequest, err := ocsp.CreateRequest(leaf, issuer, &ocspOpts)
if err != nil { if err != nil {
return revoked, ok, err return false, false, err
} }
for _, server := range ocspURLs { for _, server := range ocspURLs {
resp, err := sendOCSPRequest(server, ocspRequest, leaf, issuer) resp, e := sendOCSPRequest(server, ocspRequest, leaf, issuer)
if err != nil { if e != nil {
if strict { if strict {
return revoked, ok, err return false, false, e
} }
lastErr = e
continue continue
} }
@@ -291,9 +304,9 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
revoked = true revoked = true
} }
return revoked, ok, err return revoked, ok, nil
} }
return revoked, ok, err return revoked, ok, lastErr
} }
// sendOCSPRequest attempts to request an OCSP response from the // sendOCSPRequest attempts to request an OCSP response from the
@@ -302,12 +315,21 @@ func certIsRevokedOCSP(leaf *x509.Certificate, strict bool) (revoked, ok bool, e
func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) { func sendOCSPRequest(server string, req []byte, leaf, issuer *x509.Certificate) (*ocsp.Response, error) {
var resp *http.Response var resp *http.Response
var err error var err error
if len(req) > 256 { if len(req) > ocspGetURLMaxLen {
buf := bytes.NewBuffer(req) buf := bytes.NewBuffer(req)
resp, err = HTTPClient.Post(server, "application/ocsp-request", buf) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodPost, server, buf)
if e != nil {
return nil, e
}
httpReq.Header.Set("Content-Type", "application/ocsp-request")
resp, err = HTTPClient.Do(httpReq)
} else { } else {
reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req)) reqURL := server + "/" + neturl.QueryEscape(base64.StdEncoding.EncodeToString(req))
resp, err = HTTPClient.Get(reqURL) httpReq, e := http.NewRequestWithContext(context.Background(), http.MethodGet, reqURL, nil)
if e != nil {
return nil, e
}
resp, err = HTTPClient.Do(httpReq)
} }
if err != nil { if err != nil {

View File

@@ -1,3 +1,4 @@
//nolint:testpackage // keep tests in the same package for internal symbol access
package revoke package revoke
import ( import (
@@ -153,7 +154,7 @@ func mustParse(pemData string) *x509.Certificate {
panic("Invalid PEM type.") panic("Invalid PEM type.")
} }
cert, err := x509.ParseCertificate([]byte(block.Bytes)) cert, err := x509.ParseCertificate(block.Bytes)
if err != nil { if err != nil {
panic(err.Error()) panic(err.Error())
} }