package config import ( "crypto/tls" "fmt" "log/slog" "os" "strings" mcdslconfig "git.wntrmute.dev/mc/mcdsl/config" ) // Duration is an alias for the mcdsl config.Duration type, which wraps // time.Duration with TOML string unmarshalling support. Exported so // existing code that references config.Duration continues to work. type Duration = mcdslconfig.Duration // Config is the top-level mc-proxy configuration. type Config struct { Listeners []Listener `toml:"listeners"` Database Database `toml:"database"` GRPC GRPC `toml:"grpc"` Firewall Firewall `toml:"firewall"` Proxy Proxy `toml:"proxy"` Metrics Metrics `toml:"metrics"` Log Log `toml:"log"` } // Metrics holds the Prometheus metrics endpoint configuration. type Metrics struct { Addr string `toml:"addr"` // e.g. "127.0.0.1:9090" Path string `toml:"path"` // e.g. "/metrics" (default) } // Database holds the database configuration. type Database struct { Path string `toml:"path"` } // GRPC holds the gRPC admin API configuration. type GRPC struct { Addr string `toml:"addr"` // Unix socket path (e.g., "/srv/mc-proxy/mc-proxy.sock") } // Listener is a proxy listener with its routes. type Listener struct { Addr string `toml:"addr"` ProxyProtocol bool `toml:"proxy_protocol"` MaxConnections int64 `toml:"max_connections"` // 0 = unlimited Routes []Route `toml:"routes"` } // Route is a proxy route within a listener. type Route struct { Hostname string `toml:"hostname"` Backend string `toml:"backend"` Mode string `toml:"mode"` // "l4" (default) or "l7" TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only) TLSKey string `toml:"tls_key"` // PEM private key path (L7 only) BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only) SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend L7Policies []L7Policy `toml:"l7_policies"` // HTTP-level policies (L7 only) } // L7Policy is an HTTP-level blocking policy for L7 routes. type L7Policy struct { Type string `toml:"type"` // "block_user_agent" or "require_header" Value string `toml:"value"` // UA substring or header name } // Firewall holds the global firewall configuration. type Firewall struct { GeoIPDB string `toml:"geoip_db"` BlockedIPs []string `toml:"blocked_ips"` BlockedCIDRs []string `toml:"blocked_cidrs"` BlockedCountries []string `toml:"blocked_countries"` RateLimit int64 `toml:"rate_limit"` RateWindow Duration `toml:"rate_window"` } // Proxy holds proxy behavior timeouts. type Proxy struct { ConnectTimeout Duration `toml:"connect_timeout"` IdleTimeout Duration `toml:"idle_timeout"` ShutdownTimeout Duration `toml:"shutdown_timeout"` } // Log holds logging configuration. type Log struct { Level string `toml:"level"` } // SocketPath returns the filesystem path for the Unix socket, // stripping any "unix:" prefix. func (g GRPC) SocketPath() string { return strings.TrimPrefix(g.Addr, "unix:") } // Load reads and validates the mc-proxy configuration from a TOML file. // Environment variables with the MCPROXY_ prefix override config values. func Load(path string) (*Config, error) { cfg, err := mcdslconfig.Load[Config](path, "MCPROXY") if err != nil { return nil, err } return cfg, nil } // Validate implements the mcdsl config.Validator interface. It applies // manual env overrides for fields that the generic reflection-based // system cannot handle (int64, error-returning duration parsing), then // validates all config fields. func (c *Config) Validate() error { if err := c.applyManualEnvOverrides(); err != nil { return err } return c.validate() } // applyManualEnvOverrides handles env overrides that need error reporting // or non-standard types (int64 rate limits, duration fields that // reflection already handles but we want error semantics for). func (c *Config) applyManualEnvOverrides() error { if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" { var n int64 if _, err := fmt.Sscanf(v, "%d", &n); err != nil { return fmt.Errorf("MCPROXY_FIREWALL_RATE_LIMIT: %w", err) } c.Firewall.RateLimit = n } return nil } func (c *Config) validate() error { if c.Database.Path == "" { return fmt.Errorf("database.path is required") } // Validate listeners if provided (used for seeding on first run). for i := range c.Listeners { l := &c.Listeners[i] if l.Addr == "" { return fmt.Errorf("listener %d: addr is required", i) } if l.MaxConnections < 0 { return fmt.Errorf("listener %d (%s): max_connections must not be negative", i, l.Addr) } seen := make(map[string]bool) for j := range l.Routes { r := &l.Routes[j] if r.Hostname == "" { return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j) } if r.Backend == "" { return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j) } if seen[r.Hostname] { return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname) } seen[r.Hostname] = true if r.Mode == "" { r.Mode = "l4" } if r.Mode != "l4" && r.Mode != "l7" { return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q", i, l.Addr, j, r.Hostname, r.Mode) } if r.Mode == "l4" && (r.TLSCert != "" || r.TLSKey != "") { slog.Warn("L4 route has tls_cert/tls_key set (ignored)", "listener", l.Addr, "hostname", r.Hostname) } if r.Mode == "l7" { if r.TLSCert == "" || r.TLSKey == "" { return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key", i, l.Addr, j, r.Hostname) } if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil { return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w", i, l.Addr, j, r.Hostname, err) } } // Validate L7 policies. if r.Mode == "l4" && len(r.L7Policies) > 0 { slog.Warn("L4 route has l7_policies set (ignored)", "listener", l.Addr, "hostname", r.Hostname) } for k, p := range r.L7Policies { if p.Type != "block_user_agent" && p.Type != "require_header" { return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: type must be \"block_user_agent\" or \"require_header\", got %q", i, l.Addr, j, r.Hostname, k, p.Type) } if p.Value == "" { return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: value is required", i, l.Addr, j, r.Hostname, k) } } } } if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" { return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set") } if c.Firewall.RateLimit < 0 { return fmt.Errorf("firewall.rate_limit must not be negative") } if c.Firewall.RateWindow.Duration < 0 { return fmt.Errorf("firewall.rate_window must not be negative") } if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 { return fmt.Errorf("firewall.rate_window is required when rate_limit is set") } if c.GRPC.Addr != "" { socketPath := c.GRPC.SocketPath() if !strings.Contains(socketPath, "/") { return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /srv/mc-proxy/mc-proxy.sock)") } } if c.Metrics.Addr != "" && c.Metrics.Path != "" && !strings.HasPrefix(c.Metrics.Path, "/") { return fmt.Errorf("metrics.path must start with \"/\"") } if c.Proxy.ConnectTimeout.Duration < 0 { return fmt.Errorf("proxy.connect_timeout must not be negative") } if c.Proxy.IdleTimeout.Duration < 0 { return fmt.Errorf("proxy.idle_timeout must not be negative") } if c.Proxy.ShutdownTimeout.Duration < 0 { return fmt.Errorf("proxy.shutdown_timeout must not be negative") } return nil }