diff --git a/.golangci.yml b/.golangci.yml index e6ddab5..aeb47f4 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -247,11 +247,12 @@ linters: # Default: false check-type-assertions: true exclude-functions: - - (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write + - (*git.wntrmute.dev/kyle/goutils/dbg.DebugPrinter).Write - git.wntrmute.dev/kyle/goutils/lib.Warn - git.wntrmute.dev/kyle/goutils/lib.Warnx - git.wntrmute.dev/kyle/goutils/lib.Err - git.wntrmute.dev/kyle/goutils/lib.Errx + - (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write exhaustive: # Program elements to check for exhaustiveness. diff --git a/cmd/certdump/main.go b/cmd/certdump/main.go index 03c4f2a..1b0d3ba 100644 --- a/cmd/certdump/main.go +++ b/cmd/certdump/main.go @@ -7,6 +7,7 @@ import ( "crypto/elliptic" "crypto/rsa" "crypto/sha256" + "crypto/tls" "crypto/x509" "crypto/x509/pkix" "flag" @@ -350,14 +351,11 @@ func main() { flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate") flag.Parse() - opts := &lib.FetcherOpts{ - SkipVerify: true, - Roots: nil, - } + tlsCfg := &tls.Config{InsecureSkipVerify: true} // #nosec G402 - tool intentionally inspects broken TLS for _, filename := range flag.Args() { fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n") - certs, err := lib.GetCertificateChain(filename, opts) + certs, err := lib.GetCertificateChain(filename, tlsCfg) if err != nil { _, _ = lib.Warn(err, "couldn't read certificate") continue diff --git a/cmd/certexpiry/main.go b/cmd/certexpiry/main.go index 7529606..ae5633f 100644 --- a/cmd/certexpiry/main.go +++ b/cmd/certexpiry/main.go @@ -63,6 +63,7 @@ func checkCert(cert *x509.Certificate) { warn := inDanger(cert) name := displayName(cert.Subject) name = fmt.Sprintf("%s/SN=%s", name, cert.SerialNumber) + expiry := expires(cert) if warnOnly { if warn { @@ -74,15 +75,22 @@ func checkCert(cert *x509.Certificate) { } func main() { - opts := &lib.FetcherOpts{} + var skipVerify bool + var strictTLS bool + lib.StrictTLSFlag(&strictTLS) - flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") + flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402 flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.Parse() + tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + die.If(err) + for _, file := range flag.Args() { - certs, err := lib.GetCertificateChain(file, opts) + var certs []*x509.Certificate + + certs, err = lib.GetCertificateChain(file, tlsCfg) if err != nil { _, _ = lib.Warn(err, "while parsing certificates") continue diff --git a/cmd/certser/main.go b/cmd/certser/main.go index c7734ab..a801a84 100644 --- a/cmd/certser/main.go +++ b/cmd/certser/main.go @@ -31,16 +31,23 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string { } func main() { - opts := &lib.FetcherOpts{} + var skipVerify bool + var strictTLS bool + lib.StrictTLSFlag(&strictTLS) displayAs := flag.String("d", "int", "display mode (int, hex, uhex)") showExpiry := flag.Bool("e", false, "show expiry date") - flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") + flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402 flag.Parse() + tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + die.If(err) + displayMode := parseDisplayMode(*displayAs) for _, arg := range flag.Args() { - cert, err := lib.GetCertificate(arg, opts) + var cert *x509.Certificate + + cert, err = lib.GetCertificate(arg, tlsCfg) die.If(err) fmt.Printf("%s: %s", arg, serialString(cert, displayMode)) diff --git a/cmd/certverify/main.go b/cmd/certverify/main.go index a997645..6c58fdf 100644 --- a/cmd/certverify/main.go +++ b/cmd/certverify/main.go @@ -32,6 +32,7 @@ type appConfig struct { caFile, intFile string forceIntermediateBundle bool revexp, skipVerify, verbose bool + strictTLS bool } func parseFlags() appConfig { @@ -43,6 +44,7 @@ func parseFlags() appConfig { flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.verbose, "v", false, "verbose") + lib.StrictTLSFlag(&cfg.strictTLS) flag.Parse() return cfg } @@ -108,12 +110,13 @@ func run(cfg appConfig) error { return fmt.Errorf("failed to build combined pool: %w", err) } - opts := &lib.FetcherOpts{ - Roots: combinedPool, - SkipVerify: cfg.skipVerify, + tlsCfg, err := lib.BaselineTLSConfig(cfg.skipVerify, cfg.strictTLS) + if err != nil { + return err } + tlsCfg.RootCAs = combinedPool - chain, err := lib.GetCertificateChain(flag.Arg(0), opts) + chain, err := lib.GetCertificateChain(flag.Arg(0), tlsCfg) if err != nil { return err } diff --git a/cmd/stealchain/main.go b/cmd/stealchain/main.go index a9c960a..da2abc7 100644 --- a/cmd/stealchain/main.go +++ b/cmd/stealchain/main.go @@ -3,50 +3,47 @@ package main import ( "context" "crypto/tls" - "crypto/x509" "encoding/pem" "flag" "fmt" "net" "os" + "git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" ) func main() { - var cfg = &tls.Config{} // #nosec G402 - var sysRoot, serverName string + var skipVerify bool + var strictTLS bool + lib.StrictTLSFlag(&strictTLS) flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") - flag.StringVar(&cfg.ServerName, "sni", cfg.ServerName, "provide an SNI name") - flag.BoolVar(&cfg.InsecureSkipVerify, "noverify", false, "don't verify certificates") + flag.StringVar(&serverName, "sni", "", "provide an SNI name") + flag.BoolVar(&skipVerify, "noverify", false, "don't verify certificates") flag.Parse() + tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + die.If(err) + if sysRoot != "" { - pemList, err := os.ReadFile(sysRoot) + tlsCfg.RootCAs, err = certlib.LoadPEMCertPool(sysRoot) die.If(err) - - roots := x509.NewCertPool() - if !roots.AppendCertsFromPEM(pemList) { - fmt.Printf("[!] no valid roots found") - roots = nil - } - - cfg.RootCAs = roots } if serverName != "" { - cfg.ServerName = serverName + tlsCfg.ServerName = serverName } for _, site := range flag.Args() { - _, _, err := net.SplitHostPort(site) + _, _, err = net.SplitHostPort(site) if err != nil { site += ":443" } - // Use proxy-aware TLS dialer - conn, err := lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: cfg}) + + var conn *tls.Conn + conn, err = lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: tlsCfg}) die.If(err) cs := conn.ConnectionState() diff --git a/dbg/dbg.go b/dbg/dbg.go index 5d53764..aa9ccb6 100644 --- a/dbg/dbg.go +++ b/dbg/dbg.go @@ -1,12 +1,34 @@ -// Package dbg implements a debug printer. +// Package dbg implements a simple debug printer. +// +// There are two main ways to use it: +// - By using one of the constructors and calling flag.BoolVar(&debug.Enabled...) +// - By setting the environment variable GOUTILS_ENABLE_DEBUG to true or false and +// calling NewFromEnv(). +// +// If enabled, any of the print statements will be written to stdout. Otherwise, +// nothing will be emitted. package dbg import ( "fmt" "io" "os" + "runtime/debug" + "strings" ) +const DebugEnvKey = "GOUTILS_ENABLE_DEBUG" + +var enabledValues = map[string]bool{ + "1": true, + "true": true, + "yes": true, + "on": true, + "y": true, + "enable": true, + "enabled": true, +} + // A DebugPrinter is a drop-in replacement for fmt.Print*, and also acts as // an io.WriteCloser when enabled. type DebugPrinter struct { @@ -15,6 +37,23 @@ type DebugPrinter struct { out io.WriteCloser } +// New returns a new DebugPrinter on os.Stdout. +func New() *DebugPrinter { + return &DebugPrinter{ + out: os.Stderr, + } +} + +// NewFromEnv returns a new DebugPrinter based on the value of the environment +// variable GOUTILS_ENABLE_DEBUG. +func NewFromEnv() *DebugPrinter { + enabled := strings.ToLower(os.Getenv(DebugEnvKey)) + return &DebugPrinter{ + out: os.Stderr, + Enabled: enabledValues[enabled], + } +} + // Close satisfies the Closer interface. func (dbg *DebugPrinter) Close() error { return dbg.out.Close() @@ -28,13 +67,6 @@ func (dbg *DebugPrinter) Write(p []byte) (int, error) { return 0, nil } -// New returns a new DebugPrinter on os.Stdout. -func New() *DebugPrinter { - return &DebugPrinter{ - out: os.Stdout, - } -} - // ToFile sets up a new DebugPrinter to a file, truncating it if it exists. func ToFile(path string) (*DebugPrinter, error) { file, err := os.Create(path) @@ -74,3 +106,7 @@ func (dbg *DebugPrinter) Printf(format string, v ...any) { fmt.Fprintf(dbg.out, format, v...) } } + +func (dbg *DebugPrinter) StackTrace() { + dbg.Write(debug.Stack()) +} diff --git a/lib/dialer.go b/lib/dialer.go index 2f09315..666b8fa 100644 --- a/lib/dialer.go +++ b/lib/dialer.go @@ -20,6 +20,7 @@ import ( "crypto/tls" "encoding/base64" "errors" + "flag" "fmt" "net" "net/http" @@ -29,8 +30,42 @@ import ( "time" xproxy "golang.org/x/net/proxy" + + "git.wntrmute.dev/kyle/goutils/dbg" ) +// StrictBaselineTLSConfig returns a secure TLS config. +// Many of the tools in this repo are designed to debug broken TLS systems +// and therefore explicitly support old or insecure TLS setups. +func StrictBaselineTLSConfig() *tls.Config { + return &tls.Config{ + MinVersion: tls.VersionTLS12, + InsecureSkipVerify: false, // explicitly set + } +} + +func StrictTLSFlag(useStrict *bool) { + flag.BoolVar(useStrict, "strict-tls", false, "Use strict TLS configuration (disables certificate verification)") +} + +func BaselineTLSConfig(skipVerify bool, secure bool) (*tls.Config, error) { + if secure && skipVerify { + return nil, errors.New("cannot skip verification and use secure TLS") + } + + if skipVerify { + return &tls.Config{InsecureSkipVerify: true}, nil // #nosec G402 - intentional + } + + if secure { + return StrictBaselineTLSConfig(), nil + } + + return &tls.Config{}, nil // #nosec G402 - intentional +} + +var debug = dbg.NewFromEnv() + // DialerOpts controls creation of proxy-aware dialers. // // Timeout controls the maximum amount of time spent establishing the @@ -94,24 +129,30 @@ func NewNetDialer(opts DialerOpts) (ContextDialer, error) { } if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil { + debug.Printf("using SOCKS5 proxy %q\n", u) return newSOCKS5Dialer(u, opts) } if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil { + // Respect the proxy URL scheme. Zscaler may set HTTPS_PROXY to an HTTP proxy + // running locally; in that case we must NOT TLS-wrap the proxy connection. + debug.Printf("using HTTPS proxy %q\n", u) return &httpProxyDialer{ proxyURL: u, timeout: opts.Timeout, - secure: true, + secure: strings.EqualFold(u.Scheme, "https"), config: opts.TLSConfig, }, nil } if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil { + debug.Printf("using HTTP proxy %q\n", u) return &httpProxyDialer{ proxyURL: u, timeout: opts.Timeout, - secure: true, - config: opts.TLSConfig, + // Only TLS-wrap the proxy connection if the URL scheme is https. + secure: strings.EqualFold(u.Scheme, "https"), + config: opts.TLSConfig, }, nil } @@ -131,6 +172,7 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) { // Prefer SOCKS5 if present. if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil { + debug.Printf("using SOCKS5 proxy %q\n", u) base, err := newSOCKS5Dialer(u, opts) if err != nil { return nil, err @@ -140,19 +182,22 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) { // For TLS, prefer HTTPS proxy over HTTP if both set. if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil { + debug.Printf("using HTTPS proxy %q\n", u) base := &httpProxyDialer{ proxyURL: u, timeout: opts.Timeout, - secure: true, + secure: strings.EqualFold(u.Scheme, "https"), config: opts.TLSConfig, } return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil } + if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil { + debug.Printf("using HTTP proxy %q\n", u) base := &httpProxyDialer{ proxyURL: u, timeout: opts.Timeout, - secure: true, + secure: strings.EqualFold(u.Scheme, "https"), config: opts.TLSConfig, } return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil @@ -246,13 +291,8 @@ type httpProxyDialer struct { config *tls.Config } -func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { - if !strings.HasPrefix(network, "tcp") { - return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network) - } - - // Dial to proxy - var nd = &net.Dialer{Timeout: d.timeout} +// proxyAddress returns host:port for the proxy, applying defaults by scheme when missing. +func (d *httpProxyDialer) proxyAddress() string { proxyAddr := d.proxyURL.Host if !strings.Contains(proxyAddr, ":") { if d.secure { @@ -261,7 +301,61 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri proxyAddr += ":80" } } - conn, err := nd.DialContext(ctx, "tcp", proxyAddr) + return proxyAddr +} + +// tlsWrapProxyConn performs a TLS handshake to the proxy when d.secure is true. +// It clones the provided tls.Config (if any), ensures ServerName and a safe +// minimum TLS version. +func (d *httpProxyDialer) tlsWrapProxyConn(ctx context.Context, conn net.Conn) (net.Conn, error) { + host := d.proxyURL.Hostname() + // Clone provided config (if any) to avoid mutating caller's config. + cfg := &tls.Config{} // #nosec G402 - intentional + if d.config != nil { + cfg = d.config.Clone() + } + + if cfg.ServerName == "" { + cfg.ServerName = host + } + + tlsConn := tls.Client(conn, cfg) + if err := tlsConn.HandshakeContext(ctx); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err) + } + return tlsConn, nil +} + +// readConnectResponse reads and validates the proxy's response to a CONNECT +// request. It returns nil on a 200 status and an error otherwise. +func readConnectResponse(br *bufio.Reader) error { + statusLine, err := br.ReadString('\n') + if err != nil { + return fmt.Errorf("failed to read CONNECT response: %w", err) + } + + if !strings.HasPrefix(statusLine, "HTTP/") { + return fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine)) + } + + if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") { + // Drain headers for context + _ = drainHeaders(br) + return fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine)) + } + + return drainHeaders(br) +} + +func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + if !strings.HasPrefix(network, "tcp") { + return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network) + } + + // Dial to proxy + var nd = &net.Dialer{Timeout: d.timeout} + conn, err := nd.DialContext(ctx, "tcp", d.proxyAddress()) if err != nil { return nil, err } @@ -273,14 +367,11 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri // If HTTPS proxy, wrap with TLS to the proxy itself. if d.secure { - host := d.proxyURL.Hostname() - d.config.ServerName = host - tlsConn := tls.Client(conn, d.config) - if err = tlsConn.HandshakeContext(ctx); err != nil { - _ = conn.Close() - return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err) + c, werr := d.tlsWrapProxyConn(ctx, conn) + if werr != nil { + return nil, werr } - conn = tlsConn + conn = c } req := buildConnectRequest(d.proxyURL, address) @@ -291,25 +382,7 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri // Read proxy response until end of headers br := bufio.NewReader(conn) - statusLine, err := br.ReadString('\n') - if err != nil { - _ = conn.Close() - return nil, fmt.Errorf("failed to read CONNECT response: %w", err) - } - - if !strings.HasPrefix(statusLine, "HTTP/") { - _ = conn.Close() - return nil, fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine)) - } - - if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") { - // Drain headers for context - _ = drainHeaders(br) - _ = conn.Close() - return nil, fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine)) - } - - if err = drainHeaders(br); err != nil { + if err = readConnectResponse(br); err != nil { _ = conn.Close() return nil, err } @@ -429,7 +502,7 @@ func (t *tlsWrappingDialer) DialContext(ctx context.Context, network, address st } cfg = c } else { - cfg = &tls.Config{ServerName: host} // #nosec G402 - intentional + cfg = &tls.Config{ServerName: host, MinVersion: tls.VersionTLS12} } tlsConn := tls.Client(raw, cfg) diff --git a/lib/fetch.go b/lib/fetch.go index 2180d33..7faed18 100644 --- a/lib/fetch.go +++ b/lib/fetch.go @@ -14,18 +14,8 @@ import ( "git.wntrmute.dev/kyle/goutils/fileutil" ) -// FetcherOpts are options for fetching certificates. They are only applicable to ServerFetcher. -type FetcherOpts struct { - SkipVerify bool - Roots *x509.CertPool -} - -func (fo *FetcherOpts) TLSConfig() *tls.Config { - return &tls.Config{ - InsecureSkipVerify: fo.SkipVerify, // #nosec G402 - intentional - RootCAs: fo.Roots, - } -} +// Note: Previously this package exposed a FetcherOpts type. It has been +// refactored to use *tls.Config directly for configuring TLS behavior. // Fetcher is an interface for fetching certificates from a remote source. It // currently supports fetching from a server or a file. @@ -143,7 +133,10 @@ func (ff *FileFetcher) Get() (*x509.Certificate, error) { } // GetCertificateChain fetches a certificate chain from a remote source. -func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, error) { +// If cfg is non-nil and spec refers to a TLS server, the provided TLS +// configuration will be used to control verification behavior (e.g., +// InsecureSkipVerify, RootCAs). +func GetCertificateChain(spec string, cfg *tls.Config) ([]*x509.Certificate, error) { if fileutil.FileDoesExist(spec) { return NewFileFetcher(spec).GetChain() } @@ -153,17 +146,17 @@ func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, e return nil, err } - if opts != nil { - fetcher.insecure = opts.SkipVerify - fetcher.roots = opts.Roots + if cfg != nil { + fetcher.insecure = cfg.InsecureSkipVerify + fetcher.roots = cfg.RootCAs } return fetcher.GetChain() } // GetCertificate fetches the first certificate from a certificate chain. -func GetCertificate(spec string, opts *FetcherOpts) (*x509.Certificate, error) { - certs, err := GetCertificateChain(spec, opts) +func GetCertificate(spec string, cfg *tls.Config) (*x509.Certificate, error) { + certs, err := GetCertificateChain(spec, cfg) if err != nil { return nil, err } diff --git a/release-docker.sh b/release-docker.sh index 27a2fa3..2f60328 100755 --- a/release-docker.sh +++ b/release-docker.sh @@ -45,7 +45,7 @@ fi # Use the first tag if multiple are present; warn the user. # Avoid readarray for broader Bash compatibility (e.g., macOS Bash 3.2). -TAG_ARRAY=($TAGS) +TAG_ARRAY=("$TAGS") TAG="${TAG_ARRAY[0]}" if (( ${#TAG_ARRAY[@]} > 1 )); then