Compare commits

...

12 Commits

21 changed files with 1424 additions and 922 deletions

View File

@@ -247,11 +247,12 @@ linters:
# Default: false
check-type-assertions: true
exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
- (*git.wntrmute.dev/kyle/goutils/dbg.DebugPrinter).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
exhaustive:
# Program elements to check for exhaustiveness.

View File

@@ -1,5 +1,47 @@
CHANGELOG
v1.14.6 - 2025-11-18
Added:
- certlib: move tlskeypair functions into certlib.
v1.14.5 - 2025-11-18
Changed:
- certlib/verify: fix a nil-pointer dereference.
v1.14.4 - 2025-11-18
Added:
- certlib/ski: add support for return certificate SKI.
- certlib/verify: add support for verifying certificates.
Changed:
- certlib/dump: moved more functions into the dump package.
- cmd: many certificate-related commands had their functionality moved into
certlib.
v1.14.3 - 2025-11-18
Added:
- certlib/dump: the certificate dumping functions have been moved into
their own package.
Changed:
- cmd/certdump: refactor out most of the functionality into certlib/dump.
- cmd/kgz: add extended metadata support.
v1.14.2 - 2025-11-18
Added:
- lib: add tooling for generating baseline TLS configs.
Changed:
- cmd: update all commands to allow the use strict TLS configs. Note that
many of these tools are intended for debugging broken or insecure TLS
systems, and the ability to support insecure TLS configurations is
important in this regard.
v1.14.1 - 2025-11-18
Added:

339
certlib/dump/dump.go Normal file
View File

@@ -0,0 +1,339 @@
// Package dump implements tooling for dumping certificate information.
package dump
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"io"
"os"
"sort"
"strings"
"github.com/kr/text"
"git.wntrmute.dev/kyle/goutils/lib"
)
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content commitment",
x509.KeyUsageKeyEncipherment: "key encipherment",
x509.KeyUsageKeyAgreement: "key agreement",
x509.KeyUsageDataEncipherment: "data encipherment",
x509.KeyUsageCertSign: "cert sign",
x509.KeyUsageCRLSign: "crl sign",
x509.KeyUsageEncipherOnly: "encipher only",
x509.KeyUsageDecipherOnly: "decipher only",
}
var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
}
func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default:
return "unknown public key algorithm"
}
}
func sigAlgoHash(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA:
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default:
return "unknown hash algorithm"
}
}
const maxLine = 78
func makeIndent(n int) string {
s := " "
var sSb97 strings.Builder
for range n {
sSb97.WriteString(" ")
}
s += sSb97.String()
return s
}
func indentLen(n int) int {
return 4 + (8 * n)
}
// this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string {
if indent > 3 {
indent = 3
}
wrapped := text.Wrap(s, maxLine)
lines := strings.SplitN(wrapped, "\n", 2)
if len(lines) == 1 {
return lines[0]
}
if (maxLine - indentLen(indent)) <= 0 {
panic("too much indentation")
}
rest := strings.Join(lines[1:], " ")
wrapped = text.Wrap(rest, maxLine-indentLen(indent))
return lines[0] + "\n" + text.Indent(wrapped, makeIndent(indent))
}
func dumpHex(in []byte) string {
return lib.HexEncode(in, lib.HexEncodeUpperColon)
}
func certPublic(cert *x509.Certificate) string {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return fmt.Sprintf("RSA-%d", pub.N.BitLen())
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return "ECDSA-prime256v1"
case elliptic.P384():
return "ECDSA-secp384r1"
case elliptic.P521():
return "ECDSA-secp521r1"
default:
return "ECDSA (unknown curve)"
}
case *dsa.PublicKey:
return "DSA"
default:
return "Unknown"
}
}
func DisplayName(name pkix.Name) string {
var ns []string
if name.CommonName != "" {
ns = append(ns, name.CommonName)
}
for i := range name.Country {
ns = append(ns, fmt.Sprintf("C=%s", name.Country[i]))
}
for i := range name.Organization {
ns = append(ns, fmt.Sprintf("O=%s", name.Organization[i]))
}
for i := range name.OrganizationalUnit {
ns = append(ns, fmt.Sprintf("OU=%s", name.OrganizationalUnit[i]))
}
for i := range name.Locality {
ns = append(ns, fmt.Sprintf("L=%s", name.Locality[i]))
}
for i := range name.Province {
ns = append(ns, fmt.Sprintf("ST=%s", name.Province[i]))
}
if len(ns) > 0 {
return "/" + strings.Join(ns, "/")
}
return "*** no subject information ***"
}
func keyUsages(ku x509.KeyUsage) string {
var uses []string
for u, s := range keyUsage {
if (ku & u) != 0 {
uses = append(uses, s)
}
}
sort.Strings(uses)
return strings.Join(uses, ", ")
}
func extUsage(ext []x509.ExtKeyUsage) string {
ns := make([]string, 0, len(ext))
for i := range ext {
ns = append(ns, extKeyUsages[ext[i]])
}
sort.Strings(ns)
return strings.Join(ns, ", ")
}
func showBasicConstraints(cert *x509.Certificate) {
fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, "valid")
} else {
fmt.Fprint(os.Stdout, "invalid")
}
if cert.IsCA {
fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, " (basic constraint failure)")
}
} else {
fmt.Fprint(os.Stdout, ", is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
}
}
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
}
fmt.Fprintln(os.Stdout)
}
var (
dateFormat string
showHash bool // if true, print a SHA256 hash of the certificate's Raw field
)
func wrapPrint(text string, indent int) {
tabs := ""
var tabsSb140 strings.Builder
for range indent {
tabsSb140.WriteString("\t")
}
tabs += tabsSb140.String()
fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
}
func DisplayCert(w io.Writer, cert *x509.Certificate) {
fmt.Fprintln(w, "CERTIFICATE")
if showHash {
fmt.Fprintln(w, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
}
fmt.Fprintln(w, wrap("Subject: "+DisplayName(cert.Subject), 0))
fmt.Fprintln(w, wrap("Issuer: "+DisplayName(cert.Issuer), 0))
fmt.Fprintf(w, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm))
fmt.Fprintln(w, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Fprintf(w, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 {
fmt.Fprintf(w, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
}
if len(cert.SubjectKeyId) > 0 {
fmt.Fprintf(w, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
}
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Fprintf(w, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Fprintf(w, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 {
fmt.Fprintf(w, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
}
showBasicConstraints(cert)
validNames := make([]string, 0, len(cert.DNSNames)+len(cert.EmailAddresses)+len(cert.IPAddresses))
for i := range cert.DNSNames {
validNames = append(validNames, "dns:"+cert.DNSNames[i])
}
for i := range cert.EmailAddresses {
validNames = append(validNames, "email:"+cert.EmailAddresses[i])
}
for i := range cert.IPAddresses {
validNames = append(validNames, "ip:"+cert.IPAddresses[i].String())
}
sans := fmt.Sprintf("SANs (%d): %s\n", len(validNames), strings.Join(validNames, ", "))
wrapPrint(sans, 1)
l := len(cert.IssuingCertificateURL)
if l != 0 {
var aia string
if l == 1 {
aia = "AIA"
} else {
aia = "AIAs"
}
wrapPrint(fmt.Sprintf("%d %s:", l, aia), 1)
for _, url := range cert.IssuingCertificateURL {
wrapPrint(url, 2)
}
}
l = len(cert.OCSPServer)
if l > 0 {
title := "OCSP server"
if l > 1 {
title += "s"
}
wrapPrint(title+":\n", 1)
for _, ocspServer := range cert.OCSPServer {
wrapPrint(fmt.Sprintf("- %s\n", ocspServer), 2)
}
}
}

135
certlib/keymatch.go Normal file
View File

@@ -0,0 +1,135 @@
package certlib
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"os"
)
// LoadPrivateKey loads a private key from disk. It accepts both PEM and DER
// encodings and supports RSA and ECDSA keys. If the file contains a PEM block,
// the block type must be one of the recognised private key types.
func LoadPrivateKey(path string) (crypto.Signer, error) {
in, err := os.ReadFile(path)
if err != nil {
return nil, err
}
in = bytes.TrimSpace(in)
if p, _ := pem.Decode(in); p != nil {
if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type)
}
return ParsePrivateKeyPEM(in)
}
return ParsePrivateKeyDER(in)
}
var validPEMs = map[string]bool{
"PRIVATE KEY": true,
"RSA PRIVATE KEY": true,
"EC PRIVATE KEY": true,
}
const (
curveInvalid = iota // any invalid curve
curveRSA // indicates key is an RSA key, not an EC key
curveP256
curveP384
curveP521
)
func getECCurve(pub any) int {
switch pub := pub.(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return curveP256
case elliptic.P384():
return curveP384
case elliptic.P521():
return curveP521
default:
return curveInvalid
}
case *rsa.PublicKey:
return curveRSA
default:
return curveInvalid
}
}
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// MatchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func MatchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}

157
certlib/ski/ski.go Normal file
View File

@@ -0,0 +1,157 @@
package ski
import (
"bytes"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/rsa"
"crypto/sha1" // #nosec G505 this is the standard
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
"fmt"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
keyTypeEd25519 = "Ed25519"
)
type subjectPublicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}
type KeyInfo struct {
PublicKey []byte
KeyType string
FileType string
}
func (k *KeyInfo) String() string {
return fmt.Sprintf("%s (%s)", lib.HexEncode(k.PublicKey, lib.HexEncodeLowerColon), k.KeyType)
}
func (k *KeyInfo) SKI(displayMode lib.HexEncodeMode) (string, error) {
var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(k.PublicKey, &subPKI)
if err != nil {
return "", fmt.Errorf("serializing SKI: %w", err)
}
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := lib.HexEncode(pubHash[:], displayMode)
return pubHashString, nil
}
// ParsePEM parses a PEM file and returns the public key and its type.
func ParsePEM(path string) (*KeyInfo, error) {
material := &KeyInfo{}
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("parsing X.509 material %s: %w", path, err)
}
data = bytes.TrimSpace(data)
p, rest := pem.Decode(data)
if len(rest) > 0 {
lib.Warnx("trailing data in PEM file")
}
if p == nil {
return nil, fmt.Errorf("no PEM data in %s", path)
}
data = p.Bytes
switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
material.PublicKey, material.KeyType = parseKey(data)
material.FileType = "private key"
case "CERTIFICATE":
material.PublicKey, material.KeyType = parseCertificate(data)
material.FileType = "certificate"
case "CERTIFICATE REQUEST":
material.PublicKey, material.KeyType = parseCSR(data)
material.FileType = "certificate request"
default:
return nil, fmt.Errorf("unknown PEM type %s", p.Type)
}
return material, nil
}
func parseKey(data []byte) ([]byte, string) {
priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil {
die.If(err)
}
var kt string
switch priv.Public().(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown private key type %T", priv)
}
public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err)
return public, kt
}
func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data)
die.If(err)
pub := cert.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
case *ed25519.PublicKey:
kt = keyTypeEd25519
default:
die.With("unknown public key type %T", pub)
}
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return public, kt
}
func parseCSR(data []byte) ([]byte, string) {
// Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err)
pub := csr.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown public key type %T", pub)
}
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return public, kt
}

