Add tooling to enable strict TLS.

This commit is contained in:
2025-11-18 17:25:49 -08:00
parent 3f92963c74
commit b714c75a43
10 changed files with 217 additions and 101 deletions

View File

@@ -247,11 +247,12 @@ linters:
# Default: false # Default: false
check-type-assertions: true check-type-assertions: true
exclude-functions: exclude-functions:
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write - (*git.wntrmute.dev/kyle/goutils/dbg.DebugPrinter).Write
- git.wntrmute.dev/kyle/goutils/lib.Warn - git.wntrmute.dev/kyle/goutils/lib.Warn
- git.wntrmute.dev/kyle/goutils/lib.Warnx - git.wntrmute.dev/kyle/goutils/lib.Warnx
- git.wntrmute.dev/kyle/goutils/lib.Err - git.wntrmute.dev/kyle/goutils/lib.Err
- git.wntrmute.dev/kyle/goutils/lib.Errx - git.wntrmute.dev/kyle/goutils/lib.Errx
- (*git.wntrmute.dev/kyle/goutils/sbuf.Buffer).Write
exhaustive: exhaustive:
# Program elements to check for exhaustiveness. # Program elements to check for exhaustiveness.

View File

@@ -7,6 +7,7 @@ import (
"crypto/elliptic" "crypto/elliptic"
"crypto/rsa" "crypto/rsa"
"crypto/sha256" "crypto/sha256"
"crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"flag" "flag"
@@ -350,14 +351,11 @@ func main() {
flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate") flag.BoolVar(&leafOnly, "l", false, "only show the leaf certificate")
flag.Parse() flag.Parse()
opts := &lib.FetcherOpts{ tlsCfg := &tls.Config{InsecureSkipVerify: true} // #nosec G402 - tool intentionally inspects broken TLS
SkipVerify: true,
Roots: nil,
}
for _, filename := range flag.Args() { for _, filename := range flag.Args() {
fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n") fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n")
certs, err := lib.GetCertificateChain(filename, opts) certs, err := lib.GetCertificateChain(filename, tlsCfg)
if err != nil { if err != nil {
_, _ = lib.Warn(err, "couldn't read certificate") _, _ = lib.Warn(err, "couldn't read certificate")
continue continue

View File

@@ -63,6 +63,7 @@ func checkCert(cert *x509.Certificate) {
warn := inDanger(cert) warn := inDanger(cert)
name := displayName(cert.Subject) name := displayName(cert.Subject)
name = fmt.Sprintf("%s/SN=%s", name, cert.SerialNumber) name = fmt.Sprintf("%s/SN=%s", name, cert.SerialNumber)
expiry := expires(cert) expiry := expires(cert)
if warnOnly { if warnOnly {
if warn { if warn {
@@ -74,15 +75,22 @@ func checkCert(cert *x509.Certificate) {
} }
func main() { func main() {
opts := &lib.FetcherOpts{} var skipVerify bool
var strictTLS bool
lib.StrictTLSFlag(&strictTLS)
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402
flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs")
flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring")
flag.Parse() flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
for _, file := range flag.Args() { for _, file := range flag.Args() {
certs, err := lib.GetCertificateChain(file, opts) var certs []*x509.Certificate
certs, err = lib.GetCertificateChain(file, tlsCfg)
if err != nil { if err != nil {
_, _ = lib.Warn(err, "while parsing certificates") _, _ = lib.Warn(err, "while parsing certificates")
continue continue

View File

@@ -31,16 +31,23 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string {
} }
func main() { func main() {
opts := &lib.FetcherOpts{} var skipVerify bool
var strictTLS bool
lib.StrictTLSFlag(&strictTLS)
displayAs := flag.String("d", "int", "display mode (int, hex, uhex)") displayAs := flag.String("d", "int", "display mode (int, hex, uhex)")
showExpiry := flag.Bool("e", false, "show expiry date") showExpiry := flag.Bool("e", false, "show expiry date")
flag.BoolVar(&opts.SkipVerify, "k", false, "skip server verification") flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402
flag.Parse() flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
displayMode := parseDisplayMode(*displayAs) displayMode := parseDisplayMode(*displayAs)
for _, arg := range flag.Args() { for _, arg := range flag.Args() {
cert, err := lib.GetCertificate(arg, opts) var cert *x509.Certificate
cert, err = lib.GetCertificate(arg, tlsCfg)
die.If(err) die.If(err)
fmt.Printf("%s: %s", arg, serialString(cert, displayMode)) fmt.Printf("%s: %s", arg, serialString(cert, displayMode))

View File

@@ -32,6 +32,7 @@ type appConfig struct {
caFile, intFile string caFile, intFile string
forceIntermediateBundle bool forceIntermediateBundle bool
revexp, skipVerify, verbose bool revexp, skipVerify, verbose bool
strictTLS bool
} }
func parseFlags() appConfig { func parseFlags() appConfig {
@@ -43,6 +44,7 @@ func parseFlags() appConfig {
flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification") flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification")
flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information")
flag.BoolVar(&cfg.verbose, "v", false, "verbose") flag.BoolVar(&cfg.verbose, "v", false, "verbose")
lib.StrictTLSFlag(&cfg.strictTLS)
flag.Parse() flag.Parse()
return cfg return cfg
} }
@@ -108,12 +110,13 @@ func run(cfg appConfig) error {
return fmt.Errorf("failed to build combined pool: %w", err) return fmt.Errorf("failed to build combined pool: %w", err)
} }
opts := &lib.FetcherOpts{ tlsCfg, err := lib.BaselineTLSConfig(cfg.skipVerify, cfg.strictTLS)
Roots: combinedPool, if err != nil {
SkipVerify: cfg.skipVerify, return err
} }
tlsCfg.RootCAs = combinedPool
chain, err := lib.GetCertificateChain(flag.Arg(0), opts) chain, err := lib.GetCertificateChain(flag.Arg(0), tlsCfg)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,50 +3,47 @@ package main
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"crypto/x509"
"encoding/pem" "encoding/pem"
"flag" "flag"
"fmt" "fmt"
"net" "net"
"os" "os"
"git.wntrmute.dev/kyle/goutils/certlib"
"git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/die"
"git.wntrmute.dev/kyle/goutils/lib" "git.wntrmute.dev/kyle/goutils/lib"
) )
func main() { func main() {
var cfg = &tls.Config{} // #nosec G402
var sysRoot, serverName string var sysRoot, serverName string
var skipVerify bool
var strictTLS bool
lib.StrictTLSFlag(&strictTLS)
flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle")
flag.StringVar(&cfg.ServerName, "sni", cfg.ServerName, "provide an SNI name") flag.StringVar(&serverName, "sni", "", "provide an SNI name")
flag.BoolVar(&cfg.InsecureSkipVerify, "noverify", false, "don't verify certificates") flag.BoolVar(&skipVerify, "noverify", false, "don't verify certificates")
flag.Parse() flag.Parse()
tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS)
die.If(err)
if sysRoot != "" { if sysRoot != "" {
pemList, err := os.ReadFile(sysRoot) tlsCfg.RootCAs, err = certlib.LoadPEMCertPool(sysRoot)
die.If(err) die.If(err)
roots := x509.NewCertPool()
if !roots.AppendCertsFromPEM(pemList) {
fmt.Printf("[!] no valid roots found")
roots = nil
}
cfg.RootCAs = roots
} }
if serverName != "" { if serverName != "" {
cfg.ServerName = serverName tlsCfg.ServerName = serverName
} }
for _, site := range flag.Args() { for _, site := range flag.Args() {
_, _, err := net.SplitHostPort(site) _, _, err = net.SplitHostPort(site)
if err != nil { if err != nil {
site += ":443" site += ":443"
} }
// Use proxy-aware TLS dialer
conn, err := lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: cfg}) var conn *tls.Conn
conn, err = lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: tlsCfg})
die.If(err) die.If(err)
cs := conn.ConnectionState() cs := conn.ConnectionState()

View File

@@ -1,12 +1,34 @@
// Package dbg implements a debug printer. // Package dbg implements a simple debug printer.
//
// There are two main ways to use it:
// - By using one of the constructors and calling flag.BoolVar(&debug.Enabled...)
// - By setting the environment variable GOUTILS_ENABLE_DEBUG to true or false and
// calling NewFromEnv().
//
// If enabled, any of the print statements will be written to stdout. Otherwise,
// nothing will be emitted.
package dbg package dbg
import ( import (
"fmt" "fmt"
"io" "io"
"os" "os"
"runtime/debug"
"strings"
) )
const DebugEnvKey = "GOUTILS_ENABLE_DEBUG"
var enabledValues = map[string]bool{
"1": true,
"true": true,
"yes": true,
"on": true,
"y": true,
"enable": true,
"enabled": true,
}
// A DebugPrinter is a drop-in replacement for fmt.Print*, and also acts as // A DebugPrinter is a drop-in replacement for fmt.Print*, and also acts as
// an io.WriteCloser when enabled. // an io.WriteCloser when enabled.
type DebugPrinter struct { type DebugPrinter struct {
@@ -15,6 +37,23 @@ type DebugPrinter struct {
out io.WriteCloser out io.WriteCloser
} }
// New returns a new DebugPrinter on os.Stdout.
func New() *DebugPrinter {
return &DebugPrinter{
out: os.Stderr,
}
}
// NewFromEnv returns a new DebugPrinter based on the value of the environment
// variable GOUTILS_ENABLE_DEBUG.
func NewFromEnv() *DebugPrinter {
enabled := strings.ToLower(os.Getenv(DebugEnvKey))
return &DebugPrinter{
out: os.Stderr,
Enabled: enabledValues[enabled],
}
}
// Close satisfies the Closer interface. // Close satisfies the Closer interface.
func (dbg *DebugPrinter) Close() error { func (dbg *DebugPrinter) Close() error {
return dbg.out.Close() return dbg.out.Close()
@@ -28,13 +67,6 @@ func (dbg *DebugPrinter) Write(p []byte) (int, error) {
return 0, nil return 0, nil
} }
// New returns a new DebugPrinter on os.Stdout.
func New() *DebugPrinter {
return &DebugPrinter{
out: os.Stdout,
}
}
// ToFile sets up a new DebugPrinter to a file, truncating it if it exists. // ToFile sets up a new DebugPrinter to a file, truncating it if it exists.
func ToFile(path string) (*DebugPrinter, error) { func ToFile(path string) (*DebugPrinter, error) {
file, err := os.Create(path) file, err := os.Create(path)
@@ -74,3 +106,7 @@ func (dbg *DebugPrinter) Printf(format string, v ...any) {
fmt.Fprintf(dbg.out, format, v...) fmt.Fprintf(dbg.out, format, v...)
} }
} }
func (dbg *DebugPrinter) StackTrace() {
dbg.Write(debug.Stack())
}

View File

@@ -20,6 +20,7 @@ import (
"crypto/tls" "crypto/tls"
"encoding/base64" "encoding/base64"
"errors" "errors"
"flag"
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
@@ -29,8 +30,42 @@ import (
"time" "time"
xproxy "golang.org/x/net/proxy" xproxy "golang.org/x/net/proxy"
"git.wntrmute.dev/kyle/goutils/dbg"
) )
// StrictBaselineTLSConfig returns a secure TLS config.
// Many of the tools in this repo are designed to debug broken TLS systems
// and therefore explicitly support old or insecure TLS setups.
func StrictBaselineTLSConfig() *tls.Config {
return &tls.Config{
MinVersion: tls.VersionTLS12,
InsecureSkipVerify: false, // explicitly set
}
}
func StrictTLSFlag(useStrict *bool) {
flag.BoolVar(useStrict, "strict-tls", false, "Use strict TLS configuration (disables certificate verification)")
}
func BaselineTLSConfig(skipVerify bool, secure bool) (*tls.Config, error) {
if secure && skipVerify {
return nil, errors.New("cannot skip verification and use secure TLS")
}
if skipVerify {
return &tls.Config{InsecureSkipVerify: true}, nil // #nosec G402 - intentional
}
if secure {
return StrictBaselineTLSConfig(), nil
}
return &tls.Config{}, nil // #nosec G402 - intentional
}
var debug = dbg.NewFromEnv()
// DialerOpts controls creation of proxy-aware dialers. // DialerOpts controls creation of proxy-aware dialers.
// //
// Timeout controls the maximum amount of time spent establishing the // Timeout controls the maximum amount of time spent establishing the
@@ -94,24 +129,30 @@ func NewNetDialer(opts DialerOpts) (ContextDialer, error) {
} }
if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil { if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil {
debug.Printf("using SOCKS5 proxy %q\n", u)
return newSOCKS5Dialer(u, opts) return newSOCKS5Dialer(u, opts)
} }
if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil { if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil {
// Respect the proxy URL scheme. Zscaler may set HTTPS_PROXY to an HTTP proxy
// running locally; in that case we must NOT TLS-wrap the proxy connection.
debug.Printf("using HTTPS proxy %q\n", u)
return &httpProxyDialer{ return &httpProxyDialer{
proxyURL: u, proxyURL: u,
timeout: opts.Timeout, timeout: opts.Timeout,
secure: true, secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig, config: opts.TLSConfig,
}, nil }, nil
} }
if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil { if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil {
debug.Printf("using HTTP proxy %q\n", u)
return &httpProxyDialer{ return &httpProxyDialer{
proxyURL: u, proxyURL: u,
timeout: opts.Timeout, timeout: opts.Timeout,
secure: true, // Only TLS-wrap the proxy connection if the URL scheme is https.
config: opts.TLSConfig, secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig,
}, nil }, nil
} }
@@ -131,6 +172,7 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) {
// Prefer SOCKS5 if present. // Prefer SOCKS5 if present.
if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil { if u := getProxyURLFromEnv("SOCKS5_PROXY"); u != nil {
debug.Printf("using SOCKS5 proxy %q\n", u)
base, err := newSOCKS5Dialer(u, opts) base, err := newSOCKS5Dialer(u, opts)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -140,19 +182,22 @@ func NewTLSDialer(opts DialerOpts) (ContextDialer, error) {
// For TLS, prefer HTTPS proxy over HTTP if both set. // For TLS, prefer HTTPS proxy over HTTP if both set.
if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil { if u := getProxyURLFromEnv("HTTPS_PROXY"); u != nil {
debug.Printf("using HTTPS proxy %q\n", u)
base := &httpProxyDialer{ base := &httpProxyDialer{
proxyURL: u, proxyURL: u,
timeout: opts.Timeout, timeout: opts.Timeout,
secure: true, secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig, config: opts.TLSConfig,
} }
return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil
} }
if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil { if u := getProxyURLFromEnv("HTTP_PROXY"); u != nil {
debug.Printf("using HTTP proxy %q\n", u)
base := &httpProxyDialer{ base := &httpProxyDialer{
proxyURL: u, proxyURL: u,
timeout: opts.Timeout, timeout: opts.Timeout,
secure: true, secure: strings.EqualFold(u.Scheme, "https"),
config: opts.TLSConfig, config: opts.TLSConfig,
} }
return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil return &tlsWrappingDialer{base: base, tcfg: opts.TLSConfig, timeout: opts.Timeout}, nil
@@ -246,13 +291,8 @@ type httpProxyDialer struct {
config *tls.Config config *tls.Config
} }
func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { // proxyAddress returns host:port for the proxy, applying defaults by scheme when missing.
if !strings.HasPrefix(network, "tcp") { func (d *httpProxyDialer) proxyAddress() string {
return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network)
}
// Dial to proxy
var nd = &net.Dialer{Timeout: d.timeout}
proxyAddr := d.proxyURL.Host proxyAddr := d.proxyURL.Host
if !strings.Contains(proxyAddr, ":") { if !strings.Contains(proxyAddr, ":") {
if d.secure { if d.secure {
@@ -261,7 +301,61 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
proxyAddr += ":80" proxyAddr += ":80"
} }
} }
conn, err := nd.DialContext(ctx, "tcp", proxyAddr) return proxyAddr
}
// tlsWrapProxyConn performs a TLS handshake to the proxy when d.secure is true.
// It clones the provided tls.Config (if any), ensures ServerName and a safe
// minimum TLS version.
func (d *httpProxyDialer) tlsWrapProxyConn(ctx context.Context, conn net.Conn) (net.Conn, error) {
host := d.proxyURL.Hostname()
// Clone provided config (if any) to avoid mutating caller's config.
cfg := &tls.Config{} // #nosec G402 - intentional
if d.config != nil {
cfg = d.config.Clone()
}
if cfg.ServerName == "" {
cfg.ServerName = host
}
tlsConn := tls.Client(conn, cfg)
if err := tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err)
}
return tlsConn, nil
}
// readConnectResponse reads and validates the proxy's response to a CONNECT
// request. It returns nil on a 200 status and an error otherwise.
func readConnectResponse(br *bufio.Reader) error {
statusLine, err := br.ReadString('\n')
if err != nil {
return fmt.Errorf("failed to read CONNECT response: %w", err)
}
if !strings.HasPrefix(statusLine, "HTTP/") {
return fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine))
}
if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") {
// Drain headers for context
_ = drainHeaders(br)
return fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine))
}
return drainHeaders(br)
}
func (d *httpProxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
return nil, fmt.Errorf("http proxy dialer only supports TCP, got %q", network)
}
// Dial to proxy
var nd = &net.Dialer{Timeout: d.timeout}
conn, err := nd.DialContext(ctx, "tcp", d.proxyAddress())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -273,14 +367,11 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
// If HTTPS proxy, wrap with TLS to the proxy itself. // If HTTPS proxy, wrap with TLS to the proxy itself.
if d.secure { if d.secure {
host := d.proxyURL.Hostname() c, werr := d.tlsWrapProxyConn(ctx, conn)
d.config.ServerName = host if werr != nil {
tlsConn := tls.Client(conn, d.config) return nil, werr
if err = tlsConn.HandshakeContext(ctx); err != nil {
_ = conn.Close()
return nil, fmt.Errorf("tls handshake with https proxy failed: %w", err)
} }
conn = tlsConn conn = c
} }
req := buildConnectRequest(d.proxyURL, address) req := buildConnectRequest(d.proxyURL, address)
@@ -291,25 +382,7 @@ func (d *httpProxyDialer) DialContext(ctx context.Context, network, address stri
// Read proxy response until end of headers // Read proxy response until end of headers
br := bufio.NewReader(conn) br := bufio.NewReader(conn)
statusLine, err := br.ReadString('\n') if err = readConnectResponse(br); err != nil {
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("failed to read CONNECT response: %w", err)
}
if !strings.HasPrefix(statusLine, "HTTP/") {
_ = conn.Close()
return nil, fmt.Errorf("invalid proxy response: %q", strings.TrimSpace(statusLine))
}
if !strings.Contains(statusLine, " 200 ") && !strings.HasSuffix(strings.TrimSpace(statusLine), " 200") {
// Drain headers for context
_ = drainHeaders(br)
_ = conn.Close()
return nil, fmt.Errorf("proxy CONNECT failed: %s", strings.TrimSpace(statusLine))
}
if err = drainHeaders(br); err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
@@ -429,7 +502,7 @@ func (t *tlsWrappingDialer) DialContext(ctx context.Context, network, address st
} }
cfg = c cfg = c
} else { } else {
cfg = &tls.Config{ServerName: host} // #nosec G402 - intentional cfg = &tls.Config{ServerName: host, MinVersion: tls.VersionTLS12}
} }
tlsConn := tls.Client(raw, cfg) tlsConn := tls.Client(raw, cfg)

View File

@@ -14,18 +14,8 @@ import (
"git.wntrmute.dev/kyle/goutils/fileutil" "git.wntrmute.dev/kyle/goutils/fileutil"
) )
// FetcherOpts are options for fetching certificates. They are only applicable to ServerFetcher. // Note: Previously this package exposed a FetcherOpts type. It has been
type FetcherOpts struct { // refactored to use *tls.Config directly for configuring TLS behavior.
SkipVerify bool
Roots *x509.CertPool
}
func (fo *FetcherOpts) TLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: fo.SkipVerify, // #nosec G402 - intentional
RootCAs: fo.Roots,
}
}
// Fetcher is an interface for fetching certificates from a remote source. It // Fetcher is an interface for fetching certificates from a remote source. It
// currently supports fetching from a server or a file. // currently supports fetching from a server or a file.
@@ -143,7 +133,10 @@ func (ff *FileFetcher) Get() (*x509.Certificate, error) {
} }
// GetCertificateChain fetches a certificate chain from a remote source. // GetCertificateChain fetches a certificate chain from a remote source.
func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, error) { // If cfg is non-nil and spec refers to a TLS server, the provided TLS
// configuration will be used to control verification behavior (e.g.,
// InsecureSkipVerify, RootCAs).
func GetCertificateChain(spec string, cfg *tls.Config) ([]*x509.Certificate, error) {
if fileutil.FileDoesExist(spec) { if fileutil.FileDoesExist(spec) {
return NewFileFetcher(spec).GetChain() return NewFileFetcher(spec).GetChain()
} }
@@ -153,17 +146,17 @@ func GetCertificateChain(spec string, opts *FetcherOpts) ([]*x509.Certificate, e
return nil, err return nil, err
} }
if opts != nil { if cfg != nil {
fetcher.insecure = opts.SkipVerify fetcher.insecure = cfg.InsecureSkipVerify
fetcher.roots = opts.Roots fetcher.roots = cfg.RootCAs
} }
return fetcher.GetChain() return fetcher.GetChain()
} }
// GetCertificate fetches the first certificate from a certificate chain. // GetCertificate fetches the first certificate from a certificate chain.
func GetCertificate(spec string, opts *FetcherOpts) (*x509.Certificate, error) { func GetCertificate(spec string, cfg *tls.Config) (*x509.Certificate, error) {
certs, err := GetCertificateChain(spec, opts) certs, err := GetCertificateChain(spec, cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -45,7 +45,7 @@ fi
# Use the first tag if multiple are present; warn the user. # Use the first tag if multiple are present; warn the user.
# Avoid readarray for broader Bash compatibility (e.g., macOS Bash 3.2). # Avoid readarray for broader Bash compatibility (e.g., macOS Bash 3.2).
TAG_ARRAY=($TAGS) TAG_ARRAY=("$TAGS")
TAG="${TAG_ARRAY[0]}" TAG="${TAG_ARRAY[0]}"
if (( ${#TAG_ARRAY[@]} > 1 )); then if (( ${#TAG_ARRAY[@]} > 1 )); then