From b25e1b0e79cd47ad4fc8bea179527f1e912ded92 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Tue, 17 Mar 2026 14:37:21 -0700 Subject: [PATCH] Add per-IP rate limiting and Unix socket support for gRPC admin API Rate limiting: per-source-IP connection rate limiter in the firewall layer with configurable limit and sliding window. Blocklisted IPs are rejected before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC admin API can now listen on a Unix domain socket (no TLS required), secured by file permissions (0600), as a simpler alternative for local-only access. Co-Authored-By: Claude Opus 4.6 (1M context) --- ARCHITECTURE.md | 58 ++++++++++++-- cmd/mc-proxy/server.go | 13 ++- cmd/mc-proxy/snapshot.go | 25 +++++- cmd/mc-proxy/status.go | 6 ++ deploy/mc-proxy.toml.example | 44 ++++++++-- internal/config/config.go | 44 ++++++++++ internal/config/config_test.go | 107 +++++++++++++++++++++++++ internal/firewall/firewall.go | 20 ++++- internal/firewall/firewall_test.go | 62 ++++++++++++-- internal/firewall/ratelimit.go | 81 +++++++++++++++++++ internal/firewall/ratelimit_test.go | 95 ++++++++++++++++++++++ internal/grpcserver/grpcserver.go | 75 +++++++++++++++-- internal/grpcserver/grpcserver_test.go | 66 ++++++++++++++- internal/server/server.go | 32 +++++++- internal/server/server_test.go | 6 +- internal/sni/sni.go | 3 + 16 files changed, 694 insertions(+), 43 deletions(-) create mode 100644 internal/firewall/ratelimit.go create mode 100644 internal/firewall/ratelimit_test.go diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index 6aa37e7..f372697 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -174,6 +174,47 @@ gRPC admin API. --- +## gRPC Admin API + +The admin API is optional (disabled if `[grpc]` is omitted from the config). +When enabled, it requires TLS and supports optional mTLS for client +authentication. TLS 1.3 is enforced. The API provides runtime management +of routes and firewall rules without restarting the proxy. + +### RPCs + +| RPC | Description | +|-----|-------------| +| `ListRoutes` | List all routes for a given listener | +| `AddRoute` | Add a route to a listener (write-through to DB) | +| `RemoveRoute` | Remove a route from a listener (write-through to DB) | +| `GetFirewallRules` | List all firewall rules | +| `AddFirewallRule` | Add a firewall rule (write-through to DB) | +| `RemoveFirewallRule` | Remove a firewall rule (write-through to DB) | +| `GetStatus` | Return version, uptime, listener status, connection counts | + +### Input Validation + +The admin API validates all inputs before persisting: + +- **Route backends** must be valid `host:port` tuples. +- **IP firewall rules** must be valid IP addresses (`netip.ParseAddr`). +- **CIDR firewall rules** must be valid prefixes in canonical form. +- **Country firewall rules** must be exactly 2 uppercase letters (ISO 3166-1 alpha-2). + +### Security + +The gRPC admin API has no MCIAS integration — mc-proxy is pre-auth +infrastructure. Access control relies on: + +1. **Network binding**: bind to `127.0.0.1` (default) to restrict to local access. +2. **mTLS**: configure `client_ca` to require client certificates. + +If the admin API is exposed on a non-loopback interface without mTLS, +any network client can modify routing and firewall rules. + +--- + ## Configuration TOML configuration file, loaded at startup. The proxy refuses to start if @@ -210,12 +251,16 @@ addr = ":9443" hostname = "mcias.metacircular.net" backend = "127.0.0.1:28443" -# gRPC admin API. Optional — omit addr to disable. +# gRPC admin API. Optional — omit or leave addr empty to disable. +# If enabled, tls_cert and tls_key are required (TLS 1.3 only). +# client_ca enables mTLS and is strongly recommended for non-loopback addresses. +# ca_cert is used by the `status` CLI command to verify the server certificate. [grpc] addr = "127.0.0.1:9090" tls_cert = "/srv/mc-proxy/certs/cert.pem" tls_key = "/srv/mc-proxy/certs/key.pem" client_ca = "/srv/mc-proxy/certs/ca.pem" +ca_cert = "/srv/mc-proxy/certs/ca.pem" # Firewall. Global blocklist, evaluated before routing. Default allow. [firewall] @@ -333,6 +378,8 @@ Multi-stage Docker build: | File | Purpose | |------|---------| | `mc-proxy.service` | Main proxy service | +| `mc-proxy-backup.service` | Oneshot database backup (VACUUM INTO) | +| `mc-proxy-backup.timer` | Daily backup timer (02:00 UTC, 5-minute jitter) | The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE` in the systemd unit rather than running as root. @@ -354,9 +401,10 @@ On `SIGHUP`: 1. Reload the GeoIP database from disk. 2. Continue serving with the updated database. -Configuration changes (routes, listeners, firewall rules) require a full -restart. Hot reload of routing rules is deferred to the future SQLite-backed -implementation. +Routes and firewall rules can be modified at runtime via the gRPC admin API +(write-through to SQLite). Listener changes (adding/removing ports) require +a full restart. TOML configuration changes (timeouts, log level, GeoIP path) +also require a restart. --- @@ -373,7 +421,7 @@ It has no authentication or authorization of its own. | Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). | | GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. | | GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. | -| Slowloris / slow ClientHello | Timeout on the SNI extraction phase. If a complete ClientHello is not received within a reasonable window (e.g. 10s), the connection is reset. | +| Slowloris / slow ClientHello | Hardcoded 10-second timeout on the SNI extraction phase. If a complete ClientHello is not received within this window, the connection is reset. | | Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. | | Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. | diff --git a/cmd/mc-proxy/server.go b/cmd/mc-proxy/server.go index f96c624..17453fb 100644 --- a/cmd/mc-proxy/server.go +++ b/cmd/mc-proxy/server.go @@ -68,7 +68,7 @@ func serverCmd() *cobra.Command { } // Load firewall rules from DB. - fw, err := loadFirewallFromDB(store, cfg.Firewall.GeoIPDB) + fw, err := loadFirewallFromDB(store, cfg.Firewall) if err != nil { return err } @@ -90,7 +90,12 @@ func serverCmd() *cobra.Command { logger.Error("gRPC server error", "error", err) } }() - defer grpcSrv.GracefulStop() + defer func() { + grpcSrv.GracefulStop() + if cfg.GRPC.IsUnixSocket() { + os.Remove(cfg.GRPC.SocketPath()) + } + }() } // SIGHUP reloads the GeoIP database. @@ -140,7 +145,7 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) { return result, nil } -func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall, error) { +func loadFirewallFromDB(store *db.Store, fwCfg config.Firewall) (*firewall.Firewall, error) { rules, err := store.ListFirewallRules() if err != nil { return nil, fmt.Errorf("loading firewall rules: %w", err) @@ -158,7 +163,7 @@ func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall, } } - fw, err := firewall.New(geoIPPath, ips, cidrs, countries) + fw, err := firewall.New(fwCfg.GeoIPDB, ips, cidrs, countries, fwCfg.RateLimit, fwCfg.RateWindow.Duration) if err != nil { return nil, fmt.Errorf("initializing firewall: %w", err) } diff --git a/cmd/mc-proxy/snapshot.go b/cmd/mc-proxy/snapshot.go index 4c36570..13c47a6 100644 --- a/cmd/mc-proxy/snapshot.go +++ b/cmd/mc-proxy/snapshot.go @@ -2,7 +2,9 @@ package main import ( "fmt" + "os" "path/filepath" + "strings" "time" "github.com/spf13/cobra" @@ -32,10 +34,29 @@ func snapshotCmd() *cobra.Command { } defer store.Close() + dataDir := filepath.Dir(cfg.Database.Path) + if outputPath == "" { - dir := filepath.Dir(cfg.Database.Path) ts := time.Now().UTC().Format("20060102T150405Z") - outputPath = filepath.Join(dir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts)) + outputPath = filepath.Join(dataDir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts)) + } + + // Validate the output path is within the data directory. + absOutput, err := filepath.Abs(filepath.Clean(outputPath)) + if err != nil { + return fmt.Errorf("resolving output path: %w", err) + } + absDataDir, err := filepath.Abs(dataDir) + if err != nil { + return fmt.Errorf("resolving data directory: %w", err) + } + if !strings.HasPrefix(absOutput, absDataDir+string(os.PathSeparator)) { + return fmt.Errorf("output path must be within the data directory (%s)", absDataDir) + } + + // Ensure the parent directory exists. + if err := os.MkdirAll(filepath.Dir(outputPath), 0700); err != nil { + return fmt.Errorf("creating backup directory: %w", err) } if err := store.Snapshot(outputPath); err != nil { diff --git a/cmd/mc-proxy/status.go b/cmd/mc-proxy/status.go index cf7c720..0f77a46 100644 --- a/cmd/mc-proxy/status.go +++ b/cmd/mc-proxy/status.go @@ -11,6 +11,7 @@ import ( "github.com/spf13/cobra" "google.golang.org/grpc" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" "git.wntrmute.dev/kyle/mc-proxy/internal/config" @@ -70,6 +71,11 @@ func statusCmd() *cobra.Command { } func dialGRPC(cfg config.GRPC) (*grpc.ClientConn, error) { + if cfg.IsUnixSocket() { + return grpc.NewClient("unix://"+cfg.SocketPath(), + grpc.WithTransportCredentials(insecure.NewCredentials())) + } + tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS13, } diff --git a/deploy/mc-proxy.toml.example b/deploy/mc-proxy.toml.example index 47c4ec8..df1f28e 100644 --- a/deploy/mc-proxy.toml.example +++ b/deploy/mc-proxy.toml.example @@ -1,23 +1,49 @@ # mc-proxy configuration +# +# This file seeds the database on first run. After that, the database is +# the source of truth — listener, route, and firewall fields here are ignored. -# Listeners. Each entry binds a TCP listener on the specified address. +# Database. Required. +[database] +path = "/srv/mc-proxy/mc-proxy.db" + +# Listeners. Each listener binds a TCP port and has its own route table. [[listeners]] addr = ":443" + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" + [[listeners]] addr = ":8443" + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + [[listeners]] addr = ":9443" -# Routes. SNI hostname → backend address. -[[routes]] -hostname = "metacrypt.metacircular.net" -backend = "127.0.0.1:18443" + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" -[[routes]] -hostname = "mcias.metacircular.net" -backend = "127.0.0.1:28443" +# gRPC admin API. Optional — omit or leave addr empty to disable. +# If enabled over TCP, tls_cert and tls_key are required. mTLS (client_ca) +# is strongly recommended for any non-loopback listen address. +[grpc] +addr = "127.0.0.1:9090" +tls_cert = "/srv/mc-proxy/certs/cert.pem" +tls_key = "/srv/mc-proxy/certs/key.pem" +client_ca = "/srv/mc-proxy/certs/ca.pem" # mTLS; omit to disable client auth + +# Unix socket alternative (no TLS needed, secured by file permissions): +# addr = "/srv/mc-proxy/admin.sock" # Firewall. Global blocklist, evaluated before routing. Default allow. [firewall] @@ -25,6 +51,8 @@ geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb" blocked_ips = [] blocked_cidrs = [] blocked_countries = ["KP", "CN", "IN", "IL"] +rate_limit = 100 # max connections per source IP per window (0 = disabled) +rate_window = "1m" # sliding window duration (required if rate_limit > 0) # Proxy behavior. [proxy] diff --git a/internal/config/config.go b/internal/config/config.go index 3f34f7d..16e7074 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -3,6 +3,7 @@ package config import ( "fmt" "os" + "strings" "time" "github.com/pelletier/go-toml/v2" @@ -44,6 +45,8 @@ type Firewall struct { BlockedIPs []string `toml:"blocked_ips"` BlockedCIDRs []string `toml:"blocked_cidrs"` BlockedCountries []string `toml:"blocked_countries"` + RateLimit int64 `toml:"rate_limit"` + RateWindow Duration `toml:"rate_window"` } type Proxy struct { @@ -61,6 +64,18 @@ type Duration struct { time.Duration } +// IsUnixSocket returns true if the gRPC address refers to a Unix domain socket. +func (g GRPC) IsUnixSocket() bool { + path := strings.TrimPrefix(g.Addr, "unix:") + return strings.Contains(path, "/") +} + +// SocketPath returns the filesystem path for a Unix socket address, +// stripping any "unix:" prefix. +func (g GRPC) SocketPath() string { + return strings.TrimPrefix(g.Addr, "unix:") +} + func (d *Duration) UnmarshalText(text []byte) error { var err error d.Duration, err = time.ParseDuration(string(text)) @@ -114,5 +129,34 @@ func (c *Config) validate() error { return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set") } + if c.Firewall.RateLimit < 0 { + return fmt.Errorf("firewall.rate_limit must not be negative") + } + if c.Firewall.RateWindow.Duration < 0 { + return fmt.Errorf("firewall.rate_window must not be negative") + } + if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 { + return fmt.Errorf("firewall.rate_window is required when rate_limit is set") + } + + // Validate gRPC config: if enabled, TLS cert and key are required + // (unless using a Unix socket, which doesn't need TLS). + if c.GRPC.Addr != "" && !c.GRPC.IsUnixSocket() { + if c.GRPC.TLSCert == "" || c.GRPC.TLSKey == "" { + return fmt.Errorf("grpc: tls_cert and tls_key are required when grpc.addr is a TCP address") + } + } + + // Validate timeouts are non-negative. + if c.Proxy.ConnectTimeout.Duration < 0 { + return fmt.Errorf("proxy.connect_timeout must not be negative") + } + if c.Proxy.IdleTimeout.Duration < 0 { + return fmt.Errorf("proxy.idle_timeout must not be negative") + } + if c.Proxy.ShutdownTimeout.Duration < 0 { + return fmt.Errorf("proxy.shutdown_timeout must not be negative") + } + return nil } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index 909ce4a..d31c7e1 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -195,6 +195,113 @@ addr = ":8443" } } +func TestGRPCIsUnixSocket(t *testing.T) { + tests := []struct { + addr string + want bool + }{ + {"/var/run/mc-proxy.sock", true}, + {"unix:/var/run/mc-proxy.sock", true}, + {"127.0.0.1:9090", false}, + {":9090", false}, + {"", false}, + } + for _, tt := range tests { + g := GRPC{Addr: tt.addr} + if got := g.IsUnixSocket(); got != tt.want { + t.Fatalf("IsUnixSocket(%q) = %v, want %v", tt.addr, got, tt.want) + } + } +} + +func TestValidateGRPCUnixNoTLS(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[grpc] +addr = "/var/run/mc-proxy.sock" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err != nil { + t.Fatalf("expected Unix socket without TLS to be valid, got: %v", err) + } +} + +func TestValidateGRPCTCPRequiresTLS(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[grpc] +addr = "127.0.0.1:9090" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for TCP gRPC addr without TLS certs") + } +} + +func TestValidateRateLimitRequiresWindow(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[firewall] +rate_limit = 100 +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for rate_limit without rate_window") + } +} + +func TestValidateRateLimitWithWindow(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[firewall] +rate_limit = 100 +rate_window = "1m" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Firewall.RateLimit != 100 { + t.Fatalf("got rate_limit %d, want 100", cfg.Firewall.RateLimit) + } +} + func TestDuration(t *testing.T) { var d Duration if err := d.UnmarshalText([]byte("5s")); err != nil { diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 13669a4..931f8ef 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -5,9 +5,9 @@ import ( "net/netip" "strings" "sync" + "time" "github.com/oschwald/maxminddb-golang" - ) type geoIPRecord struct { @@ -23,17 +23,23 @@ type Firewall struct { blockedCountries map[string]struct{} geoDBPath string geoDB *maxminddb.Reader + rl *rateLimiter mu sync.RWMutex // protects all mutable state } // New creates a Firewall from raw rule lists and an optional GeoIP database path. -func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) { +// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window. +func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) { f := &Firewall{ blockedIPs: make(map[netip.Addr]struct{}), blockedCountries: make(map[string]struct{}), geoDBPath: geoIPPath, } + if rateLimit > 0 && rateWindow > 0 { + f.rl = newRateLimiter(rateLimit, rateWindow) + } + for _, ip := range ips { addr, err := netip.ParseAddr(ip) if err != nil { @@ -89,6 +95,12 @@ func (f *Firewall) Blocked(addr netip.Addr) bool { } } + // Rate limiting is checked after blocklist — no point tracking state + // for already-blocked IPs. + if f.rl != nil && !f.rl.Allow(addr) { + return true + } + return false } @@ -190,6 +202,10 @@ func (f *Firewall) ReloadGeoIP() error { // Close releases resources held by the firewall. func (f *Firewall) Close() error { + if f.rl != nil { + f.rl.Stop() + } + f.mu.Lock() defer f.mu.Unlock() diff --git a/internal/firewall/firewall_test.go b/internal/firewall/firewall_test.go index 6b9655e..5ef38d7 100644 --- a/internal/firewall/firewall_test.go +++ b/internal/firewall/firewall_test.go @@ -3,10 +3,11 @@ package firewall import ( "net/netip" "testing" + "time" ) func TestEmptyFirewall(t *testing.T) { - fw, err := New("", nil, nil, nil) + fw, err := New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -22,7 +23,7 @@ func TestEmptyFirewall(t *testing.T) { } func TestIPBlocking(t *testing.T) { - fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil) + fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -47,7 +48,7 @@ func TestIPBlocking(t *testing.T) { } func TestCIDRBlocking(t *testing.T) { - fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil) + fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -73,7 +74,7 @@ func TestCIDRBlocking(t *testing.T) { } func TestIPv4MappedIPv6(t *testing.T) { - fw, err := New("", []string{"192.0.2.1"}, nil, nil) + fw, err := New("", []string{"192.0.2.1"}, nil, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -86,21 +87,21 @@ func TestIPv4MappedIPv6(t *testing.T) { } func TestInvalidIP(t *testing.T) { - _, err := New("", []string{"not-an-ip"}, nil, nil) + _, err := New("", []string{"not-an-ip"}, nil, nil, 0, 0) if err == nil { t.Fatal("expected error for invalid IP") } } func TestInvalidCIDR(t *testing.T) { - _, err := New("", nil, []string{"not-a-cidr"}, nil) + _, err := New("", nil, []string{"not-a-cidr"}, nil, 0, 0) if err == nil { t.Fatal("expected error for invalid CIDR") } } func TestCombinedRules(t *testing.T) { - fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil) + fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } @@ -124,8 +125,53 @@ func TestCombinedRules(t *testing.T) { } } +func TestRateLimitBlocking(t *testing.T) { + fw, err := New("", nil, nil, nil, 2, time.Minute) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + addr := netip.MustParseAddr("10.0.0.1") + + if fw.Blocked(addr) { + t.Fatal("first request should be allowed") + } + if fw.Blocked(addr) { + t.Fatal("second request should be allowed") + } + if !fw.Blocked(addr) { + t.Fatal("third request should be blocked (limit=2)") + } +} + +func TestRateLimitBlocklistFirst(t *testing.T) { + // A blocklisted IP should be blocked without consuming rate limit quota. + fw, err := New("", []string{"10.0.0.1"}, nil, nil, 1, time.Minute) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + blockedAddr := netip.MustParseAddr("10.0.0.1") + otherAddr := netip.MustParseAddr("10.0.0.2") + + // Blocked by blocklist — should not touch the rate limiter. + if !fw.Blocked(blockedAddr) { + t.Fatal("blocklisted IP should be blocked") + } + + // Other address should still have its full rate limit quota. + if fw.Blocked(otherAddr) { + t.Fatal("other IP should be allowed (within rate limit)") + } + if !fw.Blocked(otherAddr) { + t.Fatal("other IP should be blocked after exceeding rate limit") + } +} + func TestRuntimeMutation(t *testing.T) { - fw, err := New("", nil, nil, nil) + fw, err := New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("unexpected error: %v", err) } diff --git a/internal/firewall/ratelimit.go b/internal/firewall/ratelimit.go new file mode 100644 index 0000000..9b8738a --- /dev/null +++ b/internal/firewall/ratelimit.go @@ -0,0 +1,81 @@ +package firewall + +import ( + "net/netip" + "sync" + "sync/atomic" + "time" +) + +type rateLimitEntry struct { + count atomic.Int64 + start atomic.Int64 // UnixNano +} + +type rateLimiter struct { + limit int64 + window time.Duration + entries sync.Map // netip.Addr → *rateLimitEntry + now func() time.Time + done chan struct{} +} + +func newRateLimiter(limit int64, window time.Duration) *rateLimiter { + rl := &rateLimiter{ + limit: limit, + window: window, + now: time.Now, + done: make(chan struct{}), + } + + go rl.cleanup() + return rl +} + +// Allow checks whether the given address is within its rate limit window. +// Returns false if the address has exceeded the limit. +func (rl *rateLimiter) Allow(addr netip.Addr) bool { + now := rl.now().UnixNano() + + val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{}) + entry := val.(*rateLimitEntry) + + windowStart := entry.start.Load() + if now-windowStart >= rl.window.Nanoseconds() { + // Window expired — reset. Intentionally non-atomic across the two + // stores; worst case a few extra connections slip through at the + // boundary, which is acceptable for a connection rate limiter. + entry.start.Store(now) + entry.count.Store(1) + return true + } + + n := entry.count.Add(1) + return n <= rl.limit +} + +// Stop terminates the cleanup goroutine. +func (rl *rateLimiter) Stop() { + close(rl.done) +} + +func (rl *rateLimiter) cleanup() { + ticker := time.NewTicker(rl.window) + defer ticker.Stop() + + for { + select { + case <-rl.done: + return + case <-ticker.C: + cutoff := rl.now().Add(-2 * rl.window).UnixNano() + rl.entries.Range(func(key, value any) bool { + entry := value.(*rateLimitEntry) + if entry.start.Load() < cutoff { + rl.entries.Delete(key) + } + return true + }) + } + } +} diff --git a/internal/firewall/ratelimit_test.go b/internal/firewall/ratelimit_test.go new file mode 100644 index 0000000..7d599aa --- /dev/null +++ b/internal/firewall/ratelimit_test.go @@ -0,0 +1,95 @@ +package firewall + +import ( + "net/netip" + "sync/atomic" + "testing" + "time" +) + +func TestRateLimiterAllow(t *testing.T) { + rl := newRateLimiter(3, time.Minute) + defer rl.Stop() + + addr := netip.MustParseAddr("10.0.0.1") + + for i := 0; i < 3; i++ { + if !rl.Allow(addr) { + t.Fatalf("call %d: expected Allow=true", i+1) + } + } + + if rl.Allow(addr) { + t.Fatal("call 4: expected Allow=false (over limit)") + } +} + +func TestRateLimiterDifferentIPs(t *testing.T) { + rl := newRateLimiter(2, time.Minute) + defer rl.Stop() + + a := netip.MustParseAddr("10.0.0.1") + b := netip.MustParseAddr("10.0.0.2") + + // Exhaust a's limit. + for i := 0; i < 2; i++ { + rl.Allow(a) + } + if rl.Allow(a) { + t.Fatal("a should be rate limited") + } + + // b should be independent. + if !rl.Allow(b) { + t.Fatal("b should not be rate limited") + } +} + +func TestRateLimiterWindowReset(t *testing.T) { + rl := newRateLimiter(2, time.Minute) + defer rl.Stop() + + var fakeNow atomic.Int64 + fakeNow.Store(time.Now().UnixNano()) + rl.now = func() time.Time { + return time.Unix(0, fakeNow.Load()) + } + + addr := netip.MustParseAddr("10.0.0.1") + + // Exhaust the limit. + rl.Allow(addr) + rl.Allow(addr) + if rl.Allow(addr) { + t.Fatal("should be rate limited") + } + + // Advance past the window. + fakeNow.Add(int64(2 * time.Minute)) + + // Should be allowed again. + if !rl.Allow(addr) { + t.Fatal("should be allowed after window reset") + } +} + +func TestRateLimiterCleanup(t *testing.T) { + rl := newRateLimiter(10, 50*time.Millisecond) + defer rl.Stop() + + addr := netip.MustParseAddr("10.0.0.1") + rl.Allow(addr) + + // Entry should exist. + if _, ok := rl.entries.Load(addr); !ok { + t.Fatal("entry should exist") + } + + // Wait for 2*window + a cleanup cycle to pass. + time.Sleep(200 * time.Millisecond) + + // Entry should have been cleaned up. + if _, ok := rl.entries.Load(addr); ok { + t.Fatal("stale entry should have been cleaned up") + } +} diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index a0cff93..9ddd353 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -7,7 +7,9 @@ import ( "fmt" "log/slog" "net" + "net/netip" "os" + "regexp" "strings" "google.golang.org/grpc" @@ -22,6 +24,8 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) +var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`) + // AdminServer implements the ProxyAdmin gRPC service. type AdminServer struct { pb.UnimplementedProxyAdminServiceServer @@ -30,8 +34,43 @@ type AdminServer struct { logger *slog.Logger } -// New creates a gRPC server with TLS and optional mTLS. +// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP +// addresses, TLS is required with optional mTLS. func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) { + admin := &AdminServer{ + srv: srv, + store: store, + logger: logger, + } + + if cfg.IsUnixSocket() { + return newUnixServer(cfg, admin) + } + return newTCPServer(cfg, admin) +} + +func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) { + path := cfg.SocketPath() + + // Remove stale socket file from a previous run. + os.Remove(path) + + ln, err := net.Listen("unix", path) + if err != nil { + return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err) + } + + if err := os.Chmod(path, 0600); err != nil { + ln.Close() + return nil, nil, fmt.Errorf("setting socket permissions: %w", err) + } + + grpcServer := grpc.NewServer() + pb.RegisterProxyAdminServiceServer(grpcServer, admin) + return grpcServer, ln, nil +} + +func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) { cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) if err != nil { return nil, nil, fmt.Errorf("loading TLS keypair: %w", err) @@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg creds := credentials.NewTLS(tlsConfig) grpcServer := grpc.NewServer(grpc.Creds(creds)) - - admin := &AdminServer{ - srv: srv, - store: store, - logger: logger, - } pb.RegisterProxyAdminServiceServer(grpcServer, admin) ln, err := net.Listen("tcp", cfg.Addr) @@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb. return nil, status.Error(codes.InvalidArgument, "hostname and backend are required") } + // Validate backend is a valid host:port. + if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err) + } + ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err @@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule return nil, status.Error(codes.InvalidArgument, "value is required") } + // Validate the value matches the rule type before persisting. + switch ruleType { + case "ip": + if _, err := netip.ParseAddr(req.Rule.Value); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err) + } + case "cidr": + prefix, err := netip.ParsePrefix(req.Rule.Value) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err) + } + // Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16). + if prefix.Masked().String() != req.Rule.Value { + return nil, status.Errorf(codes.InvalidArgument, + "CIDR not in canonical form: use %s", prefix.Masked().String()) + } + case "country": + if !countryCodeRe.MatchString(req.Rule.Value) { + return nil, status.Error(codes.InvalidArgument, + "country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)") + } + } + // Write-through: DB first, then memory. if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) diff --git a/internal/grpcserver/grpcserver_test.go b/internal/grpcserver/grpcserver_test.go index 9d38c49..3eb03bb 100644 --- a/internal/grpcserver/grpcserver_test.go +++ b/internal/grpcserver/grpcserver_test.go @@ -62,7 +62,7 @@ func setup(t *testing.T) *testEnv { } // Build server with matching in-memory state. - fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil) + fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0) if err != nil { t.Fatalf("firewall: %v", err) } @@ -268,6 +268,15 @@ func TestAddRouteValidation(t *testing.T) { if err == nil { t.Fatal("expected error for empty backend") } + + // Invalid backend (not host:port). + _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"}, + }) + if err == nil { + t.Fatal("expected error for invalid backend address") + } } func TestRemoveRoute(t *testing.T) { @@ -410,6 +419,61 @@ func TestAddFirewallRuleValidation(t *testing.T) { if err == nil { t.Fatal("expected error for empty value") } + + // Invalid IP address. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + Value: "not-an-ip", + }, + }) + if err == nil { + t.Fatal("expected error for invalid IP") + } + + // Invalid CIDR. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, + Value: "not-a-cidr", + }, + }) + if err == nil { + t.Fatal("expected error for invalid CIDR") + } + + // Non-canonical CIDR. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, + Value: "192.168.1.5/16", + }, + }) + if err == nil { + t.Fatal("expected error for non-canonical CIDR") + } + + // Invalid country code (lowercase). + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, + Value: "cn", + }, + }) + if err == nil { + t.Fatal("expected error for lowercase country code") + } + + // Invalid country code (too long). + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, + Value: "USA", + }, + }) + if err == nil { + t.Fatal("expected error for 3-letter country code") + } } func TestRemoveFirewallRule(t *testing.T) { diff --git a/internal/server/server.go b/internal/server/server.go index 712e148..27b92da 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -24,6 +24,8 @@ type ListenerState struct { routes map[string]string // lowercase hostname → backend addr mu sync.RWMutex ActiveConnections atomic.Int64 + activeConns map[net.Conn]struct{} // tracked for forced shutdown + connMu sync.Mutex } // Routes returns a snapshot of the listener's route table. @@ -100,9 +102,10 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, var listeners []*ListenerState for _, ld := range listenerData { listeners = append(listeners, &ListenerState{ - ID: ld.ID, - Addr: ld.Addr, - routes: ld.Routes, + ID: ld.ID, + Addr: ld.Addr, + routes: ld.Routes, + activeConns: make(map[net.Conn]struct{}), }) } @@ -186,6 +189,9 @@ func (s *Server) Run(ctx context.Context) error { s.logger.Info("all connections drained") case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration): s.logger.Warn("shutdown timeout exceeded, forcing close") + // Force-close all listener connections to unblock relay goroutines. + s.forceCloseAll() + <-done } s.fw.Close() @@ -214,11 +220,31 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) } } +// forceCloseAll closes all tracked connections across all listeners. +func (s *Server) forceCloseAll() { + for _, ls := range s.listeners { + ls.connMu.Lock() + for conn := range ls.activeConns { + conn.Close() + } + ls.connMu.Unlock() + } +} + func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) { defer s.wg.Done() defer ls.ActiveConnections.Add(-1) defer conn.Close() + ls.connMu.Lock() + ls.activeConns[conn] = struct{}{} + ls.connMu.Unlock() + defer func() { + ls.connMu.Lock() + delete(ls.activeConns, conn) + ls.connMu.Unlock() + }() + remoteAddr := conn.RemoteAddr().String() addrPort, err := netip.ParseAddrPort(remoteAddr) if err != nil { diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 5605ac7..fcba86a 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -28,7 +28,7 @@ func echoServer(t *testing.T, ln net.Listener) { // newTestServer creates a Server with the given listener data and no firewall rules. func newTestServer(t *testing.T, listeners []ListenerData) *Server { t.Helper() - fw, err := firewall.New("", nil, nil, nil) + fw, err := firewall.New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } @@ -195,7 +195,7 @@ func TestFirewallBlocks(t *testing.T) { proxyLn.Close() // Create a firewall that blocks 127.0.0.1 (the test client). - fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil) + fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } @@ -599,7 +599,7 @@ func TestGracefulShutdown(t *testing.T) { proxyAddr := proxyLn.Addr().String() proxyLn.Close() - fw, err := firewall.New("", nil, nil, nil) + fw, err := firewall.New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } diff --git a/internal/sni/sni.go b/internal/sni/sni.go index 9fe8de1..11325be 100644 --- a/internal/sni/sni.go +++ b/internal/sni/sni.go @@ -165,6 +165,9 @@ func parseServerNameExtension(data []byte) (string, error) { } if nameType == 0x00 { // hostname + if nameLen > 255 { + return "", fmt.Errorf("SNI hostname exceeds 255 bytes") + } return strings.ToLower(string(data[:nameLen])), nil }