49
certlib/verify/check.go Normal file
View File

@@ -0,0 +1,49 @@
package verify
import (
"crypto/x509"
"fmt"
"time"
"git.wntrmute.dev/kyle/goutils/certlib/dump"
)
const DefaultLeeway = 2160 * time.Hour // three months
type CertCheck struct {
Cert *x509.Certificate
leeway time.Duration
}
func NewCertCheck(cert *x509.Certificate, leeway time.Duration) *CertCheck {
return &CertCheck{
Cert: cert,
leeway: leeway,
}
}
func (c CertCheck) Expiry() time.Duration {
return time.Until(c.Cert.NotAfter)
}
func (c CertCheck) IsExpiring(leeway time.Duration) bool {
return c.Expiry() < leeway
}
// Err returns nil if the certificate is not expiring within the leeway period.
func (c CertCheck) Err() error {
if !c.IsExpiring(c.leeway) {
return nil
}
return fmt.Errorf("%s expires in %s", dump.DisplayName(c.Cert.Subject), c.Expiry())
}
func (c CertCheck) Name() string {
return fmt.Sprintf("%s/SN=%s", dump.DisplayName(c.Cert.Subject),
c.Cert.SerialNumber)
}
func (c CertCheck) String() string {
return fmt.Sprintf("%s expires on %s (in %s)\n", c.Name(), c.Cert.NotAfter, c.Expiry())
}

143
certlib/verify/verify.go Normal file
View File

