1. DB migration: add CHECK(mode IN ('l4', 'l7')) constraint on the
routes.mode column. ARCHITECTURE.md documented this constraint but
migration v2 omitted it. Enforces mode validity at the database
level in addition to application-level validation.
2. L7 reverse proxy: distinguish timeout errors from connection errors
in the ErrorHandler. Backend timeouts now return HTTP 504 Gateway
Timeout instead of 502. Uses errors.Is(context.DeadlineExceeded)
and net.Error.Timeout() detection. Added isTimeoutError unit tests.
3. Config validation: warn when L4 routes have tls_cert or tls_key set
(they are silently ignored). ARCHITECTURE.md documented this warning
but config.validate() did not emit it. Uses slog.Warn.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
257 lines
7.1 KiB
Go
257 lines
7.1 KiB
Go
package config
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"fmt"
|
|
"log/slog"
|
|
"os"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pelletier/go-toml/v2"
|
|
)
|
|
|
|
type Config struct {
|
|
Listeners []Listener `toml:"listeners"`
|
|
Database Database `toml:"database"`
|
|
GRPC GRPC `toml:"grpc"`
|
|
Firewall Firewall `toml:"firewall"`
|
|
Proxy Proxy `toml:"proxy"`
|
|
Log Log `toml:"log"`
|
|
}
|
|
|
|
type Database struct {
|
|
Path string `toml:"path"`
|
|
}
|
|
|
|
type GRPC struct {
|
|
Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock")
|
|
}
|
|
|
|
type Listener struct {
|
|
Addr string `toml:"addr"`
|
|
ProxyProtocol bool `toml:"proxy_protocol"`
|
|
Routes []Route `toml:"routes"`
|
|
}
|
|
|
|
type Route struct {
|
|
Hostname string `toml:"hostname"`
|
|
Backend string `toml:"backend"`
|
|
Mode string `toml:"mode"` // "l4" (default) or "l7"
|
|
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
|
|
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
|
|
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
|
|
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
|
|
}
|
|
|
|
type Firewall struct {
|
|
GeoIPDB string `toml:"geoip_db"`
|
|
BlockedIPs []string `toml:"blocked_ips"`
|
|
BlockedCIDRs []string `toml:"blocked_cidrs"`
|
|
BlockedCountries []string `toml:"blocked_countries"`
|
|
RateLimit int64 `toml:"rate_limit"`
|
|
RateWindow Duration `toml:"rate_window"`
|
|
}
|
|
|
|
type Proxy struct {
|
|
ConnectTimeout Duration `toml:"connect_timeout"`
|
|
IdleTimeout Duration `toml:"idle_timeout"`
|
|
ShutdownTimeout Duration `toml:"shutdown_timeout"`
|
|
}
|
|
|
|
type Log struct {
|
|
Level string `toml:"level"`
|
|
}
|
|
|
|
// Duration wraps time.Duration for TOML string unmarshalling.
|
|
type Duration struct {
|
|
time.Duration
|
|
}
|
|
|
|
// SocketPath returns the filesystem path for the Unix socket,
|
|
// stripping any "unix:" prefix.
|
|
func (g GRPC) SocketPath() string {
|
|
return strings.TrimPrefix(g.Addr, "unix:")
|
|
}
|
|
|
|
func (d *Duration) UnmarshalText(text []byte) error {
|
|
var err error
|
|
d.Duration, err = time.ParseDuration(string(text))
|
|
return err
|
|
}
|
|
|
|
func Load(path string) (*Config, error) {
|
|
data, err := os.ReadFile(path)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("reading config: %w", err)
|
|
}
|
|
|
|
var cfg Config
|
|
if err := toml.Unmarshal(data, &cfg); err != nil {
|
|
return nil, fmt.Errorf("parsing config: %w", err)
|
|
}
|
|
|
|
if err := cfg.applyEnvOverrides(); err != nil {
|
|
return nil, fmt.Errorf("applying env overrides: %w", err)
|
|
}
|
|
|
|
if err := cfg.validate(); err != nil {
|
|
return nil, fmt.Errorf("invalid config: %w", err)
|
|
}
|
|
|
|
return &cfg, nil
|
|
}
|
|
|
|
// applyEnvOverrides applies environment variable overrides to the config.
|
|
// Variables use the MCPROXY_ prefix with underscore-separated paths.
|
|
func (c *Config) applyEnvOverrides() error {
|
|
// Database
|
|
if v := os.Getenv("MCPROXY_DATABASE_PATH"); v != "" {
|
|
c.Database.Path = v
|
|
}
|
|
|
|
// gRPC
|
|
if v := os.Getenv("MCPROXY_GRPC_ADDR"); v != "" {
|
|
c.GRPC.Addr = v
|
|
}
|
|
|
|
// Firewall
|
|
if v := os.Getenv("MCPROXY_FIREWALL_GEOIP_DB"); v != "" {
|
|
c.Firewall.GeoIPDB = v
|
|
}
|
|
if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" {
|
|
var n int64
|
|
if _, err := fmt.Sscanf(v, "%d", &n); err != nil {
|
|
return fmt.Errorf("MCPROXY_FIREWALL_RATE_LIMIT: %w", err)
|
|
}
|
|
c.Firewall.RateLimit = n
|
|
}
|
|
if v := os.Getenv("MCPROXY_FIREWALL_RATE_WINDOW"); v != "" {
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return fmt.Errorf("MCPROXY_FIREWALL_RATE_WINDOW: %w", err)
|
|
}
|
|
c.Firewall.RateWindow = Duration{d}
|
|
}
|
|
|
|
// Proxy timeouts
|
|
if v := os.Getenv("MCPROXY_PROXY_CONNECT_TIMEOUT"); v != "" {
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return fmt.Errorf("MCPROXY_PROXY_CONNECT_TIMEOUT: %w", err)
|
|
}
|
|
c.Proxy.ConnectTimeout = Duration{d}
|
|
}
|
|
if v := os.Getenv("MCPROXY_PROXY_IDLE_TIMEOUT"); v != "" {
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return fmt.Errorf("MCPROXY_PROXY_IDLE_TIMEOUT: %w", err)
|
|
}
|
|
c.Proxy.IdleTimeout = Duration{d}
|
|
}
|
|
if v := os.Getenv("MCPROXY_PROXY_SHUTDOWN_TIMEOUT"); v != "" {
|
|
d, err := time.ParseDuration(v)
|
|
if err != nil {
|
|
return fmt.Errorf("MCPROXY_PROXY_SHUTDOWN_TIMEOUT: %w", err)
|
|
}
|
|
c.Proxy.ShutdownTimeout = Duration{d}
|
|
}
|
|
|
|
// Log
|
|
if v := os.Getenv("MCPROXY_LOG_LEVEL"); v != "" {
|
|
c.Log.Level = v
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Config) validate() error {
|
|
if c.Database.Path == "" {
|
|
return fmt.Errorf("database.path is required")
|
|
}
|
|
|
|
// Validate listeners if provided (used for seeding on first run).
|
|
for i := range c.Listeners {
|
|
l := &c.Listeners[i]
|
|
if l.Addr == "" {
|
|
return fmt.Errorf("listener %d: addr is required", i)
|
|
}
|
|
seen := make(map[string]bool)
|
|
for j := range l.Routes {
|
|
r := &l.Routes[j]
|
|
if r.Hostname == "" {
|
|
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
|
|
}
|
|
if r.Backend == "" {
|
|
return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j)
|
|
}
|
|
if seen[r.Hostname] {
|
|
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
|
|
}
|
|
seen[r.Hostname] = true
|
|
|
|
// Normalize mode: empty defaults to "l4".
|
|
if r.Mode == "" {
|
|
r.Mode = "l4"
|
|
}
|
|
if r.Mode != "l4" && r.Mode != "l7" {
|
|
return fmt.Errorf("listener %d (%s), route %d (%s): mode must be \"l4\" or \"l7\", got %q",
|
|
i, l.Addr, j, r.Hostname, r.Mode)
|
|
}
|
|
|
|
// Warn if L4 routes have cert/key set (they are ignored).
|
|
if r.Mode == "l4" && (r.TLSCert != "" || r.TLSKey != "") {
|
|
slog.Warn("L4 route has tls_cert/tls_key set (ignored)",
|
|
"listener", l.Addr, "hostname", r.Hostname)
|
|
}
|
|
|
|
// L7 routes require TLS cert and key.
|
|
if r.Mode == "l7" {
|
|
if r.TLSCert == "" || r.TLSKey == "" {
|
|
return fmt.Errorf("listener %d (%s), route %d (%s): L7 routes require tls_cert and tls_key",
|
|
i, l.Addr, j, r.Hostname)
|
|
}
|
|
if _, err := tls.LoadX509KeyPair(r.TLSCert, r.TLSKey); err != nil {
|
|
return fmt.Errorf("listener %d (%s), route %d (%s): loading TLS cert/key: %w",
|
|
i, l.Addr, j, r.Hostname, err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" {
|
|
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
|
|
}
|
|
|
|
if c.Firewall.RateLimit < 0 {
|
|
return fmt.Errorf("firewall.rate_limit must not be negative")
|
|
}
|
|
if c.Firewall.RateWindow.Duration < 0 {
|
|
return fmt.Errorf("firewall.rate_window must not be negative")
|
|
}
|
|
if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 {
|
|
return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
|
|
}
|
|
|
|
// Validate gRPC config: if enabled, addr must be a Unix socket path.
|
|
if c.GRPC.Addr != "" {
|
|
path := c.GRPC.SocketPath()
|
|
if !strings.Contains(path, "/") {
|
|
return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)")
|
|
}
|
|
}
|
|
|
|
// Validate timeouts are non-negative.
|
|
if c.Proxy.ConnectTimeout.Duration < 0 {
|
|
return fmt.Errorf("proxy.connect_timeout must not be negative")
|
|
}
|
|
if c.Proxy.IdleTimeout.Duration < 0 {
|
|
return fmt.Errorf("proxy.idle_timeout must not be negative")
|
|
}
|
|
if c.Proxy.ShutdownTimeout.Duration < 0 {
|
|
return fmt.Errorf("proxy.shutdown_timeout must not be negative")
|
|
}
|
|
|
|
return nil
|
|
}
|