diff --git a/certlib/certgen/config.go b/certlib/certgen/config.go new file mode 100644 index 0000000..5d3e454 --- /dev/null +++ b/certlib/certgen/config.go @@ -0,0 +1,211 @@ +package certgen + +import ( + "crypto" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "fmt" + "math/big" + "net" + "strings" + "time" + + "git.wntrmute.dev/kyle/goutils/lib" +) + +type KeySpec struct { + Algorithm string `yaml:"algorithm"` + Size int `yaml:"size"` +} + +func (ks KeySpec) Generate() (crypto.PublicKey, crypto.PrivateKey, error) { + switch strings.ToLower(ks.Algorithm) { + case "rsa": + return GenerateKey(x509.RSA, ks.Size) + case "ecdsa": + return GenerateKey(x509.ECDSA, ks.Size) + case "ed25519": + return GenerateKey(x509.Ed25519, 0) + default: + return nil, nil, fmt.Errorf("unknown key algorithm: %s", ks.Algorithm) + } +} + +func (ks KeySpec) SigningAlgorithm() (x509.SignatureAlgorithm, error) { + switch strings.ToLower(ks.Algorithm) { + case "rsa": + return x509.SHA512WithRSAPSS, nil + case "ecdsa": + return x509.ECDSAWithSHA512, nil + case "ed25519": + return x509.PureEd25519, nil + default: + return 0, fmt.Errorf("unknown key algorithm: %s", ks.Algorithm) + } +} + +type Subject struct { + CommonName string `yaml:"common_name"` + Country string `yaml:"country"` + Locality string `yaml:"locality"` + Province string `yaml:"province"` + Organization string `yaml:"organization"` + OrganizationalUnit string `yaml:"organizational_unit"` + Email string `yaml:"email"` + DNSNames []string `yaml:"dns"` + IPAddresses []string `yaml:"ips"` +} + +type CertificateRequest struct { + KeySpec KeySpec `yaml:"key"` + Subject Subject `yaml:"subject"` + Profile Profile `yaml:"profile"` +} + +func (cs CertificateRequest) Generate() (crypto.PrivateKey, *x509.CertificateRequest, error) { + pub, priv, err := cs.KeySpec.Generate() + if err != nil { + return nil, nil, err + } + + subject := pkix.Name{} + subject.CommonName = cs.Subject.CommonName + subject.Country = []string{cs.Subject.Country} + subject.Locality = []string{cs.Subject.Locality} + subject.Province = []string{cs.Subject.Province} + subject.Organization = []string{cs.Subject.Organization} + subject.OrganizationalUnit = []string{cs.Subject.OrganizationalUnit} + + ipAddresses := make([]net.IP, 0, len(cs.Subject.IPAddresses)) + for i, ip := range cs.Subject.IPAddresses { + ipAddresses = append(ipAddresses, net.ParseIP(ip)) + if ipAddresses[i] == nil { + return nil, nil, fmt.Errorf("invalid IP address: %s", ip) + } + } + + req := &x509.CertificateRequest{ + PublicKeyAlgorithm: 0, + PublicKey: pub, + Subject: subject, + DNSNames: cs.Subject.DNSNames, + IPAddresses: ipAddresses, + } + + reqBytes, err := x509.CreateCertificateRequest(rand.Reader, req, priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to create certificate request: %w", err) + } + + req, err = x509.ParseCertificateRequest(reqBytes) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse certificate request: %w", err) + } + + return priv, req, nil +} + +type Profile struct { + IsCA bool `yaml:"is_ca"` + PathLen int `yaml:"path_len"` + KeyUse string `yaml:"key_uses"` + ExtKeyUsages []string `yaml:"ext_key_usages"` + Expiry string `yaml:"expiry"` +} + +func (p Profile) templateFromRequest(req *x509.CertificateRequest) (*x509.Certificate, error) { + serial, err := SerialNumber() + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + + expiry, err := lib.ParseDuration(p.Expiry) + if err != nil { + return nil, fmt.Errorf("parsing expiry: %w", err) + } + + certTemplate := &x509.Certificate{ + SignatureAlgorithm: req.SignatureAlgorithm, + PublicKeyAlgorithm: req.PublicKeyAlgorithm, + PublicKey: req.PublicKey, + SerialNumber: serial, + Subject: req.Subject, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(expiry), + BasicConstraintsValid: true, + IsCA: p.IsCA, + MaxPathLen: p.PathLen, + DNSNames: req.DNSNames, + IPAddresses: req.IPAddresses, + } + + var ok bool + certTemplate.KeyUsage, ok = keyUsageStrings[p.KeyUse] + if !ok { + return nil, fmt.Errorf("invalid key usage: %s", p.KeyUse) + } + + var eku x509.ExtKeyUsage + for _, extKeyUsage := range p.ExtKeyUsages { + eku, ok = extKeyUsageStrings[extKeyUsage] + if !ok { + return nil, fmt.Errorf("invalid extended key usage: %s", extKeyUsage) + } + certTemplate.ExtKeyUsage = append(certTemplate.ExtKeyUsage, eku) + } + + return certTemplate, nil +} + +func (p Profile) SignRequest(parent *x509.Certificate, req *x509.CertificateRequest, priv crypto.PrivateKey) (*x509.Certificate, error) { + tpl, err := p.templateFromRequest(req) + if err != nil { + return nil, fmt.Errorf("failed to create certificate template: %w", err) + } + + certBytes, err := x509.CreateCertificate(rand.Reader, tpl, parent, req.PublicKey, priv) + if err != nil { + return nil, fmt.Errorf("failed to create certificate: %w", err) + } + + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("failed to parse certificate: %w", err) + } + + return cert, nil +} + +func (p Profile) SelfSign(req *x509.CertificateRequest, priv crypto.PrivateKey) (*x509.Certificate, error) { + certTemplate, err := p.templateFromRequest(req) + if err != nil { + return nil, fmt.Errorf("failed to create certificate template: %w", err) + } + + return p.SignRequest(certTemplate, req, priv) +} + +func SerialNumber() (*big.Int, error) { + serialNumberBytes := make([]byte, 20) + _, err := rand.Read(serialNumberBytes) + if err != nil { + return nil, fmt.Errorf("failed to generate serial number: %w", err) + } + return new(big.Int).SetBytes(serialNumberBytes), nil +} + +// GenerateSelfSigned generates a self-signed certificate using the given certificate request. +func GenerateSelfSigned(creq *CertificateRequest) (*x509.Certificate, crypto.PrivateKey, error) { + priv, req, err := creq.Generate() + if err != nil { + return nil, nil, fmt.Errorf("failed to generate certificate request: %w", err) + } + + cert, err := creq.Profile.SelfSign(req, priv) + if err != nil { + return nil, nil, fmt.Errorf("failed to self-sign certificate: %w", err) + } + + return cert, priv, nil +} diff --git a/certlib/certgen/keygen.go b/certlib/certgen/keygen.go index 8622d0a..299203c 100644 --- a/certlib/certgen/keygen.go +++ b/certlib/certgen/keygen.go @@ -1,11 +1,60 @@ package certgen import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" "encoding/asn1" + "errors" + "fmt" ) var ( oidEd25519 = asn1.ObjectIdentifier{1, 3, 101, 110} ) -func GenerateKey() {} +func GenerateKey(algorithm x509.PublicKeyAlgorithm, bitSize int) (crypto.PublicKey, crypto.PrivateKey, error) { + var key crypto.PrivateKey + var pub crypto.PublicKey + var err error + + switch algorithm { + case x509.RSA: + pub, key, err = ed25519.GenerateKey(rand.Reader) + case x509.Ed25519: + key, err = rsa.GenerateKey(rand.Reader, bitSize) + if err == nil { + pub = key.(*rsa.PrivateKey).Public() + } + case x509.ECDSA: + var curve elliptic.Curve + + switch bitSize { + case 256: + curve = elliptic.P256() + case 384: + curve = elliptic.P384() + case 521: + curve = elliptic.P521() + default: + return nil, nil, fmt.Errorf("unsupported curve size %d", bitSize) + } + + key, err = ecdsa.GenerateKey(curve, rand.Reader) + if err == nil { + pub = key.(*ecdsa.PrivateKey).Public() + } + default: + err = errors.New("unsupported algorithm") + } + + if err != nil { + return nil, nil, err + } + + return pub, key, nil +} diff --git a/certlib/certgen/ku.go b/certlib/certgen/ku.go new file mode 100644 index 0000000..7042ea9 --- /dev/null +++ b/certlib/certgen/ku.go @@ -0,0 +1,32 @@ +package certgen + +import "crypto/x509" + +var keyUsageStrings = map[string]x509.KeyUsage{ + "signing": x509.KeyUsageDigitalSignature, + "digital signature": x509.KeyUsageDigitalSignature, + "content commitment": x509.KeyUsageContentCommitment, + "key encipherment": x509.KeyUsageKeyEncipherment, + "key agreement": x509.KeyUsageKeyAgreement, + "data encipherment": x509.KeyUsageDataEncipherment, + "cert sign": x509.KeyUsageCertSign, + "crl sign": x509.KeyUsageCRLSign, + "encipher only": x509.KeyUsageEncipherOnly, + "decipher only": x509.KeyUsageDecipherOnly, +} + +var extKeyUsageStrings = map[string]x509.ExtKeyUsage{ + "any": x509.ExtKeyUsageAny, + "server auth": x509.ExtKeyUsageServerAuth, + "client auth": x509.ExtKeyUsageClientAuth, + "code signing": x509.ExtKeyUsageCodeSigning, + "email protection": x509.ExtKeyUsageEmailProtection, + "s/mime": x509.ExtKeyUsageEmailProtection, + "ipsec end system": x509.ExtKeyUsageIPSECEndSystem, + "ipsec tunnel": x509.ExtKeyUsageIPSECTunnel, + "ipsec user": x509.ExtKeyUsageIPSECUser, + "timestamping": x509.ExtKeyUsageTimeStamping, + "ocsp signing": x509.ExtKeyUsageOCSPSigning, + "microsoft sgc": x509.ExtKeyUsageMicrosoftServerGatedCrypto, + "netscape sgc": x509.ExtKeyUsageNetscapeServerGatedCrypto, +} diff --git a/certlib/dump/dump.go b/certlib/dump/dump.go index b82076d..9e9be0f 100644 --- a/certlib/dump/dump.go +++ b/certlib/dump/dump.go @@ -54,6 +54,8 @@ var extKeyUsages = map[x509.ExtKeyUsage]string{ x509.ExtKeyUsageMicrosoftKernelCodeSigning: "microsoft kernel code signing", } + + func sigAlgoPK(a x509.SignatureAlgorithm) string { switch a { case x509.MD2WithRSA, x509.MD5WithRSA, x509.SHA1WithRSA, x509.SHA256WithRSA, x509.SHA384WithRSA, x509.SHA512WithRSA: diff --git a/certlib/verify/verify.go b/certlib/verify/verify.go index cb6a285..28ca56a 100644 --- a/certlib/verify/verify.go +++ b/certlib/verify/verify.go @@ -8,7 +8,8 @@ import ( "io" "git.wntrmute.dev/kyle/goutils/certlib/revoke" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" + "git.wntrmute.dev/kyle/goutils/lib/fetch" ) func bundleIntermediates(w io.Writer, chain []*x509.Certificate, pool *x509.CertPool, verbose bool) *x509.CertPool { @@ -45,7 +46,7 @@ func prepareVerification(w io.Writer, target string, opts *Opts) (*verifyResult, if opts == nil { opts = &Opts{ - Config: lib.StrictBaselineTLSConfig(), + Config: dialer.StrictBaselineTLSConfig(), ForceIntermediates: false, } } @@ -67,7 +68,7 @@ func prepareVerification(w io.Writer, target string, opts *Opts) (*verifyResult, roots = opts.Config.RootCAs.Clone() - chain, err := lib.GetCertificateChain(target, opts.Config) + chain, err := fetch.GetCertificateChain(target, opts.Config) if err != nil { return nil, fmt.Errorf("fetching certificate chain: %w", err) } diff --git a/cmd/ca-signed/main.go b/cmd/ca-signed/main.go index c5f2fd6..6c78196 100644 --- a/cmd/ca-signed/main.go +++ b/cmd/ca-signed/main.go @@ -12,6 +12,8 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib/verify" "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" + "git.wntrmute.dev/kyle/goutils/lib/fetch" ) //go:embed testdata/*.pem @@ -137,11 +139,11 @@ func selftest() int { func main() { var skipVerify, useStrict bool - lib.StrictTLSFlag(&useStrict) + dialer.StrictTLSFlag(&useStrict) flag.BoolVar(&skipVerify, "k", false, "don't verify certificates") flag.Parse() - tcfg, err := lib.BaselineTLSConfig(skipVerify, useStrict) + tcfg, err := dialer.BaselineTLSConfig(skipVerify, useStrict) die.If(err) args := flag.Args() @@ -171,7 +173,7 @@ func main() { for _, arg := range args { var cert *x509.Certificate - cert, err = lib.GetCertificate(arg, tcfg) + cert, err = fetch.GetCertificate(arg, tcfg) if err != nil { lib.Warn(err, "while parsing certificate from %s", arg) continue diff --git a/cmd/cert-revcheck/main.go b/cmd/cert-revcheck/main.go index fc3b6d0..cb7925e 100644 --- a/cmd/cert-revcheck/main.go +++ b/cmd/cert-revcheck/main.go @@ -15,7 +15,7 @@ import ( hosts "git.wntrmute.dev/kyle/goutils/certlib/hosts" "git.wntrmute.dev/kyle/goutils/certlib/revoke" "git.wntrmute.dev/kyle/goutils/fileutil" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) var ( @@ -39,7 +39,7 @@ func main() { revoke.HardFail = hardfail // Build a proxy-aware HTTP client for OCSP/CRL fetches - if httpClient, err := lib.NewHTTPClient(lib.DialerOpts{Timeout: timeout}); err == nil { + if httpClient, err := dialer.NewHTTPClient(dialer.DialerOpts{Timeout: timeout}); err == nil { revoke.HTTPClient = httpClient } @@ -105,7 +105,7 @@ func checkSite(hostport string) (string, error) { defer cancel() // Use proxy-aware TLS dialer - conn, err := lib.DialTLS(ctx, target.String(), lib.DialerOpts{Timeout: timeout, TLSConfig: &tls.Config{ + conn, err := dialer.DialTLS(ctx, target.String(), dialer.DialerOpts{Timeout: timeout, TLSConfig: &tls.Config{ InsecureSkipVerify: true, // #nosec G402 -- CLI tool only verifies revocation ServerName: target.Host, }}) diff --git a/cmd/certchain/main.go b/cmd/certchain/main.go index 042a2dc..6c1161e 100644 --- a/cmd/certchain/main.go +++ b/cmd/certchain/main.go @@ -11,7 +11,7 @@ import ( "strings" "git.wntrmute.dev/kyle/goutils/die" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) var hasPort = regexp.MustCompile(`:\d+$`) @@ -25,7 +25,7 @@ func main() { } // Use proxy-aware TLS dialer - conn, err := lib.DialTLS(context.Background(), server, lib.DialerOpts{TLSConfig: &tls.Config{}}) // #nosec G402 + conn, err := dialer.DialTLS(context.Background(), server, dialer.DialerOpts{TLSConfig: &tls.Config{}}) // #nosec G402 die.If(err) defer conn.Close() diff --git a/cmd/certdump/main.go b/cmd/certdump/main.go index a27a89b..0c7f9ef 100644 --- a/cmd/certdump/main.go +++ b/cmd/certdump/main.go @@ -9,6 +9,7 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib/dump" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/fetch" ) var config struct { @@ -27,7 +28,7 @@ func main() { for _, filename := range flag.Args() { fmt.Fprintf(os.Stdout, "--%s ---%s", filename, "\n") - certs, err := lib.GetCertificateChain(filename, tlsCfg) + certs, err := fetch.GetCertificateChain(filename, tlsCfg) if err != nil { lib.Warn(err, "couldn't read certificate") continue diff --git a/cmd/certexpiry/main.go b/cmd/certexpiry/main.go index 3496616..256e8b5 100644 --- a/cmd/certexpiry/main.go +++ b/cmd/certexpiry/main.go @@ -8,6 +8,8 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib/verify" "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" + "git.wntrmute.dev/kyle/goutils/lib/fetch" ) func main() { @@ -18,20 +20,20 @@ func main() { warnOnly bool ) - lib.StrictTLSFlag(&strictTLS) + dialer.StrictTLSFlag(&strictTLS) flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402 flag.BoolVar(&warnOnly, "q", false, "only warn about expiring certs") flag.DurationVar(&leeway, "t", leeway, "warn if certificates are closer than this to expiring") flag.Parse() - tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + tlsCfg, err := dialer.BaselineTLSConfig(skipVerify, strictTLS) die.If(err) for _, file := range flag.Args() { var certs []*x509.Certificate - certs, err = lib.GetCertificateChain(file, tlsCfg) + certs, err = fetch.GetCertificateChain(file, tlsCfg) if err != nil { _, _ = lib.Warn(err, "while parsing certificates") continue diff --git a/cmd/certser/main.go b/cmd/certser/main.go index a801a84..3bc5848 100644 --- a/cmd/certser/main.go +++ b/cmd/certser/main.go @@ -8,6 +8,8 @@ import ( "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" + "git.wntrmute.dev/kyle/goutils/lib/fetch" ) const displayInt lib.HexEncodeMode = iota @@ -33,13 +35,13 @@ func serialString(cert *x509.Certificate, mode lib.HexEncodeMode) string { func main() { var skipVerify bool var strictTLS bool - lib.StrictTLSFlag(&strictTLS) + dialer.StrictTLSFlag(&strictTLS) displayAs := flag.String("d", "int", "display mode (int, hex, uhex)") showExpiry := flag.Bool("e", false, "show expiry date") flag.BoolVar(&skipVerify, "k", false, "skip server verification") // #nosec G402 flag.Parse() - tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + tlsCfg, err := dialer.BaselineTLSConfig(skipVerify, strictTLS) die.If(err) displayMode := parseDisplayMode(*displayAs) @@ -47,7 +49,7 @@ func main() { for _, arg := range flag.Args() { var cert *x509.Certificate - cert, err = lib.GetCertificate(arg, tlsCfg) + cert, err = fetch.GetCertificate(arg, tlsCfg) die.If(err) fmt.Printf("%s: %s", arg, serialString(cert, displayMode)) diff --git a/cmd/certverify/main.go b/cmd/certverify/main.go index fc370b8..287c730 100644 --- a/cmd/certverify/main.go +++ b/cmd/certverify/main.go @@ -10,6 +10,7 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib/verify" "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) type appConfig struct { @@ -28,7 +29,7 @@ func parseFlags() appConfig { flag.BoolVar(&cfg.skipVerify, "k", false, "skip CA verification") flag.BoolVar(&cfg.revexp, "r", false, "print revocation and expiry information") flag.BoolVar(&cfg.verbose, "v", false, "verbose") - lib.StrictTLSFlag(&cfg.strictTLS) + dialer.StrictTLSFlag(&cfg.strictTLS) flag.Parse() if flag.NArg() == 0 { @@ -71,7 +72,7 @@ func main() { die.If(err) } - opts.Config, err = lib.BaselineTLSConfig(cfg.skipVerify, cfg.strictTLS) + opts.Config, err = dialer.BaselineTLSConfig(cfg.skipVerify, cfg.strictTLS) die.If(err) opts.Config.RootCAs = roots diff --git a/cmd/rhash/main.go b/cmd/rhash/main.go index 67d909a..71d7a39 100644 --- a/cmd/rhash/main.go +++ b/cmd/rhash/main.go @@ -14,6 +14,7 @@ import ( "git.wntrmute.dev/kyle/goutils/ahash" "git.wntrmute.dev/kyle/goutils/die" "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) func usage(w io.Writer) { @@ -84,7 +85,7 @@ func main() { continue } // Use proxy-aware HTTP client with a reasonable timeout for connects/handshakes - httpClient, err := lib.NewHTTPClient(lib.DialerOpts{Timeout: 30 * time.Second}) + httpClient, err := dialer.NewHTTPClient(dialer.DialerOpts{Timeout: 30 * time.Second}) if err != nil { _, _ = lib.Warn(err, "building HTTP client for %s", remote) continue diff --git a/cmd/stealchain/main.go b/cmd/stealchain/main.go index da2abc7..a5772ef 100644 --- a/cmd/stealchain/main.go +++ b/cmd/stealchain/main.go @@ -11,20 +11,20 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/die" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) func main() { var sysRoot, serverName string var skipVerify bool var strictTLS bool - lib.StrictTLSFlag(&strictTLS) + dialer.StrictTLSFlag(&strictTLS) flag.StringVar(&sysRoot, "ca", "", "provide an alternate CA bundle") flag.StringVar(&serverName, "sni", "", "provide an SNI name") flag.BoolVar(&skipVerify, "noverify", false, "don't verify certificates") flag.Parse() - tlsCfg, err := lib.BaselineTLSConfig(skipVerify, strictTLS) + tlsCfg, err := dialer.BaselineTLSConfig(skipVerify, strictTLS) die.If(err) if sysRoot != "" { @@ -43,7 +43,7 @@ func main() { } var conn *tls.Conn - conn, err = lib.DialTLS(context.Background(), site, lib.DialerOpts{TLSConfig: tlsCfg}) + conn, err = dialer.DialTLS(context.Background(), site, dialer.DialerOpts{TLSConfig: tlsCfg}) die.If(err) cs := conn.ConnectionState() diff --git a/cmd/tlsinfo/main.go b/cmd/tlsinfo/main.go index f26c9fe..a258eaf 100644 --- a/cmd/tlsinfo/main.go +++ b/cmd/tlsinfo/main.go @@ -9,7 +9,7 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib/hosts" "git.wntrmute.dev/kyle/goutils/die" - "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) func main() { @@ -22,10 +22,10 @@ func main() { die.If(err) // Use proxy-aware TLS dialer; skip verification as before - conn, err := lib.DialTLS( + conn, err := dialer.DialTLS( context.Background(), hostPort.String(), - lib.DialerOpts{TLSConfig: &tls.Config{InsecureSkipVerify: true}}, + dialer.DialerOpts{TLSConfig: &tls.Config{InsecureSkipVerify: true}}, ) // #nosec G402 die.If(err) diff --git a/lib/dialer.go b/lib/dialer/dialer.go similarity index 99% rename from lib/dialer.go rename to lib/dialer/dialer.go index 666b8fa..09c24da 100644 --- a/lib/dialer.go +++ b/lib/dialer/dialer.go @@ -12,7 +12,7 @@ // 3. HTTP_PROXY // // Both uppercase and lowercase variable names are honored. -package lib +package dialer import ( "bufio" @@ -468,8 +468,8 @@ func (s *socks5ContextDialer) DialContext(ctx context.Context, network, address // tlsWrappingDialer performs a TLS handshake over an existing base dialer. type tlsWrappingDialer struct { - base ContextDialer - tcfg *tls.Config + base ContextDialer + tcfg *tls.Config timeout time.Duration } diff --git a/lib/duration/duration.go b/lib/duration/duration.go new file mode 100644 index 0000000..b3fd0e5 --- /dev/null +++ b/lib/duration/duration.go @@ -0,0 +1 @@ +package duration diff --git a/lib/fetch.go b/lib/fetch/fetch.go similarity index 92% rename from lib/fetch.go rename to lib/fetch/fetch.go index 7faed18..ecd32fa 100644 --- a/lib/fetch.go +++ b/lib/fetch/fetch.go @@ -1,4 +1,4 @@ -package lib +package fetch import ( "context" @@ -12,6 +12,8 @@ import ( "git.wntrmute.dev/kyle/goutils/certlib" "git.wntrmute.dev/kyle/goutils/certlib/hosts" "git.wntrmute.dev/kyle/goutils/fileutil" + "git.wntrmute.dev/kyle/goutils/lib" + "git.wntrmute.dev/kyle/goutils/lib/dialer" ) // Note: Previously this package exposed a FetcherOpts type. It has been @@ -61,18 +63,18 @@ func ParseServer(host string) (*ServerFetcher, error) { } func (sf *ServerFetcher) String() string { - return fmt.Sprintf("tls://%s", net.JoinHostPort(sf.host, Itoa(sf.port, -1))) + return fmt.Sprintf("tls://%s", net.JoinHostPort(sf.host, lib.Itoa(sf.port, -1))) } func (sf *ServerFetcher) GetChain() ([]*x509.Certificate, error) { - opts := DialerOpts{ + opts := dialer.DialerOpts{ TLSConfig: &tls.Config{ InsecureSkipVerify: sf.insecure, // #nosec G402 - no shit sherlock RootCAs: sf.roots, }, } - conn, err := DialTLS(context.Background(), net.JoinHostPort(sf.host, Itoa(sf.port, -1)), opts) + conn, err := dialer.DialTLS(context.Background(), net.JoinHostPort(sf.host, lib.Itoa(sf.port, -1)), opts) if err != nil { return nil, fmt.Errorf("failed to dial server: %w", err) } diff --git a/lib/lib.go b/lib/lib.go index feca520..6938e9d 100644 --- a/lib/lib.go +++ b/lib/lib.go @@ -6,6 +6,7 @@ import ( "fmt" "os" "path/filepath" + "strconv" "strings" "time" ) @@ -112,6 +113,85 @@ func Duration(d time.Duration) string { return s } +// IsDigit checks if a byte is a decimal digit. +func IsDigit(b byte) bool { + return b >= '0' && b <= '9' +} + +// ParseDuration parses a duration string into a time.Duration. +// It supports standard units (ns, us/µs, ms, s, m, h) plus extended units: +// d (days, 24h), w (weeks, 7d), y (years, 365d). +// Units can be combined without spaces, e.g., "1y2w3d4h5m6s". +// Case-insensitive. Years and days are approximations (no leap seconds/months). +// Returns an error for invalid input. +func ParseDuration(s string) (time.Duration, error) { + s = strings.ToLower(s) // Normalize to lowercase for case-insensitivity. + if s == "" { + return 0, fmt.Errorf("empty duration string") + } + + var total time.Duration + i := 0 + for i < len(s) { + // Parse the number part. + start := i + for i < len(s) && IsDigit(s[i]) { + i++ + } + if start == i { + return 0, fmt.Errorf("expected number at position %d", start) + } + numStr := s[start:i] + num, err := strconv.ParseUint(numStr, 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid number %q: %w", numStr, err) + } + + // Parse the unit part. + if i >= len(s) { + return 0, fmt.Errorf("expected unit after number %q", numStr) + } + unitStart := i + i++ // Consume the first char of the unit. + unit := s[unitStart:i] + + // Handle potential two-char units like "ms". + if unit == "m" && i < len(s) && s[i] == 's' { + i++ // Consume the 's'. + unit = "ms" + } + + // Convert to duration based on unit. + var d time.Duration + switch unit { + case "ns": + d = time.Nanosecond * time.Duration(num) + case "us", "µs": + d = time.Microsecond * time.Duration(num) + case "ms": + d = time.Millisecond * time.Duration(num) + case "s": + d = time.Second * time.Duration(num) + case "m": + d = time.Minute * time.Duration(num) + case "h": + d = time.Hour * time.Duration(num) + case "d": + d = 24 * time.Hour * time.Duration(num) + case "w": + d = 7 * 24 * time.Hour * time.Duration(num) + case "y": + d = 365 * 24 * time.Hour * time.Duration(num) // Approximate, non-leap year. + default: + return 0, fmt.Errorf("unknown unit %q at position %d", s[unitStart:i], unitStart) + } + + total += d + } + + return total, nil +} + type HexEncodeMode uint8 const ( diff --git a/lib/lib_test.go b/lib/lib_test.go index 6851e97..67759f8 100644 --- a/lib/lib_test.go +++ b/lib/lib_test.go @@ -2,10 +2,46 @@ package lib_test import ( "testing" + "time" "git.wntrmute.dev/kyle/goutils/lib" ) +func TestParseDuration(t *testing.T) { + tests := []struct { + name string + input string + expected time.Duration + wantErr bool + }{ + // Valid durations + {"hour", "1h", time.Hour, false}, + {"day", "2d", 2 * 24 * time.Hour, false}, + {"minute", "3m", 3 * time.Minute, false}, + {"second", "4s", 4 * time.Second, false}, + + // Edge cases + {"zero seconds", "0s", 0, false}, + {"empty string", "", 0, true}, + {"no numeric before unit", "h", 0, true}, + {"invalid unit", "1x", 0, true}, + {"non-numeric input", "abc", 0, true}, + {"missing unit", "10", 0, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := lib.ParseDuration(tc.input) + if (err != nil) != tc.wantErr { + t.Fatalf("unexpected error: %v, wantErr: %v", err, tc.wantErr) + } + if got != tc.expected { + t.Fatalf("expected %v, got %v", tc.expected, got) + } + }) + } +} + func TestHexEncode_LowerUpper(t *testing.T) { b := []byte{0x0f, 0xa1, 0x00, 0xff}