@@ -0,0 +1,143 @@
package verify
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/lib"
)
func bundleIntermediates(w io.Writer, chain []*x509.Certificate, pool *x509.CertPool, verbose bool) *x509.CertPool {
for _, intermediate := range chain[1:] {
if verbose {
fmt.Fprintf(w, "[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
}
pool.AddCert(intermediate)
}
return pool
}
type Opts struct {
Verbose bool
Config *tls.Config
Intermediates *x509.CertPool
ForceIntermediates bool
CheckRevocation bool
KeyUsages []x509.ExtKeyUsage
}
type verifyResult struct {
chain []*x509.Certificate
roots *x509.CertPool
ints *x509.CertPool
}
func prepareVerification(w io.Writer, target string, opts *Opts) (*verifyResult, error) {
var (
roots, ints *x509.CertPool
err error
)
if opts == nil {
opts = &Opts{
Config: lib.StrictBaselineTLSConfig(),
ForceIntermediates: false,
}
}
if opts.Config.RootCAs == nil {
roots, err = x509.SystemCertPool()
if err != nil {
return nil, fmt.Errorf("couldn't load system cert pool: %w", err)
}
opts.Config.RootCAs = roots
}
if opts.Intermediates == nil {
ints = x509.NewCertPool()
} else {
ints = opts.Intermediates.Clone()
}
roots = opts.Config.RootCAs.Clone()
chain, err := lib.GetCertificateChain(target, opts.Config)
if err != nil {
return nil, fmt.Errorf("fetching certificate chain: %w", err)
}
if opts.Verbose {
fmt.Fprintf(w, "[+] %s has %d certificates\n", target, len(chain))
}
if len(chain) > 1 && opts.ForceIntermediates {
ints = bundleIntermediates(w, chain, ints, opts.Verbose)
}
return &verifyResult{
chain: chain,
roots: roots,
ints: ints,
}, nil
}
// Chain fetches the certificate chain for a target and verifies it.
func Chain(w io.Writer, target string, opts *Opts) ([]*x509.Certificate, error) {
result, err := prepareVerification(w, target, opts)
if err != nil {
return nil, fmt.Errorf("certificate verification failed: %w", err)
}
chains, err := CertWith(result.chain[0], result.roots, result.ints, opts.CheckRevocation, opts.KeyUsages...)
if err != nil {
return nil, fmt.Errorf("certificate verification failed: %w", err)
}
return chains, nil
}
// CertWith verifies a certificate against a set of roots and intermediates.
func CertWith(
cert *x509.Certificate,
roots, ints *x509.CertPool,
checkRevocation bool,
keyUses ...x509.ExtKeyUsage,
) ([]*x509.Certificate, error) {
if len(keyUses) == 0 {
keyUses = []x509.ExtKeyUsage{x509.ExtKeyUsageAny}
}
opts := x509.VerifyOptions{
Intermediates: ints,
Roots: roots,
KeyUsages: keyUses,
}
chains, err := cert.Verify(opts)
if err != nil {
return nil, err
}
if checkRevocation {
revoked, ok := revoke.VerifyCertificate(cert)
if !ok {
return nil, errors.New("failed to check certificate revocation status")
}
if revoked {
return nil, errors.New("certificate is revoked")
}
}
if len(chains) == 0 {
return nil, errors.New("no valid certificate chain found")
}
return chains[0], nil
}

View File

@@ -2,374 +2,44 @@
package main
import (
"crypto/dsa"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/sha256"
"crypto/x509"
"crypto/x509/pkix"
"crypto/tls"
"flag"
"fmt"
"os"
"sort"
"strings"
"github.com/kr/text"
"git.wntrmute.dev/kyle/goutils/certlib/dump"
"git.wntrmute.dev/kyle/goutils/lib"
)
// following two lifted from CFSSL, (replace-regexp "\(.+\): \(.+\),"
// "\2: \1,")
const (
sSHA256 = "SHA256"
sSHA512 = "SHA512"
)
var keyUsage = map[x509.KeyUsage]string{
x509.KeyUsageDigitalSignature: "digital signature",
x509.KeyUsageContentCommitment: "content committment",
x509.KeyUsageKeyEncipherment: "key encipherment",
x509.KeyUsageKeyAgreement: "key agreement",
x509.KeyUsageDataEncipherment: "data encipherment",
x509.KeyUsageCertSign: "cert sign",
x509.KeyUsageCRLSign: "crl sign",
x509.KeyUsageEncipherOnly: "encipher only",
x509.KeyUsageDecipherOnly: "decipher only",
}
var extKeyUsages = map[x509.ExtKeyUsage]string{
x509.ExtKeyUsageAny: "any",
x509.ExtKeyUsageServerAuth: "server auth",
x509.ExtKeyUsageClientAuth: "client auth",
x509.ExtKeyUsageCodeSigning: "code signing",
x509.ExtKeyUsageEmailProtection: "s/mime",
x509.ExtKeyUsageIPSECEndSystem: "ipsec end system",
x509.ExtKeyUsageIPSECTunnel: "ipsec tunnel",
x509.ExtKeyUsageIPSECUser: "ipsec user",
x509.ExtKeyUsageTimeStamping: "timestamping",
x509.ExtKeyUsageOCSPSigning: "ocsp signing",
x509.ExtKeyUsageMicrosoftServerGatedCrypto: "microsoft sgc",
x509.ExtKeyUsageNetscapeServerGatedCrypto: "netscape sgc",
x509.ExtKeyUsageMicrosoftCommercialCodeSigning: "microsoft commercial code signing",
x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing",
}
func sigAlgoPK(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA:
return "RSA"
case x509.SHA256WithRSAPSS, x509.SHA384WithRSAPSS, x509.SHA512WithRSAPSS:
return "RSA-PSS"
case x509.ECDSAWithSHA1, x509.ECDSAWithSHA256, x509.ECDSAWithSHA384, x509.ECDSAWithSHA512:
return "ECDSA"
case x509.DSAWithSHA1, x509.DSAWithSHA256:
return "DSA"
case x509.PureEd25519:
return "Ed25519"
case x509.UnknownSignatureAlgorithm:
return "unknown public key algorithm"
default:
return "unknown public key algorithm"
}
}
func sigAlgoHash(a x509.SignatureAlgorithm) string {
switch a {
case x509.MD2WithRSA:
return "MD2"
case x509.MD5WithRSA:
return "MD5"
case x509.SHA1WithRSA, x509.ECDSAWithSHA1, x509.DSAWithSHA1:
return "SHA1"
case x509.SHA256WithRSA, x509.ECDSAWithSHA256, x509.DSAWithSHA256:
return sSHA256
case x509.SHA256WithRSAPSS:
return sSHA256
case x509.SHA384WithRSA, x509.ECDSAWithSHA384:
return "SHA384"
case x509.SHA384WithRSAPSS:
return "SHA384"
case x509.SHA512WithRSA, x509.ECDSAWithSHA512:
return sSHA512
case x509.SHA512WithRSAPSS:
return sSHA512
case x509.PureEd25519:
return sSHA512
case x509.UnknownSignatureAlgorithm:
return "unknown hash algorithm"
default:
return "unknown hash algorithm"
}
}
const maxLine = 78
func makeIndent(n int) string {
s := " "
var sSb97 strings.Builder
for range n {
sSb97.WriteString(" ")
}
s += sSb97.String()
return s
}
func indentLen(n int) int {
return 4 + (8 * n)
}
// this isn't real efficient, but that's not a problem here.
func wrap(s string, indent int) string {
if indent > 3 {
indent = 3
}
wrapped := text.Wrap(s, maxLine)
lines := strings.SplitN(wrapped, "\n", 2)
if len(lines) == 1 {
return lines[0]
}
if (maxLine - indentLen(indent)) <= 0 {
panic("too much indentation")
}
rest := strings.Join(lines[1:], " ")
wrapped = text.Wrap(rest, maxLine-indentLen(indent))
return lines[0] + "\n" + text.Indent(wrapped, makeIndent(indent))
}
func dumpHex(in []byte) string {
return lib.HexEncode(in, lib.HexEncodeUpperColon)
}
func certPublic(cert *x509.Certificate) string {
switch pub := cert.PublicKey.(type) {
case *rsa.PublicKey:
return fmt.Sprintf("RSA-%d", pub.N.BitLen())
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return "ECDSA-prime256v1"
case elliptic.P384():
return "ECDSA-secp384r1"
case elliptic.P521():
return "ECDSA-secp521r1"
default:
return "ECDSA (unknown curve)"
}
case *dsa.PublicKey:
return "DSA"
default:
return "Unknown"
}
}
func displayName(name pkix.Name) string {
var ns []string
if name.CommonName != "" {
ns = append(ns, name.CommonName)
}
for i := range name.Country {
ns = append(ns, fmt.Sprintf("C=%s", name.Country[i]))
}
for i := range name.Organization {
ns = append(ns, fmt.Sprintf("O=%s", name.Organization[i]))
}
for i := range name.OrganizationalUnit {
ns = append(ns, fmt.Sprintf("OU=%s", name.OrganizationalUnit[i]))
}
for i := range name.Locality {
ns = append(ns, fmt.Sprintf("L=%s", name.Locality[i]))
}
for i := range name.Province {
ns = append(ns, fmt.Sprintf("ST=%s", name.Province[i]))
}
if len(ns) > 0 {
return "/" + strings.Join(ns, "/")
}
return "*** no subject information ***"
}
func keyUsages(ku x509.KeyUsage) string {
var uses []string
for u, s := range keyUsage {
if (ku & u) != 0 {
uses = append(uses, s)
}
}
sort.Strings(uses)
return strings.Join(uses, ", ")
}
func extUsage(ext []x509.ExtKeyUsage) string {
ns := make([]string, 0, len(ext))
for i := range ext {
ns = append(ns, extKeyUsages[ext[i]])
}
sort.Strings(ns)
return strings.Join(ns, ", ")
}
func showBasicConstraints(cert *x509.Certificate) {
fmt.Fprint(os.Stdout, "\tBasic constraints: ")
if cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, "valid")
} else {
fmt.Fprint(os.Stdout, "invalid")
}
if cert.IsCA {
fmt.Fprint(os.Stdout, ", is a CA certificate")
if !cert.BasicConstraintsValid {
fmt.Fprint(os.Stdout, " (basic constraint failure)")
}
} else {
fmt.Fprint(os.Stdout, ", is not a CA certificate")
if cert.KeyUsage&x509.KeyUsageKeyEncipherment != 0 {
fmt.Fprint(os.Stdout, " (key encipherment usage enabled!)")
}
}
if (cert.MaxPathLen == 0 && cert.MaxPathLenZero) || (cert.MaxPathLen > 0) {
fmt.Fprintf(os.Stdout, ", max path length %d", cert.MaxPathLen)
}
fmt.Fprintln(os.Stdout)
}
const oneTrueDateFormat = "2006-01-02T15:04:05-0700"
var (
var config struct {
showHash bool
dateFormat string
showHash bool // if true, print a SHA256 hash of the certificate's Raw field
)
func wrapPrint(text string, indent int) {
tabs := ""
var tabsSb140 strings.Builder
for range indent {
tabsSb140.WriteString("\t")
}
tabs += tabsSb140.String()
fmt.Fprintf(os.Stdout, tabs+"%s\n", wrap(text, indent))
}
func displayCert(cert *x509.Certificate) {
fmt.Fprintln(os.Stdout, "CERTIFICATE")
if showHash {
fmt.Fprintln(os.Stdout, wrap(fmt.Sprintf("SHA256: %x", sha256.Sum256(cert.Raw)), 0))
}
fmt.Fprintln(os.Stdout, wrap("Subject: "+displayName(cert.Subject), 0))
fmt.Fprintln(os.Stdout, wrap("Issuer: "+displayName(cert.Issuer), 0))
fmt.Fprintf(os.Stdout, "\tSignature algorithm: %s / %s\n", sigAlgoPK(cert.SignatureAlgorithm),
sigAlgoHash(cert.SignatureAlgorithm))
fmt.Fprintln(os.Stdout, "Details:")
wrapPrint("Public key: "+certPublic(cert), 1)
fmt.Fprintf(os.Stdout, "\tSerial number: %s\n", cert.SerialNumber)
if len(cert.AuthorityKeyId) > 0 {
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("AKI: "+dumpHex(cert.AuthorityKeyId), 1))
}
if len(cert.SubjectKeyId) > 0 {
fmt.Fprintf(os.Stdout, "\t%s\n", wrap("SKI: "+dumpHex(cert.SubjectKeyId), 1))
}
wrapPrint("Valid from: "+cert.NotBefore.Format(dateFormat), 1)
fmt.Fprintf(os.Stdout, "\t until: %s\n", cert.NotAfter.Format(dateFormat))
fmt.Fprintf(os.Stdout, "\tKey usages: %s\n", keyUsages(cert.KeyUsage))
if len(cert.ExtKeyUsage) > 0 {
fmt.Fprintf(os.Stdout, "\tExtended usages: %s\n", extUsage(cert.ExtKeyUsage))
}
showBasicConstraints(cert)
validNames := make([]string, 0, len(cert.DNSNames)+len(cert.EmailAddresses)+len(cert.IPAddresses))
for i := range cert.DNSNames {
validNames = append(validNames, "dns:"+cert.DNSNames[i])
}
for i := range cert.EmailAddresses {
validNames = append(validNames, "email:"+cert.EmailAddresses[i])
}
for i := range cert.IPAddresses {
validNames = append(validNames, "ip:"+cert.IPAddresses[i].String())
}
sans := fmt.Sprintf("SANs (%d): %s\n", len(validNames), strings.Join(validNames, ", "))
wrapPrint(sans, 1)
l := len(cert.IssuingCertificateURL)
if l != 0 {
var aia string
if l == 1 {
aia = "AIA"
} else {
aia = "AIAs"
}
wrapPrint(fmt.Sprintf("%d %s:", l, aia), 1)
for _, url := range cert.IssuingCertificateURL {
wrapPrint(url, 2)
}
}
l = len(cert.OCSPServer)
if l > 0 {
title := "OCSP server"
if l > 1 {
title += "s"
}
wrapPrint(title+":\n", 1)
for _, ocspServer := range cert.OCSPServer {
wrapPrint(fmt.Sprintf("- %s\n", ocspServer), 2)
}
}
leafOnly bool
}
func main() {
var leafOnly bool
flag.BoolVar(&showHash, "d", false, "show hashes of raw DER contents")
flag.StringVar(&dateFormat, "s", oneTrueDateFormat, "date `format` in Go time format")
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.BoolVar(&config.showHash, "d", false, "show hashes of raw DER contents")
flag.StringVar(&config.dateFormat, "s", lib.OneTrueDateFormat, "date `format` in Go time format")
flag.BoolVar(&config.leafOnly, "l", false, "only show the leaf certificate")
flag.Parse()
opts := &lib.FetcherOpts{
SkipVerify: true,
Roots: nil,
}
tlsCfg := &tls.Config{InsecureSkipVerify: true} // #nosec G402 - tool intentionally inspects broken TLS
for _, filename := range flag.Args() {
fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
certs, err := lib.GetCertificateChain(filename, opts)
certs, err := lib.GetCertificateChain(filename, tlsCfg)
if err != nil {
_, _ = lib.Warn(err, "couldn't read certificate")
lib.Warn(err, "couldn't read certificate")
continue
}
if leafOnly {
displayCert(certs[0])
if config.leafOnly {
dump.DisplayCert(os.Stdout, certs[0])
continue
}
for i := range certs {
displayCert(certs[i])
dump.DisplayCert(os.Stdout, certs[i])
}
}
}

View File

@@ -2,94 +2,52 @@ package main
import (
"crypto/x509"
"crypto/x509/pkix"
"flag"
"fmt"
"os"
"strings"
"time"
"git.wntrmute.dev/kyle/goutils/certlib/verify"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
var warnOnly bool
var leeway = 2160 * time.Hour // three months
func displayName(name pkix.Name) string {
var ns []string
if name.CommonName != "" {
ns = append(ns, name.CommonName)
}
for i := range name.Country {
ns = append(ns, fmt.Sprintf("C=%s", name.Country[i]))
}
for i := range name.Organization {
ns = append(ns, fmt.Sprintf("O=%s", name.Organization[i]))
}
for i := range name.OrganizationalUnit {
ns = append(ns, fmt.Sprintf("OU=%s", name.OrganizationalUnit[i]))
}
for i := range name.Locality {
ns = append(ns, fmt.Sprintf("L=%s", name.Locality[i]))
}
for i := range name.Province {
ns = append(ns, fmt.Sprintf("ST=%s", name.Province[i]))
}
if len(ns) > 0 {
return "/" + strings.Join(ns, "/")
}
die.With("no subject information in root")
return ""
}
func expires(cert *x509.Certificate) time.Duration {
return time.Until(cert.NotAfter)
}
func inDanger(cert *x509.Certificate) bool {
return expires(cert) < leeway
}
func checkCert(cert *x509.Certificate) {
warn := inDanger(cert)
name := displayName(cert.Subject)
name = fmt.Sprintf("%s/SN=%s", name, cert.SerialNumber)
expiry := expires(cert)
if warnOnly {
if warn {
fmt.Fprintf(os.Stderr, "%s expires on %s (in %s)\n", name, cert.NotAfter, expiry)
}
} else {
fmt.Printf("%s expires on %s (in %s)\n", name, cert.NotAfter, expiry)
}
}
func main() {
opts := &lib.FetcherOpts{}
var (
skipVerify bool
strictTLS bool
leeway = verify.DefaultLeeway
warnOnly bool
)
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
lib.StrictTLSFlag(&strictTLS)
flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402
flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs")
flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring")
flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
for _, file := range flag.Args() {
certs, err := lib.GetCertificateChain(file, opts)
var certs []*x509.Certificate
certs, err = lib.GetCertificateChain(file, tlsCfg)
if err != nil {
_, _ = lib.Warn(err, "while parsing certificates")
continue
}
for _, cert := range certs {
checkCert(cert)
check := verify.NewCertCheck(cert, leeway)
if warnOnly {
if err = check.Err(); err != nil {
lib.Warn(err, "certificate is expiring")
}
} else {
fmt.Printf("%s expires on %s (in %s)\n", check.Name(),
cert.NotAfter, check.Expiry())
}
}
}
}

View File

@@ -31,16 +31,23 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
}
func main() {
opts := &lib.FetcherOpts{}
var skipVerify bool
var strictTLS bool
lib.StrictTLSFlag(&strictTLS)
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date")
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification")
flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402
flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() {
cert, err := lib.GetCertificate(arg, opts)
var cert *x509.Certificate
cert, err = lib.GetCertificate(arg, tlsCfg)
die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode))

View File

@@ -5,33 +5,18 @@ import (
"flag"
"fmt"
"os"
"time"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/revoke"
"git.wntrmute.dev/kyle/goutils/certlib/verify"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
func printRevocation(cert *x509.Certificate) {
remaining := time.Until(cert.NotAfter)
fmt.Printf("certificate expires in %s.\n", lib.Duration(remaining))
revoked, ok := revoke.VerifyCertificate(cert)
if !ok {
fmt.Fprintf(os.Stderr, "[!] the revocation check failed (failed to determine whether certificate\nwas revoked)")
return
}
if revoked {
fmt.Fprintf(os.Stderr, "[!] the certificate has been revoked\n")
return
}
}
type appConfig struct {
caFile, intFile string
forceIntermediateBundle bool
revexp, skipVerify, verbose bool
strictTLS bool
}
func parseFlags() appConfig {
@@ -43,107 +28,66 @@ func parseFlags() appConfig {
flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose")
lib.StrictTLSFlag(&cfg.strictTLS)
flag.Parse()
if flag.NArg() == 0 {
die.With("usage: certverify targets...")
}
return cfg
}
func loadRoots(caFile string, verbose bool) (*x509.CertPool, error) {
if caFile == "" {
return x509.SystemCertPool()
}
if verbose {
fmt.Println("[+] loading root certificates from", caFile)
}
return certlib.LoadPEMCertPool(caFile)
}
func loadIntermediates(intFile string, verbose bool) (*x509.CertPool, error) {
if intFile == "" {
return x509.NewCertPool(), nil
}
if verbose {
fmt.Println("[+] loading intermediate certificates from", intFile)
}
// Note: use intFile here (previously used caFile mistakenly)
return certlib.LoadPEMCertPool(intFile)
}
func addBundledIntermediates(chain []*x509.Certificate, pool *x509.CertPool, verbose bool) {
for _, intermediate := range chain[1:] {
if verbose {
fmt.Printf("[+] adding intermediate with SKI %x\n", intermediate.SubjectKeyId)
}
pool.AddCert(intermediate)
}
}
func verifyCert(cert *x509.Certificate, roots, ints *x509.CertPool) error {
opts := x509.VerifyOptions{
Intermediates: ints,
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
}
_, err := cert.Verify(opts)
return err
}
func run(cfg appConfig) error {
roots, err := loadRoots(cfg.caFile, cfg.verbose)
if err != nil {
return err
}
ints, err := loadIntermediates(cfg.intFile, cfg.verbose)
if err != nil {
return err
}
if flag.NArg() != 1 {
fmt.Fprintf(os.Stderr, "Usage: %s [-ca bundle] [-i bundle] cert", lib.ProgName())
}
combinedPool, err := certlib.LoadFullCertPool(cfg.caFile, cfg.intFile)
if err != nil {
return fmt.Errorf("failed to build combined pool: %w", err)
}
opts := &lib.FetcherOpts{
Roots: combinedPool,
SkipVerify: cfg.skipVerify,
}
chain, err := lib.GetCertificateChain(flag.Arg(0), opts)
if err != nil {
return err
}
if cfg.verbose {
fmt.Printf("[+] %s has %d certificates\n", flag.Arg(0), len(chain))
}
cert := chain[0]
if len(chain) > 1 && !cfg.forceIntermediateBundle {
addBundledIntermediates(chain, ints, cfg.verbose)
}
if err = verifyCert(cert, roots, ints); err != nil {
return fmt.Errorf("certificate verification failed: %w", err)
}
if cfg.verbose {
fmt.Println("OK")
}
if cfg.revexp {
printRevocation(cert)
}
return nil
}
func main() {
var (
roots, ints *x509.CertPool
err error
failed bool
)
cfg := parseFlags()
if err := run(cfg); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
opts := &verify.Opts{
CheckRevocation: cfg.revexp,
ForceIntermediates: cfg.forceIntermediateBundle,
Verbose: cfg.verbose,
}
if cfg.caFile != "" {
if cfg.verbose {
fmt.Printf("loading CA certificates from %s\n", cfg.caFile)
}
roots, err = certlib.LoadPEMCertPool(cfg.caFile)
die.If(err)
}
if cfg.intFile != "" {
if cfg.verbose {
fmt.Printf("loading intermediate certificates from %s\n", cfg.intFile)
}
ints, err = certlib.LoadPEMCertPool(cfg.intFile)
die.If(err)
}
opts.Config, err = lib.BaselineTLSConfig(cfg.skipVerify, cfg.strictTLS)
die.If(err)
opts.Config.RootCAs = roots
opts.Intermediates = ints
for _, arg := range flag.Args() {
_, err = verify.Chain(os.Stdout, arg, opts)
if err != nil {
lib.Warn(err, "while verifying %s", arg)
failed = true
} else {
fmt.Printf("%s: OK\n", arg)
}
}
if failed {
os.Exit(1)
}
}

View File

@@ -3,18 +3,31 @@ kgz
kgz is like gzip, but supports compressing and decompressing to a different
directory than the source file is in.
Usage: kgz [-l] source [target]
Usage: kgz [-l] [-k] [-m] [-x] [--uid N] [--gid N] source [target]
If target is a directory, the basename of the sourcefile will be used
If target is a directory, the basename of the source file will be used
as the target filename. Compression and decompression is selected
based on whether the source filename ends in ".gz".
Flags:
-l level Compression level (0-9). Only meaninful when
compressing a file.
-l level Compression level (0-9). Only meaningful when compressing.
-u Do not restrict the size during decompression. As
a safeguard against gzip bombs, the maximum size
allowed is 32 * the compressed file size.
-k Keep the source file (do not remove it after successful
compression or decompression).
-m On decompression, set the file mtime from the gzip header.
-x On compression, include uid/gid/mode/ctime in the gzip Extra
field so that decompression can restore them. The Extra payload
is an ASN.1 DER-encoded struct.
--uid N When used with -x, set UID in Extra to N (override source).
--gid N When used with -x, set GID in Extra to N (override source).
Metadata notes:
- mtime is stored in the standard gzip header and restored with -m.
- uid/gid/mode/ctime are stored in a kgz-specific Extra subfield as an ASN.1
DER-encoded struct. Restoring
uid/gid may fail without sufficient privileges; such errors are ignored.

View File

@@ -3,23 +3,178 @@ package main
import (
"compress/flate"
"compress/gzip"
"encoding/asn1"
"encoding/binary"
"flag"
"fmt"
"io"
"math"
"os"
"path/filepath"
"strings"
"golang.org/x/sys/unix"
goutilslib "git.wntrmute.dev/kyle/goutils/lib"
)
const gzipExt = ".gz"
func compress(path, target string, level int) error {
// kgzExtraID is the two-byte subfield identifier used in the gzip Extra field
// for kgz-specific metadata.
var kgzExtraID = [2]byte{'K', 'G'}
// buildKGExtra constructs the gzip Extra subfield payload for kgz metadata.
//
// The payload is an ASN.1 DER-encoded struct with the following fields:
//
// Version INTEGER (currently 1)
// UID INTEGER
// GID INTEGER
// Mode INTEGER (permission bits)
// CTimeSec INTEGER (seconds)
// CTimeNSec INTEGER (nanoseconds)
//
// The ASN.1 blob is wrapped in a gzip Extra subfield with ID 'K','G'.
func buildKGExtra(uid, gid, mode uint32, ctimeS int64, ctimeNs int32) []byte {
// Define the ASN.1 structure to encode
type KGZExtra struct {
Version int
UID int
GID int
Mode int
CTimeSec int64
CTimeNSec int32
}
payload, err := asn1.Marshal(KGZExtra{
Version: 1,
UID: int(uid),
GID: int(gid),
Mode: int(mode),
CTimeSec: ctimeS,
CTimeNSec: ctimeNs,
})
if err != nil {
// On marshal failure, return empty to avoid breaking compression
return nil
}
// Wrap in gzip subfield: [ID1 ID2 LEN(lo) LEN(hi) PAYLOAD]
// Guard against payload length overflow to uint16 for the extra subfield length.
if len(payload) > int(math.MaxUint16) {
return nil
}
extra := make([]byte, 4+len(payload))
extra[0] = kgzExtraID[0]
extra[1] = kgzExtraID[1]
binary.LittleEndian.PutUint16(extra[2:], uint16(len(payload)&0xFFFF)) //#nosec G115 - masked
copy(extra[4:], payload)
return extra
}
// clampToInt32 clamps an int value into the int32 range using a switch to
// satisfy linters that prefer switch over if-else chains for ordered checks.
func clampToInt32(v int) int32 {
switch {
case v > int(math.MaxInt32):
return math.MaxInt32
case v < int(math.MinInt32):
return math.MinInt32
default:
return int32(v)
}
}
// buildExtraForPath prepares the gzip Extra field for kgz by collecting
// uid/gid/mode and ctime information, applying any overrides, and encoding it.
func buildExtraForPath(st unix.Stat_t, path string, setUID, setGID int) []byte {
uid := st.Uid
gid := st.Gid
if setUID >= 0 {
if uint64(setUID) <= math.MaxUint32 {
uid = uint32(setUID & 0xFFFFFFFF) //#nosec G115 - masked
}
}
if setGID >= 0 {
if uint64(setGID) <= math.MaxUint32 {
gid = uint32(setGID & 0xFFFFFFFF) //#nosec G115 - masked
}
}
mode := uint32(st.Mode & 0o7777)
// Use portable helper to gather ctime
var cts int64
var ctns int32
if ft, err := goutilslib.LoadFileTime(path); err == nil {
cts = ft.Changed.Unix()
ctns = clampToInt32(ft.Changed.Nanosecond())
}
return buildKGExtra(uid, gid, mode, cts, ctns)
}
// parseKGExtra scans a gzip Extra blob and returns kgz metadata if present.
func parseKGExtra(extra []byte) (uint32, uint32, uint32, int64, int32, bool) {
i := 0
for i+4 <= len(extra) {
id1 := extra[i]
id2 := extra[i+1]
l := int(binary.LittleEndian.Uint16(extra[i+2 : i+4]))
i += 4
if i+l > len(extra) {
break
}
if id1 == kgzExtraID[0] && id2 == kgzExtraID[1] {
// ASN.1 decode payload
payload := extra[i : i+l]
var s struct {
Version int
UID int
GID int
Mode int
CTimeSec int64
CTimeNSec int32
}
if _, err := asn1.Unmarshal(payload, &s); err != nil {
return 0, 0, 0, 0, 0, false
}
if s.Version != 1 {
return 0, 0, 0, 0, 0, false
}
// Validate ranges before converting from int -> uint32 to avoid overflow.
if s.UID < 0 || s.GID < 0 || s.Mode < 0 {
return 0, 0, 0, 0, 0, false
}
if uint64(s.UID) > math.MaxUint32 || uint64(s.GID) > math.MaxUint32 || uint64(s.Mode) > math.MaxUint32 {
return 0, 0, 0, 0, 0, false
}
return uint32(s.UID & 0xFFFFFFFF), uint32(s.GID & 0xFFFFFFFF),
uint32(s.Mode & 0xFFFFFFFF), s.CTimeSec, s.CTimeNSec, true //#nosec G115 - masked
}
i += l
}
return 0, 0, 0, 0, 0, false
}
func compress(path, target string, level int, includeExtra bool, setUID, setGID int) error {
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
}
defer sourceFile.Close()
// Gather file metadata
var st unix.Stat_t
if err = unix.Stat(path, &st); err != nil {
return fmt.Errorf("stat source: %w", err)
}
fi, err := sourceFile.Stat()
if err != nil {
return fmt.Errorf("stat source file: %w", err)
}
destFile, err := os.Create(target)
if err != nil {
return fmt.Errorf("opening file for write: %w", err)
@@ -30,6 +185,11 @@ func compress(path, target string, level int) error {
if err != nil {
return fmt.Errorf("invalid compression level: %w", err)
}
// Set header metadata
gzipCompressor.ModTime = fi.ModTime()
if includeExtra {
gzipCompressor.Extra = buildExtraForPath(st, path, setUID, setGID)
}
defer gzipCompressor.Close()
_, err = io.Copy(gzipCompressor, sourceFile)
@@ -40,7 +200,7 @@ func compress(path, target string, level int) error {
return nil
}
func uncompress(path, target string, unrestrict bool) error {
func uncompress(path, target string, unrestrict bool, preserveMtime bool) error {
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("opening file for read: %w", err)
@@ -79,19 +239,40 @@ func uncompress(path, target string, unrestrict bool) error {
if err != nil {
return fmt.Errorf("uncompressing file: %w", err)
}
// Apply metadata from Extra (uid/gid/mode) if present
if gzipUncompressor.Header.Extra != nil {
if uid, gid, mode, _, _, ok := parseKGExtra(gzipUncompressor.Header.Extra); ok {
// Chmod
_ = os.Chmod(target, os.FileMode(mode))
// Chown (may fail without privileges)
_ = os.Chown(target, int(uid), int(gid))
}
}
// Preserve mtime if requested
if preserveMtime {
mt := gzipUncompressor.Header.ModTime
if !mt.IsZero() {
// Set both atime and mtime to mt for simplicity
_ = os.Chtimes(target, mt, mt)
}
}
return nil
}
func usage(w io.Writer) {
fmt.Fprintf(w, `Usage: %s [-l] source [target]
fmt.Fprintf(w, `Usage: %s [-l] [-k] [-m] [-x] [--uid N] [--gid N] source [target]
kgz is like gzip, but supports compressing and decompressing to a different
directory than the source file is in.
Flags:
-l level Compression level (0-9). Only meaninful when
compressing a file.
-l level Compression level (0-9). Only meaningful when compressing.
-u Do not restrict the size during decompression (gzip bomb guard is 32x).
-k Keep the source file (do not remove it after successful (de)compression).
-m On decompression, set the file mtime from the gzip header.
-x On compression, include uid/gid/mode/ctime in the gzip Extra field.
--uid N When used with -x, set UID in Extra to N (overrides source owner).
--gid N When used with -x, set GID in Extra to N (overrides source group).
`, os.Args[0])
}
@@ -150,9 +331,19 @@ func main() {
var target = "."
var err error
var unrestrict bool
var keep bool
var preserveMtime bool
var includeExtra bool
var setUID int
var setGID int
flag.IntVar(&level, "l", flate.DefaultCompression, "compression level")
flag.BoolVar(&unrestrict, "u", false, "do not restrict decompression")
flag.BoolVar(&keep, "k", false, "keep the source file (do not remove it)")
flag.BoolVar(&preserveMtime, "m", false, "on decompression, set mtime from gzip header")
flag.BoolVar(&includeExtra, "x", false, "on compression, include uid/gid/mode/ctime in gzip Extra")
flag.IntVar(&setUID, "uid", -1, "when used with -x, set UID in Extra to this value")
flag.IntVar(&setGID, "gid", -1, "when used with -x, set GID in Extra to this value")
flag.Parse()
if flag.NArg() < 1 || flag.NArg() > 2 {
@@ -172,12 +363,15 @@ func main() {
os.Exit(1)
}
err = uncompress(path, target, unrestrict)
err = uncompress(path, target, unrestrict, preserveMtime)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
if !keep {
_ = os.Remove(path)
}
return
}
@@ -187,10 +381,13 @@ func main() {
os.Exit(1)
}
err = compress(path, target, level)
err = compress(path, target, level, includeExtra, setUID, setGID)
if err != nil {
os.Remove(target)
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
if !keep {
_ = os.Remove(path)
}
}

View File

@@ -1,29 +1,19 @@
package main
import (
"bytes"
"crypto/ecdsa"
"crypto/rsa"
"crypto/sha1" // #nosec G505
"crypto/x509"
"crypto/x509/pkix"
"encoding/asn1"
"encoding/pem"
// #nosec G505
"flag"
"fmt"
"io"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/certlib/ski"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
const (
keyTypeRSA = "RSA"
keyTypeECDSA = "ECDSA"
)
func usage(w io.Writer) {
fmt.Fprintf(w, `ski: print subject key info for PEM-encoded files
@@ -42,117 +32,6 @@ func init() {
flag.Usage = func() { usage(os.Stderr) }
}
func parse(path string) ([]byte, string, string) {
data, err := os.ReadFile(path)
die.If(err)
data = bytes.TrimSpace(data)
p, rest := pem.Decode(data)
if len(rest) > 0 {
_, _ = lib.Warnx("trailing data in PEM file")
}
if p == nil {
die.With("no PEM data found")
}
data = p.Bytes
var (
public []byte
kt string
ft string
)
switch p.Type {
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
public, kt = parseKey(data)
ft = "private key"
case "CERTIFICATE":
public, kt = parseCertificate(data)
ft = "certificate"
case "CERTIFICATE REQUEST":
public, kt = parseCSR(data)
ft = "certificate request"
default:
die.With("unknown PEM type %s", p.Type)
}
return public, kt, ft
}
func parseKey(data []byte) ([]byte, string) {
priv, err := certlib.ParsePrivateKeyDER(data)
if err != nil {
die.If(err)
}
var kt string
switch priv.Public().(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown private key type %T", priv)
}
public, err := x509.MarshalPKIXPublicKey(priv.Public())
die.If(err)
return public, kt
}
func parseCertificate(data []byte) ([]byte, string) {
cert, err := x509.ParseCertificate(data)
die.If(err)
pub := cert.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown public key type %T", pub)
}
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return public, kt
}
func parseCSR(data []byte) ([]byte, string) {
// Use certlib to support both PEM and DER and to centralize validation.
csr, _, err := certlib.ParseCSR(data)
die.If(err)
pub := csr.PublicKey
var kt string
switch pub.(type) {
case *rsa.PublicKey:
kt = keyTypeRSA
case *ecdsa.PublicKey:
kt = keyTypeECDSA
default:
die.With("unknown public key type %T", pub)
}
public, err := x509.MarshalPKIXPublicKey(pub)
die.If(err)
return public, kt
}
func dumpHex(in []byte, mode lib.HexEncodeMode) string {
return lib.HexEncode(in, mode)
}
type subjectPublicKeyInfo struct {
Algorithm pkix.AlgorithmIdentifier
SubjectPublicKey asn1.BitString
}
func main() {
var help, shouldMatch bool
var displayModeString string
@@ -168,27 +47,22 @@ func main() {
os.Exit(0)
}
var ski string
var matchSKI string
for _, path := range flag.Args() {
public, kt, ft := parse(path)
keyInfo, err := ski.ParsePEM(path)
die.If(err)
var subPKI subjectPublicKeyInfo
_, err := asn1.Unmarshal(public, &subPKI)
if err != nil {
_, _ = lib.Warn(err, "failed to get subject PKI")
continue
keySKI, err := keyInfo.SKI(displayMode)
die.If(err)
if matchSKI == "" {
matchSKI = keySKI
}
pubHash := sha1.Sum(subPKI.SubjectPublicKey.Bytes) // #nosec G401 this is the standard
pubHashString := dumpHex(pubHash[:], displayMode)
if ski == "" {
ski = pubHashString
}
if shouldMatch && ski != pubHashString {
if shouldMatch && matchSKI != keySKI {
_, _ = lib.Warnx("%s: SKI mismatch (%s != %s)",
path, ski, pubHashString)
path, matchSKI, keySKI)
}
fmt.Printf("%s %s (%s %s)\n", path, pubHashString, kt, ft)
fmt.Printf("%s %s (%s %s)\n", path, keySKI, keyInfo.KeyType, keyInfo.FileType)
}
}

View File

@@ -3,50 +3,47 @@ package main
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"flag"
"fmt"
"net"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib"
)
func main() {
var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string
var skipVerify bool
var strictTLS bool
lib.StrictTLSFlag(&strictTLS)
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
flag.StringVar(&cfg.ServerName, "sni", cfg.ServerName, "provide an SNI name")
flag.BoolVar(&cfg.InsecureSkipVerify, "noverify", false, "don't verify certificates")
flag.StringVar(&serverName, "sni", "", "provide an SNI name")
flag.BoolVar(&skipVerify, "noverify", false, "don't verify certificates")
flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
if sysRoot != "" {
pemList, err := os.ReadFile(sysRoot)
tlsCfg.RootCAs, err = certlib.LoadPEMCertPool(sysRoot)
die.If(err)
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(pemList) {
fmt.Printf("[!] no valid roots found")
roots = nil
}
cfg.RootCAs = roots
}
if serverName != "" {
cfg.ServerName = serverName
tlsCfg.ServerName = serverName
}
for _, site := range flag.Args() {
_, _, err := net.SplitHostPort(site)
_, _, err = net.SplitHostPort(site)
if err != nil {
site += ":443"
}
// Use proxy-aware TLS dialer
conn, err := lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: cfg})
var conn *tls.Conn
conn, err = lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: tlsCfg})
die.If(err)
cs := conn.ConnectionState()

View File

@@ -1,167 +1,33 @@
package main
import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"flag"
"fmt"
"os"
"flag"
"fmt"
"os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die"
)
var validPEMs = map[string]bool{
"PRIVATE KEY": true,
"RSA PRIVATE KEY": true,
"EC PRIVATE KEY": true,
}
const (
curveInvalid = iota // any invalid curve
curveRSA // indicates key is an RSA key, not an EC key
curveP256
curveP384
curveP521
)
func getECCurve(pub any) int {
switch pub := pub.(type) {
case *ecdsa.PublicKey:
switch pub.Curve {
case elliptic.P256():
return curveP256
case elliptic.P384():
return curveP384
case elliptic.P521():
return curveP521
default:
return curveInvalid
}
case *rsa.PublicKey:
return curveRSA
default:
return curveInvalid
}
}
// matchRSA compares an RSA public key from certificate against RSA public key from private key.
// It returns true on match.
func matchRSA(certPub *rsa.PublicKey, keyPub *rsa.PublicKey) bool {
return keyPub.N.Cmp(certPub.N) == 0 && keyPub.E == certPub.E
}
// matchECDSA compares ECDSA public keys for equality and compatible curve.
// It returns match=true when they are on the same curve and have the same X/Y.
// If curves mismatch, match is false.
func matchECDSA(certPub *ecdsa.PublicKey, keyPub *ecdsa.PublicKey) bool {
if getECCurve(certPub) != getECCurve(keyPub) {
return false
}
if keyPub.X.Cmp(certPub.X) != 0 {
return false
}
if keyPub.Y.Cmp(certPub.Y) != 0 {
return false
}
return true
}
// matchKeys determines whether the certificate's public key matches the given private key.
// It returns true if they match; otherwise, it returns false and a human-friendly reason.
func matchKeys(cert *x509.Certificate, priv crypto.Signer) (bool, string) {
switch keyPub := priv.Public().(type) {
case *rsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *rsa.PublicKey:
if matchRSA(certPub, keyPub) {
return true, ""
}
return false, "public keys don't match"
case *ecdsa.PublicKey:
return false, "RSA private key, EC public key"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
case *ecdsa.PublicKey:
switch certPub := cert.PublicKey.(type) {
case *ecdsa.PublicKey:
if matchECDSA(certPub, keyPub) {
return true, ""
}
// Determine a more precise reason
kc := getECCurve(keyPub)
cc := getECCurve(certPub)
if kc == curveInvalid {
return false, "invalid private key curve"
}
if cc == curveRSA {
return false, "private key is EC, certificate is RSA"
}
if kc != cc {
return false, "EC curves don't match"
}
return false, "public keys don't match"
case *rsa.PublicKey:
return false, "private key is EC, certificate is RSA"
default:
return false, fmt.Sprintf("unsupported certificate public key type: %T", cert.PublicKey)
}
default:
return false, fmt.Sprintf("unrecognised private key type: %T", priv.Public())
}
}
func loadKey(path string) (crypto.Signer, error) {
in, err := os.ReadFile(path)
if err != nil {
return nil, err
}
in = bytes.TrimSpace(in)
if p, _ := pem.Decode(in); p != nil {
if !validPEMs[p.Type] {
return nil, errors.New("invalid private key file type " + p.Type)
}
return certlib.ParsePrivateKeyPEM(in)
}
return certlib.ParsePrivateKeyDER(in)
}
// functionality refactored into certlib
func main() {
var keyFile, certFile string
flag.StringVar(&keyFile, "k", "", "TLS private `key` file")
flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse()
var keyFile, certFile string
flag.StringVar(&keyFile, "k", "", "TLS private `key` file")
flag.StringVar(&certFile, "c", "", "TLS `certificate` file")
flag.Parse()
in, err := os.ReadFile(certFile)
die.If(err)
cert, err := certlib.LoadCertificate(certFile)
die.If(err)
p, _ := pem.Decode(in)
if p != nil {
if p.Type != "CERTIFICATE" {
die.With("invalid certificate (type is %s)", p.Type)
}
in = p.Bytes
}
cert, err := x509.ParseCertificate(in)
die.If(err)
priv, err := certlib.LoadPrivateKey(keyFile)
die.If(err)
priv, err := loadKey(keyFile)
die.If(err)
matched, reason := matchKeys(cert, priv)
if matched {
fmt.Println("Match.")
return
}
fmt.Printf("No match (%s).\n", reason)
os.Exit(1)
matched, reason := certlib.MatchKeys(cert, priv)
if matched {
fmt.Println("Match.")
return
}
fmt.Printf("No match (%s).\n", reason)
os.Exit(1)
}

View File

@@ -1,12 +1,34 @@
// Package dbg implements a debug printer.
// Package dbg implements a simple debug printer.
//
// There are two main ways to use it:
// - By using one of the constructors and calling flag.BoolVar(&debug.Enabled...)
// - By setting the environment variable GOUTILS_ENABLE_DEBUG to true or false and
// calling NewFromEnv().
//
// If enabled, any of the print statements will be written to stdout. Otherwise,
// nothing will be emitted.
package dbg
import (
"fmt"
"io"
"os"
"runtime/debug"
"strings"
)
const DebugEnvKey = "GOUTILS_ENABLE_DEBUG"
var enabledValues = map[string]bool{
"1": true,
"true": true,
"yes": true,
"on": true,
"y": true,
"enable": true,
"enabled": true,
}
// A DebugPrinter is a drop-in replacement for fmt.Print*, and also acts as
// an io.WriteCloser when enabled.
type DebugPrinter struct {
@@ -15,6 +37,23 @@ type DebugPrinter struct {
out io.WriteCloser
}
// New returns a new DebugPrinter on os.Stdout.
func New() *DebugPrinter {
return &DebugPrinter{
out: os.Stderr,
}
}
// NewFromEnv returns a new DebugPrinter based on the value of the environment
// variable GOUTILS_ENABLE_DEBUG.
func NewFromEnv() *DebugPrinter {
enabled := strings.ToLower(os.Getenv(DebugEnvKey))
return &DebugPrinter{
out: os.Stderr,
Enabled: enabledValues[enabled],
}
}
// Close satisfies the Closer interface.
func (dbg *DebugPrinter) Close() error {
return dbg.out.Close()
@@ -28,13 +67,6 @@ func (dbg *DebugPrinter) Write(p []byte) (int, error) {
return 0, nil
}
// New returns a new DebugPrinter on os.Stdout.
func New() *DebugPrinter {
return &DebugPrinter{
out: os.Stdout,
}
}
// ToFile sets up a new DebugPrinter to a file, truncating it if it exists.
func ToFile(path string) (*DebugPrinter, error) {
file, err := os.Create(path)
@@ -74,3 +106,7 @@ func (dbg *DebugPrinter) Printf(format string, v ...any) {
fmt.Fprintf(dbg.out, format, v...)
}
}
func (dbg *DebugPrinter) StackTrace() {
dbg.Write(debug.Stack())
}

View File

@@ -11,3 +11,11 @@ const (
// ExitFailure is the failing exit status.
ExitFailure = 1
)
const (
OneTrueDateFormat = "2006-01-02T15:04:05-0700"
DateShortFormat = "2006-01-02"
TimeShortFormat = "15:04:05"
TimeShorterFormat = "15:04"
TimeStandardDateTime = "2006-01-02 15:04"
)

View File

@@ -20,6 +20,7 @@ import (
"crypto/tls"
"encoding/base64"
"errors"
"flag"
"fmt"
"net"
"net/http"
@@ -29,8 +30,42 @@ import (
"time"
xproxy "golang.org/x/net/proxy"
"git.wntrmute.dev/kyle/goutils/dbg"
)
// StrictBaselineTLSConfig returns a secure TLS config.
// Many of the tools in this repo are designed to debug broken TLS systems
// and therefore explicitly support old or insecure TLS setups.
func StrictBaselineTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: false, // explicitly set
}
}
func StrictTLSFlag(useStrict *bool) {
flag.BoolVar(useStrict, "strict-tls", false, "Use strict TLS configuration (disables certificate verification)")
}
func BaselineTLSConfig(skipVerify bool, secure bool) (*tls.Config, error) {
if secure && skipVerify {
return nil, errors.New("cannot skip verification and use secure TLS")
}
if skipVerify {
return &tls.Config{InsecureSkipVerify: true}, nil // #nosec G402 - intentional
}
if secure {
return StrictBaselineTLSConfig(), nil
}
return &tls.Config{}, nil // #nosec G402 - intentional
}
var debug = dbg.NewFromEnv()
// DialerOpts controls creation of proxy-aware dialers.
//
// Timeout controls the maximum amount of time spent establishing the
@@ -94,24 +129,30 @@ func NewNetDialer(opts DialerOpts) (ContextDialer, error) {
}
if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil {
debug.Printf("using SOCKS5 proxy %q\n", u)
return newSOCKS5Dialer(u, opts)
}
if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil {
// Respect the proxy URL scheme. Zscaler may set HTTPS_PROXY to an HTTP proxy
// running locally; in that case we must NOT TLS-wrap the proxy connection.
debug.Printf("using HTTPS proxy %q\n", u)
return &httpProxyDialer{
proxyURL: u,
timeout: opts.Timeout,
secure: true,
secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig,
}, nil
}
if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil {
debug.Printf("using HTTP proxy %q\n", u)
return &httpProxyDialer{
proxyURL: u,
timeout: opts.Timeout,
secure: true,
config: opts.TLSConfig,
// Only TLS-wrap the proxy connection if the URL scheme is https.
secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig,
}, nil
}
@@ -131,6 +172,7 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) {
// Prefer SOCKS5 if present.
if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil {
debug.Printf("using SOCKS5 proxy %q\n", u)
base, err := newSOCKS5Dialer(u, opts)
if err != nil {
return nil, err
@@ -140,19 +182,22 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) {
// For TLS, prefer HTTPS proxy over HTTP if both set.
if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil {
debug.Printf("using HTTPS proxy %q\n", u)
base := &httpProxyDialer{
proxyURL: u,
timeout: opts.Timeout,
secure: true,
secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig,
}
return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil
}
if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil {
debug.Printf("using HTTP proxy %q\n", u)
base := &httpProxyDialer{
proxyURL: u,
timeout: opts.Timeout,
secure: true,
secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig,
}
return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil
@@ -246,13 +291,8 @@ type httpProxyDialer struct {
config *tls.Config
}
func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network)
}
// Dial to proxy
var nd = &net.Dialer{Timeout: d.timeout}
// proxyAddress returns host:port for the proxy, applying defaults by scheme when missing.
func (d *httpProxyDialer) proxyAddress() string {
proxyAddr := d.proxyURL.Host
if !strings.Contains(proxyAddr, ":") {
if d.secure {
@@ -261,7 +301,61 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
proxyAddr += ":80"
}
}
conn, err := nd.DialContext(ctx, "tcp", proxyAddr)
return proxyAddr
}
// tlsWrapProxyConn performs a TLS handshake to the proxy when d.secure is true.
// It clones the provided tls.Config (if any), ensures ServerName and a safe
// minimum TLS version.
func (d *httpProxyDialer) tlsWrapProxyConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
host := d.proxyURL.Hostname()
// Clone provided config (if any) to avoid mutating caller's config.
cfg := &tls.Config{} // #nosec G402 - intentional
if d.config != nil {
cfg = d.config.Clone()
}
if cfg.ServerName == "" {
cfg.ServerName = host
}
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err)
}
return tlsConn, nil
}
// readConnectResponse reads and validates the proxy's response to a CONNECT
// request. It returns nil on a 200 status and an error otherwise.
func readConnectResponse(br *bufio.Reader) error {
statusLine, err := br.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read CONNECT response: %w", err)
}
if !strings.HasPrefix(statusLine, "HTTP/") {
return fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine))
}
if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") {
// Drain headers for context
_ = drainHeaders(br)
return fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine))
}
return drainHeaders(br)
}
func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network)
}
// Dial to proxy
var nd = &net.Dialer{Timeout: d.timeout}
conn, err := nd.DialContext(ctx, "tcp", d.proxyAddress())
if err != nil {
return nil, err
}
@@ -273,14 +367,11 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
// If HTTPS proxy, wrap with TLS to the proxy itself.
if d.secure {
host := d.proxyURL.Hostname()
d.config.ServerName = host
tlsConn := tls.Client(conn, d.config)
if err = tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err)
c, werr := d.tlsWrapProxyConn(ctx, conn)
if werr != nil {
return nil, werr
}
conn = tlsConn
conn = c
}
req := buildConnectRequest(d.proxyURL, address)
@@ -291,25 +382,7 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
// Read proxy response until end of headers
br := bufio.NewReader(conn)
statusLine, err := br.ReadString('\n')
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to read CONNECT response: %w", err)
}
if !strings.HasPrefix(statusLine, "HTTP/") {
_ = conn.Close()
return nil, fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine))
}
if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") {
// Drain headers for context
_ = drainHeaders(br)
_ = conn.Close()
return nil, fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine))
}
if err = drainHeaders(br); err != nil {
if err = readConnectResponse(br); err != nil {
_ = conn.Close()
return nil, err
}
@@ -429,7 +502,7 @@ func (t *tlsWrappingDialer) DialContext(ctx context.Context, network, address st
}
cfg = c
} else {
cfg = &tls.Config{ServerName: host} // #nosec G402 - intentional
cfg = &tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}
}
tlsConn := tls.Client(raw, cfg)

