Add per-IP rate limiting and Unix socket support for gRPC admin API

Rate limiting: per-source-IP connection rate limiter in the firewall layer
with configurable limit and sliding window. Blocklisted IPs are rejected
before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC
admin API can now listen on a Unix domain socket (no TLS required), secured
by file permissions (0600), as a simpler alternative for local-only access.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 14:37:21 -07:00
parent e84093b7fb
commit b25e1b0e79
16 changed files with 694 additions and 43 deletions

View File

@@ -174,6 +174,47 @@ gRPC admin API.
---
## gRPC Admin API
The admin API is optional (disabled if `[grpc]` is omitted from the config).
When enabled, it requires TLS and supports optional mTLS for client
authentication. TLS 1.3 is enforced. The API provides runtime management
of routes and firewall rules without restarting the proxy.
### RPCs
| RPC | Description |
|-----|-------------|
| `ListRoutes` | List all routes for a given listener |
| `AddRoute` | Add a route to a listener (write-through to DB) |
| `RemoveRoute` | Remove a route from a listener (write-through to DB) |
| `GetFirewallRules` | List all firewall rules |
| `AddFirewallRule` | Add a firewall rule (write-through to DB) |
| `RemoveFirewallRule` | Remove a firewall rule (write-through to DB) |
| `GetStatus` | Return version, uptime, listener status, connection counts |
### Input Validation
The admin API validates all inputs before persisting:
- **Route backends** must be valid `host:port` tuples.
- **IP firewall rules** must be valid IP addresses (`netip.ParseAddr`).
- **CIDR firewall rules** must be valid prefixes in canonical form.
- **Country firewall rules** must be exactly 2 uppercase letters (ISO 3166-1 alpha-2).
### Security
The gRPC admin API has no MCIAS integration — mc-proxy is pre-auth
infrastructure. Access control relies on:
1. **Network binding**: bind to `127.0.0.1` (default) to restrict to local access.
2. **mTLS**: configure `client_ca` to require client certificates.
If the admin API is exposed on a non-loopback interface without mTLS,
any network client can modify routing and firewall rules.
---
## Configuration
TOML configuration file, loaded at startup. The proxy refuses to start if
@@ -210,12 +251,16 @@ addr = ":9443"
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
# gRPC admin API. Optional — omit addr to disable.
# gRPC admin API. Optional — omit or leave addr empty to disable.
# If enabled, tls_cert and tls_key are required (TLS 1.3 only).
# client_ca enables mTLS and is strongly recommended for non-loopback addresses.
# ca_cert is used by the `status` CLI command to verify the server certificate.
[grpc]
addr = "127.0.0.1:9090"
tls_cert = "/srv/mc-proxy/certs/cert.pem"
tls_key = "/srv/mc-proxy/certs/key.pem"
client_ca = "/srv/mc-proxy/certs/ca.pem"
ca_cert = "/srv/mc-proxy/certs/ca.pem"
# Firewall. Global blocklist, evaluated before routing. Default allow.
[firewall]
@@ -333,6 +378,8 @@ Multi-stage Docker build:
| File | Purpose |
|------|---------|
| `mc-proxy.service` | Main proxy service |
| `mc-proxy-backup.service` | Oneshot database backup (VACUUM INTO) |
| `mc-proxy-backup.timer` | Daily backup timer (02:00 UTC, 5-minute jitter) |
The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE`
in the systemd unit rather than running as root.
@@ -354,9 +401,10 @@ On `SIGHUP`:
1. Reload the GeoIP database from disk.
2. Continue serving with the updated database.
Configuration changes (routes, listeners, firewall rules) require a full
restart. Hot reload of routing rules is deferred to the future SQLite-backed
implementation.
Routes and firewall rules can be modified at runtime via the gRPC admin API
(write-through to SQLite). Listener changes (adding/removing ports) require
a full restart. TOML configuration changes (timeouts, log level, GeoIP path)
also require a restart.
---
@@ -373,7 +421,7 @@ It has no authentication or authorization of its own.
| Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). |
| GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. |
| GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. |
| Slowloris / slow ClientHello | Timeout on the SNI extraction phase. If a complete ClientHello is not received within a reasonable window (e.g. 10s), the connection is reset. |
| Slowloris / slow ClientHello | Hardcoded 10-second timeout on the SNI extraction phase. If a complete ClientHello is not received within this window, the connection is reset. |
| Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. |
| Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. |

View File

@@ -68,7 +68,7 @@ func serverCmd() *cobra.Command {
}
// Load firewall rules from DB.
fw, err := loadFirewallFromDB(store, cfg.Firewall.GeoIPDB)
fw, err := loadFirewallFromDB(store, cfg.Firewall)
if err != nil {
return err
}
@@ -90,7 +90,12 @@ func serverCmd() *cobra.Command {
logger.Error("gRPC server error", "error", err)
}
}()
defer grpcSrv.GracefulStop()
defer func() {
grpcSrv.GracefulStop()
if cfg.GRPC.IsUnixSocket() {
os.Remove(cfg.GRPC.SocketPath())
}
}()
}
// SIGHUP reloads the GeoIP database.
@@ -140,7 +145,7 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
return result, nil
}
func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall, error) {
func loadFirewallFromDB(store *db.Store, fwCfg config.Firewall) (*firewall.Firewall, error) {
rules, err := store.ListFirewallRules()
if err != nil {
return nil, fmt.Errorf("loading firewall rules: %w", err)
@@ -158,7 +163,7 @@ func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall,
}
}
fw, err := firewall.New(geoIPPath, ips, cidrs, countries)
fw, err := firewall.New(fwCfg.GeoIPDB, ips, cidrs, countries, fwCfg.RateLimit, fwCfg.RateWindow.Duration)
if err != nil {
return nil, fmt.Errorf("initializing firewall: %w", err)
}

View File

@@ -2,7 +2,9 @@ package main
import (
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/cobra"
@@ -32,10 +34,29 @@ func snapshotCmd() *cobra.Command {
}
defer store.Close()
dataDir := filepath.Dir(cfg.Database.Path)
if outputPath == "" {
dir := filepath.Dir(cfg.Database.Path)
ts := time.Now().UTC().Format("20060102T150405Z")
outputPath = filepath.Join(dir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts))
outputPath = filepath.Join(dataDir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts))
}
// Validate the output path is within the data directory.
absOutput, err := filepath.Abs(filepath.Clean(outputPath))
if err != nil {
return fmt.Errorf("resolving output path: %w", err)
}
absDataDir, err := filepath.Abs(dataDir)
if err != nil {
return fmt.Errorf("resolving data directory: %w", err)
}
if !strings.HasPrefix(absOutput, absDataDir+string(os.PathSeparator)) {
return fmt.Errorf("output path must be within the data directory (%s)", absDataDir)
}
// Ensure the parent directory exists.
if err := os.MkdirAll(filepath.Dir(outputPath), 0700); err != nil {
return fmt.Errorf("creating backup directory: %w", err)
}
if err := store.Snapshot(outputPath); err != nil {

View File

@@ -11,6 +11,7 @@ import (
"github.com/spf13/cobra"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
@@ -70,6 +71,11 @@ func statusCmd() *cobra.Command {
}
func dialGRPC(cfg config.GRPC) (*grpc.ClientConn, error) {
if cfg.IsUnixSocket() {
return grpc.NewClient("unix://"+cfg.SocketPath(),
grpc.WithTransportCredentials(insecure.NewCredentials()))
}
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS13,
}

View File

@@ -1,30 +1,58 @@
# mc-proxy configuration
#
# This file seeds the database on first run. After that, the database is
# the source of truth — listener, route, and firewall fields here are ignored.
# Listeners. Each entry binds a TCP listener on the specified address.
# Database. Required.
[database]
path = "/srv/mc-proxy/mc-proxy.db"
# Listeners. Each listener binds a TCP port and has its own route table.
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
[[listeners]]
addr = ":8443"
[[listeners]]
addr = ":9443"
# Routes. SNI hostname → backend address.
[[routes]]
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[routes]]
[[listeners]]
addr = ":9443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
# gRPC admin API. Optional — omit or leave addr empty to disable.
# If enabled over TCP, tls_cert and tls_key are required. mTLS (client_ca)
# is strongly recommended for any non-loopback listen address.
[grpc]
addr = "127.0.0.1:9090"
tls_cert = "/srv/mc-proxy/certs/cert.pem"
tls_key = "/srv/mc-proxy/certs/key.pem"
client_ca = "/srv/mc-proxy/certs/ca.pem" # mTLS; omit to disable client auth
# Unix socket alternative (no TLS needed, secured by file permissions):
# addr = "/srv/mc-proxy/admin.sock"
# Firewall. Global blocklist, evaluated before routing. Default allow.
[firewall]
geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb"
blocked_ips = []
blocked_cidrs = []
blocked_countries = ["KP", "CN", "IN", "IL"]
rate_limit = 100 # max connections per source IP per window (0 = disabled)
rate_window = "1m" # sliding window duration (required if rate_limit > 0)
# Proxy behavior.
[proxy]

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
@@ -44,6 +45,8 @@ type Firewall struct {
BlockedIPs []string `toml:"blocked_ips"`
BlockedCIDRs []string `toml:"blocked_cidrs"`
BlockedCountries []string `toml:"blocked_countries"`
RateLimit int64 `toml:"rate_limit"`
RateWindow Duration `toml:"rate_window"`
}
type Proxy struct {
@@ -61,6 +64,18 @@ type Duration struct {
time.Duration
}
// IsUnixSocket returns true if the gRPC address refers to a Unix domain socket.
func (g GRPC) IsUnixSocket() bool {
path := strings.TrimPrefix(g.Addr, "unix:")
return strings.Contains(path, "/")
}
// SocketPath returns the filesystem path for a Unix socket address,
// stripping any "unix:" prefix.
func (g GRPC) SocketPath() string {
return strings.TrimPrefix(g.Addr, "unix:")
}
func (d *Duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
@@ -114,5 +129,34 @@ func (c *Config) validate() error {
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
}
if c.Firewall.RateLimit < 0 {
return fmt.Errorf("firewall.rate_limit must not be negative")
}
if c.Firewall.RateWindow.Duration < 0 {
return fmt.Errorf("firewall.rate_window must not be negative")
}
if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 {
return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
}
// Validate gRPC config: if enabled, TLS cert and key are required
// (unless using a Unix socket, which doesn't need TLS).
if c.GRPC.Addr != "" && !c.GRPC.IsUnixSocket() {
if c.GRPC.TLSCert == "" || c.GRPC.TLSKey == "" {
return fmt.Errorf("grpc: tls_cert and tls_key are required when grpc.addr is a TCP address")
}
}
// Validate timeouts are non-negative.
if c.Proxy.ConnectTimeout.Duration < 0 {
return fmt.Errorf("proxy.connect_timeout must not be negative")
}
if c.Proxy.IdleTimeout.Duration < 0 {
return fmt.Errorf("proxy.idle_timeout must not be negative")
}
if c.Proxy.ShutdownTimeout.Duration < 0 {
return fmt.Errorf("proxy.shutdown_timeout must not be negative")
}
return nil
}

View File

@@ -195,6 +195,113 @@ addr = ":8443"
}
}
func TestGRPCIsUnixSocket(t *testing.T) {
tests := []struct {
addr string
want bool
}{
{"/var/run/mc-proxy.sock", true},
{"unix:/var/run/mc-proxy.sock", true},
{"127.0.0.1:9090", false},
{":9090", false},
{"", false},
}
for _, tt := range tests {
g := GRPC{Addr: tt.addr}
if got := g.IsUnixSocket(); got != tt.want {
t.Fatalf("IsUnixSocket(%q) = %v, want %v", tt.addr, got, tt.want)
}
}
}
func TestValidateGRPCUnixNoTLS(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[grpc]
addr = "/var/run/mc-proxy.sock"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err != nil {
t.Fatalf("expected Unix socket without TLS to be valid, got: %v", err)
}
}
func TestValidateGRPCTCPRequiresTLS(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[grpc]
addr = "127.0.0.1:9090"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for TCP gRPC addr without TLS certs")
}
}
func TestValidateRateLimitRequiresWindow(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[firewall]
rate_limit = 100
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for rate_limit without rate_window")
}
}
func TestValidateRateLimitWithWindow(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[firewall]
rate_limit = 100
rate_window = "1m"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Firewall.RateLimit != 100 {
t.Fatalf("got rate_limit %d, want 100", cfg.Firewall.RateLimit)
}
}
func TestDuration(t *testing.T) {
var d Duration
if err := d.UnmarshalText([]byte("5s")); err != nil {

View File

@@ -5,9 +5,9 @@ import (
"net/netip"
"strings"
"sync"
"time"
"github.com/oschwald/maxminddb-golang"
)
type geoIPRecord struct {
@@ -23,17 +23,23 @@ type Firewall struct {
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
rl *rateLimiter
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) {
// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window.
func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: geoIPPath,
}
if rateLimit > 0 && rateWindow > 0 {
f.rl = newRateLimiter(rateLimit, rateWindow)
}
for _, ip := range ips {
addr, err := netip.ParseAddr(ip)
if err != nil {
@@ -89,6 +95,12 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
}
}
// Rate limiting is checked after blocklist — no point tracking state
// for already-blocked IPs.
if f.rl != nil && !f.rl.Allow(addr) {
return true
}
return false
}
@@ -190,6 +202,10 @@ func (f *Firewall) ReloadGeoIP() error {
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
if f.rl != nil {
f.rl.Stop()
}
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -3,10 +3,11 @@ package firewall
import (
"net/netip"
"testing"
"time"
)
func TestEmptyFirewall(t *testing.T) {
fw, err := New("", nil, nil, nil)
fw, err := New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -22,7 +23,7 @@ func TestEmptyFirewall(t *testing.T) {
}
func TestIPBlocking(t *testing.T) {
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil)
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -47,7 +48,7 @@ func TestIPBlocking(t *testing.T) {
}
func TestCIDRBlocking(t *testing.T) {
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil)
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -73,7 +74,7 @@ func TestCIDRBlocking(t *testing.T) {
}
func TestIPv4MappedIPv6(t *testing.T) {
fw, err := New("", []string{"192.0.2.1"}, nil, nil)
fw, err := New("", []string{"192.0.2.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -86,21 +87,21 @@ func TestIPv4MappedIPv6(t *testing.T) {
}
func TestInvalidIP(t *testing.T) {
_, err := New("", []string{"not-an-ip"}, nil, nil)
_, err := New("", []string{"not-an-ip"}, nil, nil, 0, 0)
if err == nil {
t.Fatal("expected error for invalid IP")
}
}
func TestInvalidCIDR(t *testing.T) {
_, err := New("", nil, []string{"not-a-cidr"}, nil)
_, err := New("", nil, []string{"not-a-cidr"}, nil, 0, 0)
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
}
func TestCombinedRules(t *testing.T) {
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil)
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -124,8 +125,53 @@ func TestCombinedRules(t *testing.T) {
}
}
func TestRateLimitBlocking(t *testing.T) {
fw, err := New("", nil, nil, nil, 2, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
addr := netip.MustParseAddr("10.0.0.1")
if fw.Blocked(addr) {
t.Fatal("first request should be allowed")
}
if fw.Blocked(addr) {
t.Fatal("second request should be allowed")
}
if !fw.Blocked(addr) {
t.Fatal("third request should be blocked (limit=2)")
}
}
func TestRateLimitBlocklistFirst(t *testing.T) {
// A blocklisted IP should be blocked without consuming rate limit quota.
fw, err := New("", []string{"10.0.0.1"}, nil, nil, 1, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
blockedAddr := netip.MustParseAddr("10.0.0.1")
otherAddr := netip.MustParseAddr("10.0.0.2")
// Blocked by blocklist — should not touch the rate limiter.
if !fw.Blocked(blockedAddr) {
t.Fatal("blocklisted IP should be blocked")
}
// Other address should still have its full rate limit quota.
if fw.Blocked(otherAddr) {
t.Fatal("other IP should be allowed (within rate limit)")
}
if !fw.Blocked(otherAddr) {
t.Fatal("other IP should be blocked after exceeding rate limit")
}
}
func TestRuntimeMutation(t *testing.T) {
fw, err := New("", nil, nil, nil)
fw, err := New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

View File

@@ -0,0 +1,81 @@
package firewall
import (
"net/netip"
"sync"
"sync/atomic"
"time"
)
type rateLimitEntry struct {
count atomic.Int64
start atomic.Int64 // UnixNano
}
type rateLimiter struct {
limit int64
window time.Duration
entries sync.Map // netip.Addr → *rateLimitEntry
now func() time.Time
done chan struct{}
}
func newRateLimiter(limit int64, window time.Duration) *rateLimiter {
rl := &rateLimiter{
limit: limit,
window: window,
now: time.Now,
done: make(chan struct{}),
}
go rl.cleanup()
return rl
}
// Allow checks whether the given address is within its rate limit window.
// Returns false if the address has exceeded the limit.
func (rl *rateLimiter) Allow(addr netip.Addr) bool {
now := rl.now().UnixNano()
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
entry := val.(*rateLimitEntry)
windowStart := entry.start.Load()
if now-windowStart >= rl.window.Nanoseconds() {
// Window expired — reset. Intentionally non-atomic across the two
// stores; worst case a few extra connections slip through at the
// boundary, which is acceptable for a connection rate limiter.
entry.start.Store(now)
entry.count.Store(1)
return true
}
n := entry.count.Add(1)
return n <= rl.limit
}
// Stop terminates the cleanup goroutine.
func (rl *rateLimiter) Stop() {
close(rl.done)
}
func (rl *rateLimiter) cleanup() {
ticker := time.NewTicker(rl.window)
defer ticker.Stop()
for {
select {
case <-rl.done:
return
case <-ticker.C:
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
rl.entries.Range(func(key, value any) bool {
entry := value.(*rateLimitEntry)
if entry.start.Load() < cutoff {
rl.entries.Delete(key)
}
return true
})
}
}
}

View File

@@ -0,0 +1,95 @@
package firewall
import (
"net/netip"
"sync/atomic"
"testing"
"time"
)
func TestRateLimiterAllow(t *testing.T) {
rl := newRateLimiter(3, time.Minute)
defer rl.Stop()
addr := netip.MustParseAddr("10.0.0.1")
for i := 0; i < 3; i++ {
if !rl.Allow(addr) {
t.Fatalf("call %d: expected Allow=true", i+1)
}
}
if rl.Allow(addr) {
t.Fatal("call 4: expected Allow=false (over limit)")
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
rl := newRateLimiter(2, time.Minute)
defer rl.Stop()
a := netip.MustParseAddr("10.0.0.1")
b := netip.MustParseAddr("10.0.0.2")
// Exhaust a's limit.
for i := 0; i < 2; i++ {
rl.Allow(a)
}
if rl.Allow(a) {
t.Fatal("a should be rate limited")
}
// b should be independent.
if !rl.Allow(b) {
t.Fatal("b should not be rate limited")
}
}
func TestRateLimiterWindowReset(t *testing.T) {
rl := newRateLimiter(2, time.Minute)
defer rl.Stop()
var fakeNow atomic.Int64
fakeNow.Store(time.Now().UnixNano())
rl.now = func() time.Time {
return time.Unix(0, fakeNow.Load())
}
addr := netip.MustParseAddr("10.0.0.1")
// Exhaust the limit.
rl.Allow(addr)
rl.Allow(addr)
if rl.Allow(addr) {
t.Fatal("should be rate limited")
}
// Advance past the window.
fakeNow.Add(int64(2 * time.Minute))
// Should be allowed again.
if !rl.Allow(addr) {
t.Fatal("should be allowed after window reset")
}
}
func TestRateLimiterCleanup(t *testing.T) {
rl := newRateLimiter(10, 50*time.Millisecond)
defer rl.Stop()
addr := netip.MustParseAddr("10.0.0.1")
rl.Allow(addr)
// Entry should exist.
if _, ok := rl.entries.Load(addr); !ok {
t.Fatal("entry should exist")
}
// Wait for 2*window + a cleanup cycle to pass.
time.Sleep(200 * time.Millisecond)
// Entry should have been cleaned up.
if _, ok := rl.entries.Load(addr); ok {
t.Fatal("stale entry should have been cleaned up")
}
}

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"regexp"
"strings"
"google.golang.org/grpc"
@@ -22,6 +24,8 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServiceServer
@@ -30,8 +34,43 @@ type AdminServer struct {
logger *slog.Logger
}
// New creates a gRPC server with TLS and optional mTLS.
// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP
// addresses, TLS is required with optional mTLS.
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
if cfg.IsUnixSocket() {
return newUnixServer(cfg, admin)
}
return newTCPServer(cfg, admin)
}
func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
}
if err := os.Chmod(path, 0600); err != nil {
ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
grpcServer := grpc.NewServer()
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
return grpcServer, ln, nil
}
func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
if err != nil {
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
@@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(grpc.Creds(creds))
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
ln, err := net.Listen("tcp", cfg.Addr)
@@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
// Validate backend is a valid host:port.
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
@@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule
return nil, status.Error(codes.InvalidArgument, "value is required")
}
// Validate the value matches the rule type before persisting.
switch ruleType {
case "ip":
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
}
case "cidr":
prefix, err := netip.ParsePrefix(req.Rule.Value)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
}
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
if prefix.Masked().String() != req.Rule.Value {
return nil, status.Errorf(codes.InvalidArgument,
"CIDR not in canonical form: use %s", prefix.Masked().String())
}
case "country":
if !countryCodeRe.MatchString(req.Rule.Value) {
return nil, status.Error(codes.InvalidArgument,
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
}
}
// Write-through: DB first, then memory.
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)

View File

@@ -62,7 +62,7 @@ func setup(t *testing.T) *testEnv {
}
// Build server with matching in-memory state.
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil)
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("firewall: %v", err)
}
@@ -268,6 +268,15 @@ func TestAddRouteValidation(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty backend")
}
// Invalid backend (not host:port).
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
})
if err == nil {
t.Fatal("expected error for invalid backend address")
}
}
func TestRemoveRoute(t *testing.T) {
@@ -410,6 +419,61 @@ func TestAddFirewallRuleValidation(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty value")
}
// Invalid IP address.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "not-an-ip",
},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
// Invalid CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "not-a-cidr",
},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
// Non-canonical CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.1.5/16",
},
})
if err == nil {
t.Fatal("expected error for non-canonical CIDR")
}
// Invalid country code (lowercase).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "cn",
},
})
if err == nil {
t.Fatal("expected error for lowercase country code")
}
// Invalid country code (too long).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "USA",
},
})
if err == nil {
t.Fatal("expected error for 3-letter country code")
}
}
func TestRemoveFirewallRule(t *testing.T) {

View File

@@ -24,6 +24,8 @@ type ListenerState struct {
routes map[string]string // lowercase hostname → backend addr
mu sync.RWMutex
ActiveConnections atomic.Int64
activeConns map[net.Conn]struct{} // tracked for forced shutdown
connMu sync.Mutex
}
// Routes returns a snapshot of the listener's route table.
@@ -103,6 +105,7 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData,
ID: ld.ID,
Addr: ld.Addr,
routes: ld.Routes,
activeConns: make(map[net.Conn]struct{}),
})
}
@@ -186,6 +189,9 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("all connections drained")
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
s.logger.Warn("shutdown timeout exceeded, forcing close")
// Force-close all listener connections to unblock relay goroutines.
s.forceCloseAll()
<-done
}
s.fw.Close()
@@ -214,11 +220,31 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
}
}
// forceCloseAll closes all tracked connections across all listeners.
func (s *Server) forceCloseAll() {
for _, ls := range s.listeners {
ls.connMu.Lock()
for conn := range ls.activeConns {
conn.Close()
}
ls.connMu.Unlock()
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer conn.Close()
ls.connMu.Lock()
ls.activeConns[conn] = struct{}{}
ls.connMu.Unlock()
defer func() {
ls.connMu.Lock()
delete(ls.activeConns, conn)
ls.connMu.Unlock()
}()
remoteAddr := conn.RemoteAddr().String()
addrPort, err := netip.ParseAddrPort(remoteAddr)
if err != nil {

View File

@@ -28,7 +28,7 @@ func echoServer(t *testing.T, ln net.Listener) {
// newTestServer creates a Server with the given listener data and no firewall rules.
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
t.Helper()
fw, err := firewall.New("", nil, nil, nil)
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
@@ -195,7 +195,7 @@ func TestFirewallBlocks(t *testing.T) {
proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil)
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
@@ -599,7 +599,7 @@ func TestGracefulShutdown(t *testing.T) {
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil)
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}

View File

@@ -165,6 +165,9 @@ func parseServerNameExtension(data []byte) (string, error) {
}
if nameType == 0x00 { // hostname
if nameLen > 255 {
return "", fmt.Errorf("SNI hostname exceeds 255 bytes")
}
return strings.ToLower(string(data[:nameLen])), nil
}