Files
mc-proxy/internal/config/config.go
Kyle Isom 5bc8f4fc8e Fix three doc-vs-implementation gaps found during audit
1. DB migration: add CHECK(mode IN ('l4', 'l7')) constraint on the
   routes.mode column. ARCHITECTURE.md documented this constraint but
   migration v2 omitted it. Enforces mode validity at the database
   level in addition to application-level validation.

2. L7 reverse proxy: distinguish timeout errors from connection errors
   in the ErrorHandler. Backend timeouts now return HTTP 504 Gateway
   Timeout instead of 502. Uses errors.Is(context.DeadlineExceeded)
   and net.Error.Timeout() detection. Added isTimeoutError unit tests.

3. Config validation: warn when L4 routes have tls_cert or tls_key set
   (they are silently ignored). ARCHITECTURE.md documented this warning
   but config.validate() did not emit it. Uses slog.Warn.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 14:25:41 -07:00

257 lines
7.1 KiB
Go

package config
import (
"crypto/tls"
"fmt"
"log/slog"
"os"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
)
type Config struct {
Listeners []Listener `toml:"listeners"`
Database Database `toml:"database"`
GRPC GRPC `toml:"grpc"`
Firewall Firewall `toml:"firewall"`
Proxy Proxy `toml:"proxy"`
Log Log `toml:"log"`
}
type Database struct {
Path string `toml:"path"`
}
type GRPC struct {
Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock")
}
type Listener struct {
Addr string `toml:"addr"`
ProxyProtocol bool `toml:"proxy_protocol"`
Routes []Route `toml:"routes"`
}
type Route struct {
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
Mode string `toml:"mode"` // "l4" (default) or "l7"
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
}
type Firewall struct {
GeoIPDB string `toml:"geoip_db"`
BlockedIPs []string `toml:"blocked_ips"`
BlockedCIDRs []string `toml:"blocked_cidrs"`
BlockedCountries []string `toml:"blocked_countries"`
RateLimit int64 `toml:"rate_limit"`
RateWindow Duration `toml:"rate_window"`
}
type Proxy struct {
ConnectTimeout Duration `toml:"connect_timeout"`
IdleTimeout Duration `toml:"idle_timeout"`
ShutdownTimeout Duration `toml:"shutdown_timeout"`
}
type Log struct {
Level string `toml:"level"`
}
// Duration wraps time.Duration for TOML string unmarshalling.
type Duration struct {
time.Duration
}
// SocketPath returns the filesystem path for the Unix socket,
// stripping any "unix:" prefix.
func (g GRPC) SocketPath() string {
return strings.TrimPrefix(g.Addr, "unix:")
}
func (d *Duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config: %w", err)
}
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if err := cfg.applyEnvOverrides(); err != nil {
return nil, fmt.Errorf("applying env overrides: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &cfg, nil
}
// applyEnvOverrides applies environment variable overrides to the config.
// Variables use the MCPROXY_ prefix with underscore-separated paths.
func (c *Config) applyEnvOverrides() error {
// Database
if v := os.Getenv("MCPROXY_DATABASE_PATH"); v != "" {
c.Database.Path = v
}
// gRPC
if v := os.Getenv("MCPROXY_GRPC_ADDR"); v != "" {
c.GRPC.Addr = v
}
// Firewall
if v := os.Getenv("MCPROXY_FIREWALL_GEOIP_DB"); v != "" {
c.Firewall.GeoIPDB = v
}
if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" {
var n int64
if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
return fmt.Errorf("MCPROXY_FIREWALL_RATE_LIMIT: %w", err)
}
c.Firewall.RateLimit = n
}
if v := os.Getenv("MCPROXY_FIREWALL_RATE_WINDOW"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_FIREWALL_RATE_WINDOW: %w", err)
}
c.Firewall.RateWindow = Duration{d}
}
// Proxy timeouts
if v := os.Getenv("MCPROXY_PROXY_CONNECT_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_CONNECT_TIMEOUT: %w", err)
}
c.Proxy.ConnectTimeout = Duration{d}
}
if v := os.Getenv("MCPROXY_PROXY_IDLE_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_IDLE_TIMEOUT: %w", err)
}
c.Proxy.IdleTimeout = Duration{d}
}
if v := os.Getenv("MCPROXY_PROXY_SHUTDOWN_TIMEOUT"); v != "" {
d, err := time.ParseDuration(v)
if err != nil {
return fmt.Errorf("MCPROXY_PROXY_SHUTDOWN_TIMEOUT: %w", err)
}
c.Proxy.ShutdownTimeout = Duration{d}
}
// Log
if v := os.Getenv("MCPROXY_LOG_LEVEL"); v != "" {
c.Log.Level = v
}
return nil
}
func (c *Config) validate() error {
if c.Database.Path == "" {
return fmt.Errorf("database.path is required")
}
// Validate listeners if provided (used for seeding on first run).
for i := range c.Listeners {
l := &c.Listeners[i]
if l.Addr == "" {
return fmt.Errorf("listener %d: addr is required", i)
}
seen := make(map[string]bool)
for j := range l.Routes {
r := &l.Routes[j]
if r.Hostname == "" {
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
}
if r.Backend == "" {
return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j)
}
if seen[r.Hostname] {
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
}
seen[r.Hostname] = true
// Normalize mode: empty defaults to "l4".
if r.Mode == "" {
r.Mode = "l4"
}
if r.Mode != "l4" && r.Mode != "l7" {
return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q",
i, l.Addr, j, r.Hostname, r.Mode)
}
// Warn if L4 routes have cert/key set (they are ignored).
if r.Mode == "l4" && (r.TLSCert != "" || r.TLSKey != "") {
slog.Warn("L4 route has tls_cert/tls_key set (ignored)",
"listener", l.Addr, "hostname", r.Hostname)
}
// L7 routes require TLS cert and key.
if r.Mode == "l7" {
if r.TLSCert == "" || r.TLSKey == "" {
return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key",
i, l.Addr, j, r.Hostname)
}
if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil {
return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w",
i, l.Addr, j, r.Hostname, err)
}
}
}
}
if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" {
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
}
if c.Firewall.RateLimit < 0 {
return fmt.Errorf("firewall.rate_limit must not be negative")
}
if c.Firewall.RateWindow.Duration < 0 {
return fmt.Errorf("firewall.rate_window must not be negative")
}
if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 {
return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
}
// Validate gRPC config: if enabled, addr must be a Unix socket path.
if c.GRPC.Addr != "" {
path := c.GRPC.SocketPath()
if !strings.Contains(path, "/") {
return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)")
}
}
// Validate timeouts are non-negative.
if c.Proxy.ConnectTimeout.Duration < 0 {
return fmt.Errorf("proxy.connect_timeout must not be negative")
}
if c.Proxy.IdleTimeout.Duration < 0 {
return fmt.Errorf("proxy.idle_timeout must not be negative")
}
if c.Proxy.ShutdownTimeout.Duration < 0 {
return fmt.Errorf("proxy.shutdown_timeout must not be negative")
}
return nil
}