package config import ( "crypto/tls" "fmt" "os" "strings" "time" "github.com/pelletier/go-toml/v2" ) type Config struct { Listeners []Listener `toml:"listeners"` Database Database `toml:"database"` GRPC GRPC `toml:"grpc"` Firewall Firewall `toml:"firewall"` Proxy Proxy `toml:"proxy"` Log Log `toml:"log"` } type Database struct { Path string `toml:"path"` } type GRPC struct { Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock") } type Listener struct { Addr string `toml:"addr"` ProxyProtocol bool `toml:"proxy_protocol"` Routes []Route `toml:"routes"` } type Route struct { Hostname string `toml:"hostname"` Backend string `toml:"backend"` Mode string `toml:"mode"` // "l4" (default) or "l7" TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only) TLSKey string `toml:"tls_key"` // PEM private key path (L7 only) BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only) SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend } type Firewall struct { GeoIPDB string `toml:"geoip_db"` BlockedIPs []string `toml:"blocked_ips"` BlockedCIDRs []string `toml:"blocked_cidrs"` BlockedCountries []string `toml:"blocked_countries"` RateLimit int64 `toml:"rate_limit"` RateWindow Duration `toml:"rate_window"` } type Proxy struct { ConnectTimeout Duration `toml:"connect_timeout"` IdleTimeout Duration `toml:"idle_timeout"` ShutdownTimeout Duration `toml:"shutdown_timeout"` } type Log struct { Level string `toml:"level"` } // Duration wraps time.Duration for TOML string unmarshalling. type Duration struct { time.Duration } // SocketPath returns the filesystem path for the Unix socket, // stripping any "unix:" prefix. func (g GRPC) SocketPath() string { return strings.TrimPrefix(g.Addr, "unix:") } func (d *Duration) UnmarshalText(text []byte) error { var err error d.Duration, err = time.ParseDuration(string(text)) return err } func Load(path string) (*Config, error) { data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("reading config: %w", err) } var cfg Config if err := toml.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("parsing config: %w", err) } if err := cfg.applyEnvOverrides(); err != nil { return nil, fmt.Errorf("applying env overrides: %w", err) } if err := cfg.validate(); err != nil { return nil, fmt.Errorf("invalid config: %w", err) } return &cfg, nil } // applyEnvOverrides applies environment variable overrides to the config. // Variables use the MCPROXY_ prefix with underscore-separated paths. func (c *Config) applyEnvOverrides() error { // Database if v := os.Getenv("MCPROXY_DATABASE_PATH"); v != "" { c.Database.Path = v } // gRPC if v := os.Getenv("MCPROXY_GRPC_ADDR"); v != "" { c.GRPC.Addr = v } // Firewall if v := os.Getenv("MCPROXY_FIREWALL_GEOIP_DB"); v != "" { c.Firewall.GeoIPDB = v } if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" { var n int64 if _, err := fmt.Sscanf(v, "%d", &n); err != nil { return fmt.Errorf("MCPROXY_FIREWALL_RATE_LIMIT: %w", err) } c.Firewall.RateLimit = n } if v := os.Getenv("MCPROXY_FIREWALL_RATE_WINDOW"); v != "" { d, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("MCPROXY_FIREWALL_RATE_WINDOW: %w", err) } c.Firewall.RateWindow = Duration{d} } // Proxy timeouts if v := os.Getenv("MCPROXY_PROXY_CONNECT_TIMEOUT"); v != "" { d, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("MCPROXY_PROXY_CONNECT_TIMEOUT: %w", err) } c.Proxy.ConnectTimeout = Duration{d} } if v := os.Getenv("MCPROXY_PROXY_IDLE_TIMEOUT"); v != "" { d, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("MCPROXY_PROXY_IDLE_TIMEOUT: %w", err) } c.Proxy.IdleTimeout = Duration{d} } if v := os.Getenv("MCPROXY_PROXY_SHUTDOWN_TIMEOUT"); v != "" { d, err := time.ParseDuration(v) if err != nil { return fmt.Errorf("MCPROXY_PROXY_SHUTDOWN_TIMEOUT: %w", err) } c.Proxy.ShutdownTimeout = Duration{d} } // Log if v := os.Getenv("MCPROXY_LOG_LEVEL"); v != "" { c.Log.Level = v } return nil } func (c *Config) validate() error { if c.Database.Path == "" { return fmt.Errorf("database.path is required") } // Validate listeners if provided (used for seeding on first run). for i := range c.Listeners { l := &c.Listeners[i] if l.Addr == "" { return fmt.Errorf("listener %d: addr is required", i) } seen := make(map[string]bool) for j := range l.Routes { r := &l.Routes[j] if r.Hostname == "" { return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j) } if r.Backend == "" { return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j) } if seen[r.Hostname] { return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname) } seen[r.Hostname] = true // Normalize mode: empty defaults to "l4". if r.Mode == "" { r.Mode = "l4" } if r.Mode != "l4" && r.Mode != "l7" { return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q", i, l.Addr, j, r.Hostname, r.Mode) } // L7 routes require TLS cert and key. if r.Mode == "l7" { if r.TLSCert == "" || r.TLSKey == "" { return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key", i, l.Addr, j, r.Hostname) } if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil { return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w", i, l.Addr, j, r.Hostname, err) } } } } if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" { return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set") } if c.Firewall.RateLimit < 0 { return fmt.Errorf("firewall.rate_limit must not be negative") } if c.Firewall.RateWindow.Duration < 0 { return fmt.Errorf("firewall.rate_window must not be negative") } if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 { return fmt.Errorf("firewall.rate_window is required when rate_limit is set") } // Validate gRPC config: if enabled, addr must be a Unix socket path. if c.GRPC.Addr != "" { path := c.GRPC.SocketPath() if !strings.Contains(path, "/") { return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)") } } // Validate timeouts are non-negative. if c.Proxy.ConnectTimeout.Duration < 0 { return fmt.Errorf("proxy.connect_timeout must not be negative") } if c.Proxy.IdleTimeout.Duration < 0 { return fmt.Errorf("proxy.idle_timeout must not be negative") } if c.Proxy.ShutdownTimeout.Duration < 0 { return fmt.Errorf("proxy.shutdown_timeout must not be negative") } return nil }