View File

@@ -14,18 +14,8 @@ import (
"git.wntrmute.dev/kyle/goutils/fileutil"
)
// FetcherOpts are options for fetching certificates. They are only applicable to ServerFetcher.
type FetcherOpts struct {
SkipVerify bool
Roots *x509.CertPool
}
func (fo *FetcherOpts) TLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: fo.SkipVerify, // #nosec G402 - intentional
RootCAs: fo.Roots,
}
}
// Note: Previously this package exposed a FetcherOpts type. It has been
// refactored to use *tls.Config directly for configuring TLS behavior.
// Fetcher is an interface for fetching certificates from a remote source. It
// currently supports fetching from a server or a file.
@@ -143,7 +133,10 @@ func (ff *FileFetcher) Get() (*x509.Certificate, error) {
}
// GetCertificateChain fetches a certificate chain from a remote source.
func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, error) {
// If cfg is non-nil and spec refers to a TLS server, the provided TLS
// configuration will be used to control verification behavior (e.g.,
// InsecureSkipVerify, RootCAs).
func GetCertificateChain(spec string, cfg *tls.Config) ([]*x509.Certificate, error) {
if fileutil.FileDoesExist(spec) {
return NewFileFetcher(spec).GetChain()
}
@@ -153,17 +146,17 @@ func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, e
return nil, err
}
if opts != nil {
fetcher.insecure = opts.SkipVerify
fetcher.roots = opts.Roots
if cfg != nil {
fetcher.insecure = cfg.InsecureSkipVerify
fetcher.roots = cfg.RootCAs
}
return fetcher.GetChain()
}
// GetCertificate fetches the first certificate from a certificate chain.
func GetCertificate(spec string, opts *FetcherOpts) (*x509.Certificate, error) {
certs, err := GetCertificateChain(spec, opts)
func GetCertificate(spec string, cfg *tls.Config) (*x509.Certificate, error) {
certs, err := GetCertificateChain(spec, cfg)
if err != nil {
return nil, err
}

View File

@@ -45,7 +45,7 @@ fi
# Use the first tag if multiple are present; warn the user.
# Avoid readarray for broader Bash compatibility (e.g., macOS Bash 3.2).
TAG_ARRAY=($TAGS)
TAG_ARRAY=("$TAGS")
TAG="${TAG_ARRAY[0]}"
if (( ${#TAG_ARRAY[@]} > 1 )); then