Add per-IP rate limiting and Unix socket support for gRPC admin API
Rate limiting: per-source-IP connection rate limiter in the firewall layer with configurable limit and sliding window. Blocklisted IPs are rejected before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC admin API can now listen on a Unix domain socket (no TLS required), secured by file permissions (0600), as a simpler alternative for local-only access. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -174,6 +174,47 @@ gRPC admin API.
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
## gRPC Admin API
|
||||||
|
|
||||||
|
The admin API is optional (disabled if `[grpc]` is omitted from the config).
|
||||||
|
When enabled, it requires TLS and supports optional mTLS for client
|
||||||
|
authentication. TLS 1.3 is enforced. The API provides runtime management
|
||||||
|
of routes and firewall rules without restarting the proxy.
|
||||||
|
|
||||||
|
### RPCs
|
||||||
|
|
||||||
|
| RPC | Description |
|
||||||
|
|-----|-------------|
|
||||||
|
| `ListRoutes` | List all routes for a given listener |
|
||||||
|
| `AddRoute` | Add a route to a listener (write-through to DB) |
|
||||||
|
| `RemoveRoute` | Remove a route from a listener (write-through to DB) |
|
||||||
|
| `GetFirewallRules` | List all firewall rules |
|
||||||
|
| `AddFirewallRule` | Add a firewall rule (write-through to DB) |
|
||||||
|
| `RemoveFirewallRule` | Remove a firewall rule (write-through to DB) |
|
||||||
|
| `GetStatus` | Return version, uptime, listener status, connection counts |
|
||||||
|
|
||||||
|
### Input Validation
|
||||||
|
|
||||||
|
The admin API validates all inputs before persisting:
|
||||||
|
|
||||||
|
- **Route backends** must be valid `host:port` tuples.
|
||||||
|
- **IP firewall rules** must be valid IP addresses (`netip.ParseAddr`).
|
||||||
|
- **CIDR firewall rules** must be valid prefixes in canonical form.
|
||||||
|
- **Country firewall rules** must be exactly 2 uppercase letters (ISO 3166-1 alpha-2).
|
||||||
|
|
||||||
|
### Security
|
||||||
|
|
||||||
|
The gRPC admin API has no MCIAS integration — mc-proxy is pre-auth
|
||||||
|
infrastructure. Access control relies on:
|
||||||
|
|
||||||
|
1. **Network binding**: bind to `127.0.0.1` (default) to restrict to local access.
|
||||||
|
2. **mTLS**: configure `client_ca` to require client certificates.
|
||||||
|
|
||||||
|
If the admin API is exposed on a non-loopback interface without mTLS,
|
||||||
|
any network client can modify routing and firewall rules.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
TOML configuration file, loaded at startup. The proxy refuses to start if
|
TOML configuration file, loaded at startup. The proxy refuses to start if
|
||||||
@@ -210,12 +251,16 @@ addr = ":9443"
|
|||||||
hostname = "mcias.metacircular.net"
|
hostname = "mcias.metacircular.net"
|
||||||
backend = "127.0.0.1:28443"
|
backend = "127.0.0.1:28443"
|
||||||
|
|
||||||
# gRPC admin API. Optional — omit addr to disable.
|
# gRPC admin API. Optional — omit or leave addr empty to disable.
|
||||||
|
# If enabled, tls_cert and tls_key are required (TLS 1.3 only).
|
||||||
|
# client_ca enables mTLS and is strongly recommended for non-loopback addresses.
|
||||||
|
# ca_cert is used by the `status` CLI command to verify the server certificate.
|
||||||
[grpc]
|
[grpc]
|
||||||
addr = "127.0.0.1:9090"
|
addr = "127.0.0.1:9090"
|
||||||
tls_cert = "/srv/mc-proxy/certs/cert.pem"
|
tls_cert = "/srv/mc-proxy/certs/cert.pem"
|
||||||
tls_key = "/srv/mc-proxy/certs/key.pem"
|
tls_key = "/srv/mc-proxy/certs/key.pem"
|
||||||
client_ca = "/srv/mc-proxy/certs/ca.pem"
|
client_ca = "/srv/mc-proxy/certs/ca.pem"
|
||||||
|
ca_cert = "/srv/mc-proxy/certs/ca.pem"
|
||||||
|
|
||||||
# Firewall. Global blocklist, evaluated before routing. Default allow.
|
# Firewall. Global blocklist, evaluated before routing. Default allow.
|
||||||
[firewall]
|
[firewall]
|
||||||
@@ -333,6 +378,8 @@ Multi-stage Docker build:
|
|||||||
| File | Purpose |
|
| File | Purpose |
|
||||||
|------|---------|
|
|------|---------|
|
||||||
| `mc-proxy.service` | Main proxy service |
|
| `mc-proxy.service` | Main proxy service |
|
||||||
|
| `mc-proxy-backup.service` | Oneshot database backup (VACUUM INTO) |
|
||||||
|
| `mc-proxy-backup.timer` | Daily backup timer (02:00 UTC, 5-minute jitter) |
|
||||||
|
|
||||||
The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE`
|
The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE`
|
||||||
in the systemd unit rather than running as root.
|
in the systemd unit rather than running as root.
|
||||||
@@ -354,9 +401,10 @@ On `SIGHUP`:
|
|||||||
1. Reload the GeoIP database from disk.
|
1. Reload the GeoIP database from disk.
|
||||||
2. Continue serving with the updated database.
|
2. Continue serving with the updated database.
|
||||||
|
|
||||||
Configuration changes (routes, listeners, firewall rules) require a full
|
Routes and firewall rules can be modified at runtime via the gRPC admin API
|
||||||
restart. Hot reload of routing rules is deferred to the future SQLite-backed
|
(write-through to SQLite). Listener changes (adding/removing ports) require
|
||||||
implementation.
|
a full restart. TOML configuration changes (timeouts, log level, GeoIP path)
|
||||||
|
also require a restart.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -373,7 +421,7 @@ It has no authentication or authorization of its own.
|
|||||||
| Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). |
|
| Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). |
|
||||||
| GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. |
|
| GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. |
|
||||||
| GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. |
|
| GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. |
|
||||||
| Slowloris / slow ClientHello | Timeout on the SNI extraction phase. If a complete ClientHello is not received within a reasonable window (e.g. 10s), the connection is reset. |
|
| Slowloris / slow ClientHello | Hardcoded 10-second timeout on the SNI extraction phase. If a complete ClientHello is not received within this window, the connection is reset. |
|
||||||
| Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. |
|
| Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. |
|
||||||
| Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. |
|
| Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. |
|
||||||
|
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ func serverCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Load firewall rules from DB.
|
// Load firewall rules from DB.
|
||||||
fw, err := loadFirewallFromDB(store, cfg.Firewall.GeoIPDB)
|
fw, err := loadFirewallFromDB(store, cfg.Firewall)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -90,7 +90,12 @@ func serverCmd() *cobra.Command {
|
|||||||
logger.Error("gRPC server error", "error", err)
|
logger.Error("gRPC server error", "error", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
defer grpcSrv.GracefulStop()
|
defer func() {
|
||||||
|
grpcSrv.GracefulStop()
|
||||||
|
if cfg.GRPC.IsUnixSocket() {
|
||||||
|
os.Remove(cfg.GRPC.SocketPath())
|
||||||
|
}
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SIGHUP reloads the GeoIP database.
|
// SIGHUP reloads the GeoIP database.
|
||||||
@@ -140,7 +145,7 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
|
|||||||
return result, nil
|
return result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall, error) {
|
func loadFirewallFromDB(store *db.Store, fwCfg config.Firewall) (*firewall.Firewall, error) {
|
||||||
rules, err := store.ListFirewallRules()
|
rules, err := store.ListFirewallRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("loading firewall rules: %w", err)
|
return nil, fmt.Errorf("loading firewall rules: %w", err)
|
||||||
@@ -158,7 +163,7 @@ func loadFirewallFromDB(store *db.Store, geoIPPath string) (*firewall.Firewall,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fw, err := firewall.New(geoIPPath, ips, cidrs, countries)
|
fw, err := firewall.New(fwCfg.GeoIPDB, ips, cidrs, countries, fwCfg.RateLimit, fwCfg.RateWindow.Duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("initializing firewall: %w", err)
|
return nil, fmt.Errorf("initializing firewall: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,9 @@ package main
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
@@ -32,10 +34,29 @@ func snapshotCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
defer store.Close()
|
defer store.Close()
|
||||||
|
|
||||||
|
dataDir := filepath.Dir(cfg.Database.Path)
|
||||||
|
|
||||||
if outputPath == "" {
|
if outputPath == "" {
|
||||||
dir := filepath.Dir(cfg.Database.Path)
|
|
||||||
ts := time.Now().UTC().Format("20060102T150405Z")
|
ts := time.Now().UTC().Format("20060102T150405Z")
|
||||||
outputPath = filepath.Join(dir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts))
|
outputPath = filepath.Join(dataDir, "backups", fmt.Sprintf("mc-proxy-%s.db", ts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate the output path is within the data directory.
|
||||||
|
absOutput, err := filepath.Abs(filepath.Clean(outputPath))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolving output path: %w", err)
|
||||||
|
}
|
||||||
|
absDataDir, err := filepath.Abs(dataDir)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("resolving data directory: %w", err)
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(absOutput, absDataDir+string(os.PathSeparator)) {
|
||||||
|
return fmt.Errorf("output path must be within the data directory (%s)", absDataDir)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the parent directory exists.
|
||||||
|
if err := os.MkdirAll(filepath.Dir(outputPath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("creating backup directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := store.Snapshot(outputPath); err != nil {
|
if err := store.Snapshot(outputPath); err != nil {
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||||
@@ -70,6 +71,11 @@ func statusCmd() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func dialGRPC(cfg config.GRPC) (*grpc.ClientConn, error) {
|
func dialGRPC(cfg config.GRPC) (*grpc.ClientConn, error) {
|
||||||
|
if cfg.IsUnixSocket() {
|
||||||
|
return grpc.NewClient("unix://"+cfg.SocketPath(),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
}
|
||||||
|
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
MinVersion: tls.VersionTLS13,
|
MinVersion: tls.VersionTLS13,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,23 +1,49 @@
|
|||||||
# mc-proxy configuration
|
# mc-proxy configuration
|
||||||
|
#
|
||||||
|
# This file seeds the database on first run. After that, the database is
|
||||||
|
# the source of truth — listener, route, and firewall fields here are ignored.
|
||||||
|
|
||||||
# Listeners. Each entry binds a TCP listener on the specified address.
|
# Database. Required.
|
||||||
|
[database]
|
||||||
|
path = "/srv/mc-proxy/mc-proxy.db"
|
||||||
|
|
||||||
|
# Listeners. Each listener binds a TCP port and has its own route table.
|
||||||
[[listeners]]
|
[[listeners]]
|
||||||
addr = ":443"
|
addr = ":443"
|
||||||
|
|
||||||
|
[[listeners.routes]]
|
||||||
|
hostname = "metacrypt.metacircular.net"
|
||||||
|
backend = "127.0.0.1:18443"
|
||||||
|
|
||||||
|
[[listeners.routes]]
|
||||||
|
hostname = "mcias.metacircular.net"
|
||||||
|
backend = "127.0.0.1:28443"
|
||||||
|
|
||||||
[[listeners]]
|
[[listeners]]
|
||||||
addr = ":8443"
|
addr = ":8443"
|
||||||
|
|
||||||
|
[[listeners.routes]]
|
||||||
|
hostname = "metacrypt.metacircular.net"
|
||||||
|
backend = "127.0.0.1:18443"
|
||||||
|
|
||||||
[[listeners]]
|
[[listeners]]
|
||||||
addr = ":9443"
|
addr = ":9443"
|
||||||
|
|
||||||
# Routes. SNI hostname → backend address.
|
[[listeners.routes]]
|
||||||
[[routes]]
|
hostname = "mcias.metacircular.net"
|
||||||
hostname = "metacrypt.metacircular.net"
|
backend = "127.0.0.1:28443"
|
||||||
backend = "127.0.0.1:18443"
|
|
||||||
|
|
||||||
[[routes]]
|
# gRPC admin API. Optional — omit or leave addr empty to disable.
|
||||||
hostname = "mcias.metacircular.net"
|
# If enabled over TCP, tls_cert and tls_key are required. mTLS (client_ca)
|
||||||
backend = "127.0.0.1:28443"
|
# is strongly recommended for any non-loopback listen address.
|
||||||
|
[grpc]
|
||||||
|
addr = "127.0.0.1:9090"
|
||||||
|
tls_cert = "/srv/mc-proxy/certs/cert.pem"
|
||||||
|
tls_key = "/srv/mc-proxy/certs/key.pem"
|
||||||
|
client_ca = "/srv/mc-proxy/certs/ca.pem" # mTLS; omit to disable client auth
|
||||||
|
|
||||||
|
# Unix socket alternative (no TLS needed, secured by file permissions):
|
||||||
|
# addr = "/srv/mc-proxy/admin.sock"
|
||||||
|
|
||||||
# Firewall. Global blocklist, evaluated before routing. Default allow.
|
# Firewall. Global blocklist, evaluated before routing. Default allow.
|
||||||
[firewall]
|
[firewall]
|
||||||
@@ -25,6 +51,8 @@ geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb"
|
|||||||
blocked_ips = []
|
blocked_ips = []
|
||||||
blocked_cidrs = []
|
blocked_cidrs = []
|
||||||
blocked_countries = ["KP", "CN", "IN", "IL"]
|
blocked_countries = ["KP", "CN", "IN", "IL"]
|
||||||
|
rate_limit = 100 # max connections per source IP per window (0 = disabled)
|
||||||
|
rate_window = "1m" # sliding window duration (required if rate_limit > 0)
|
||||||
|
|
||||||
# Proxy behavior.
|
# Proxy behavior.
|
||||||
[proxy]
|
[proxy]
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package config
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pelletier/go-toml/v2"
|
"github.com/pelletier/go-toml/v2"
|
||||||
@@ -44,6 +45,8 @@ type Firewall struct {
|
|||||||
BlockedIPs []string `toml:"blocked_ips"`
|
BlockedIPs []string `toml:"blocked_ips"`
|
||||||
BlockedCIDRs []string `toml:"blocked_cidrs"`
|
BlockedCIDRs []string `toml:"blocked_cidrs"`
|
||||||
BlockedCountries []string `toml:"blocked_countries"`
|
BlockedCountries []string `toml:"blocked_countries"`
|
||||||
|
RateLimit int64 `toml:"rate_limit"`
|
||||||
|
RateWindow Duration `toml:"rate_window"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
@@ -61,6 +64,18 @@ type Duration struct {
|
|||||||
time.Duration
|
time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsUnixSocket returns true if the gRPC address refers to a Unix domain socket.
|
||||||
|
func (g GRPC) IsUnixSocket() bool {
|
||||||
|
path := strings.TrimPrefix(g.Addr, "unix:")
|
||||||
|
return strings.Contains(path, "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
// SocketPath returns the filesystem path for a Unix socket address,
|
||||||
|
// stripping any "unix:" prefix.
|
||||||
|
func (g GRPC) SocketPath() string {
|
||||||
|
return strings.TrimPrefix(g.Addr, "unix:")
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Duration) UnmarshalText(text []byte) error {
|
func (d *Duration) UnmarshalText(text []byte) error {
|
||||||
var err error
|
var err error
|
||||||
d.Duration, err = time.ParseDuration(string(text))
|
d.Duration, err = time.ParseDuration(string(text))
|
||||||
@@ -114,5 +129,34 @@ func (c *Config) validate() error {
|
|||||||
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
|
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.Firewall.RateLimit < 0 {
|
||||||
|
return fmt.Errorf("firewall.rate_limit must not be negative")
|
||||||
|
}
|
||||||
|
if c.Firewall.RateWindow.Duration < 0 {
|
||||||
|
return fmt.Errorf("firewall.rate_window must not be negative")
|
||||||
|
}
|
||||||
|
if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 {
|
||||||
|
return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate gRPC config: if enabled, TLS cert and key are required
|
||||||
|
// (unless using a Unix socket, which doesn't need TLS).
|
||||||
|
if c.GRPC.Addr != "" && !c.GRPC.IsUnixSocket() {
|
||||||
|
if c.GRPC.TLSCert == "" || c.GRPC.TLSKey == "" {
|
||||||
|
return fmt.Errorf("grpc: tls_cert and tls_key are required when grpc.addr is a TCP address")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate timeouts are non-negative.
|
||||||
|
if c.Proxy.ConnectTimeout.Duration < 0 {
|
||||||
|
return fmt.Errorf("proxy.connect_timeout must not be negative")
|
||||||
|
}
|
||||||
|
if c.Proxy.IdleTimeout.Duration < 0 {
|
||||||
|
return fmt.Errorf("proxy.idle_timeout must not be negative")
|
||||||
|
}
|
||||||
|
if c.Proxy.ShutdownTimeout.Duration < 0 {
|
||||||
|
return fmt.Errorf("proxy.shutdown_timeout must not be negative")
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -195,6 +195,113 @@ addr = ":8443"
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestGRPCIsUnixSocket(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
addr string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{"/var/run/mc-proxy.sock", true},
|
||||||
|
{"unix:/var/run/mc-proxy.sock", true},
|
||||||
|
{"127.0.0.1:9090", false},
|
||||||
|
{":9090", false},
|
||||||
|
{"", false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
g := GRPC{Addr: tt.addr}
|
||||||
|
if got := g.IsUnixSocket(); got != tt.want {
|
||||||
|
t.Fatalf("IsUnixSocket(%q) = %v, want %v", tt.addr, got, tt.want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateGRPCUnixNoTLS(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.toml")
|
||||||
|
|
||||||
|
data := `
|
||||||
|
[database]
|
||||||
|
path = "/tmp/test.db"
|
||||||
|
|
||||||
|
[grpc]
|
||||||
|
addr = "/var/run/mc-proxy.sock"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("expected Unix socket without TLS to be valid, got: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateGRPCTCPRequiresTLS(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.toml")
|
||||||
|
|
||||||
|
data := `
|
||||||
|
[database]
|
||||||
|
path = "/tmp/test.db"
|
||||||
|
|
||||||
|
[grpc]
|
||||||
|
addr = "127.0.0.1:9090"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for TCP gRPC addr without TLS certs")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRateLimitRequiresWindow(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.toml")
|
||||||
|
|
||||||
|
data := `
|
||||||
|
[database]
|
||||||
|
path = "/tmp/test.db"
|
||||||
|
|
||||||
|
[firewall]
|
||||||
|
rate_limit = 100
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := Load(path)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for rate_limit without rate_window")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestValidateRateLimitWithWindow(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
path := filepath.Join(dir, "test.toml")
|
||||||
|
|
||||||
|
data := `
|
||||||
|
[database]
|
||||||
|
path = "/tmp/test.db"
|
||||||
|
|
||||||
|
[firewall]
|
||||||
|
rate_limit = 100
|
||||||
|
rate_window = "1m"
|
||||||
|
`
|
||||||
|
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||||
|
t.Fatalf("write config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := Load(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if cfg.Firewall.RateLimit != 100 {
|
||||||
|
t.Fatalf("got rate_limit %d, want 100", cfg.Firewall.RateLimit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestDuration(t *testing.T) {
|
func TestDuration(t *testing.T) {
|
||||||
var d Duration
|
var d Duration
|
||||||
if err := d.UnmarshalText([]byte("5s")); err != nil {
|
if err := d.UnmarshalText([]byte("5s")); err != nil {
|
||||||
|
|||||||
@@ -5,9 +5,9 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/oschwald/maxminddb-golang"
|
"github.com/oschwald/maxminddb-golang"
|
||||||
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type geoIPRecord struct {
|
type geoIPRecord struct {
|
||||||
@@ -23,17 +23,23 @@ type Firewall struct {
|
|||||||
blockedCountries map[string]struct{}
|
blockedCountries map[string]struct{}
|
||||||
geoDBPath string
|
geoDBPath string
|
||||||
geoDB *maxminddb.Reader
|
geoDB *maxminddb.Reader
|
||||||
|
rl *rateLimiter
|
||||||
mu sync.RWMutex // protects all mutable state
|
mu sync.RWMutex // protects all mutable state
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
|
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
|
||||||
func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) {
|
// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window.
|
||||||
|
func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) {
|
||||||
f := &Firewall{
|
f := &Firewall{
|
||||||
blockedIPs: make(map[netip.Addr]struct{}),
|
blockedIPs: make(map[netip.Addr]struct{}),
|
||||||
blockedCountries: make(map[string]struct{}),
|
blockedCountries: make(map[string]struct{}),
|
||||||
geoDBPath: geoIPPath,
|
geoDBPath: geoIPPath,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rateLimit > 0 && rateWindow > 0 {
|
||||||
|
f.rl = newRateLimiter(rateLimit, rateWindow)
|
||||||
|
}
|
||||||
|
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
addr, err := netip.ParseAddr(ip)
|
addr, err := netip.ParseAddr(ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -89,6 +95,12 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Rate limiting is checked after blocklist — no point tracking state
|
||||||
|
// for already-blocked IPs.
|
||||||
|
if f.rl != nil && !f.rl.Allow(addr) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -190,6 +202,10 @@ func (f *Firewall) ReloadGeoIP() error {
|
|||||||
|
|
||||||
// Close releases resources held by the firewall.
|
// Close releases resources held by the firewall.
|
||||||
func (f *Firewall) Close() error {
|
func (f *Firewall) Close() error {
|
||||||
|
if f.rl != nil {
|
||||||
|
f.rl.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
f.mu.Lock()
|
f.mu.Lock()
|
||||||
defer f.mu.Unlock()
|
defer f.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ package firewall
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestEmptyFirewall(t *testing.T) {
|
func TestEmptyFirewall(t *testing.T) {
|
||||||
fw, err := New("", nil, nil, nil)
|
fw, err := New("", nil, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -22,7 +23,7 @@ func TestEmptyFirewall(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIPBlocking(t *testing.T) {
|
func TestIPBlocking(t *testing.T) {
|
||||||
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil)
|
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -47,7 +48,7 @@ func TestIPBlocking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCIDRBlocking(t *testing.T) {
|
func TestCIDRBlocking(t *testing.T) {
|
||||||
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil)
|
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -73,7 +74,7 @@ func TestCIDRBlocking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestIPv4MappedIPv6(t *testing.T) {
|
func TestIPv4MappedIPv6(t *testing.T) {
|
||||||
fw, err := New("", []string{"192.0.2.1"}, nil, nil)
|
fw, err := New("", []string{"192.0.2.1"}, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -86,21 +87,21 @@ func TestIPv4MappedIPv6(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidIP(t *testing.T) {
|
func TestInvalidIP(t *testing.T) {
|
||||||
_, err := New("", []string{"not-an-ip"}, nil, nil)
|
_, err := New("", []string{"not-an-ip"}, nil, nil, 0, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for invalid IP")
|
t.Fatal("expected error for invalid IP")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInvalidCIDR(t *testing.T) {
|
func TestInvalidCIDR(t *testing.T) {
|
||||||
_, err := New("", nil, []string{"not-a-cidr"}, nil)
|
_, err := New("", nil, []string{"not-a-cidr"}, nil, 0, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for invalid CIDR")
|
t.Fatal("expected error for invalid CIDR")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCombinedRules(t *testing.T) {
|
func TestCombinedRules(t *testing.T) {
|
||||||
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil)
|
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
@@ -124,8 +125,53 @@ func TestCombinedRules(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRateLimitBlocking(t *testing.T) {
|
||||||
|
fw, err := New("", nil, nil, nil, 2, time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer fw.Close()
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
|
||||||
|
if fw.Blocked(addr) {
|
||||||
|
t.Fatal("first request should be allowed")
|
||||||
|
}
|
||||||
|
if fw.Blocked(addr) {
|
||||||
|
t.Fatal("second request should be allowed")
|
||||||
|
}
|
||||||
|
if !fw.Blocked(addr) {
|
||||||
|
t.Fatal("third request should be blocked (limit=2)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimitBlocklistFirst(t *testing.T) {
|
||||||
|
// A blocklisted IP should be blocked without consuming rate limit quota.
|
||||||
|
fw, err := New("", []string{"10.0.0.1"}, nil, nil, 1, time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
defer fw.Close()
|
||||||
|
|
||||||
|
blockedAddr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
otherAddr := netip.MustParseAddr("10.0.0.2")
|
||||||
|
|
||||||
|
// Blocked by blocklist — should not touch the rate limiter.
|
||||||
|
if !fw.Blocked(blockedAddr) {
|
||||||
|
t.Fatal("blocklisted IP should be blocked")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Other address should still have its full rate limit quota.
|
||||||
|
if fw.Blocked(otherAddr) {
|
||||||
|
t.Fatal("other IP should be allowed (within rate limit)")
|
||||||
|
}
|
||||||
|
if !fw.Blocked(otherAddr) {
|
||||||
|
t.Fatal("other IP should be blocked after exceeding rate limit")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRuntimeMutation(t *testing.T) {
|
func TestRuntimeMutation(t *testing.T) {
|
||||||
fw, err := New("", nil, nil, nil)
|
fw, err := New("", nil, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
81
internal/firewall/ratelimit.go
Normal file
81
internal/firewall/ratelimit.go
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type rateLimitEntry struct {
|
||||||
|
count atomic.Int64
|
||||||
|
start atomic.Int64 // UnixNano
|
||||||
|
}
|
||||||
|
|
||||||
|
type rateLimiter struct {
|
||||||
|
limit int64
|
||||||
|
window time.Duration
|
||||||
|
entries sync.Map // netip.Addr → *rateLimitEntry
|
||||||
|
now func() time.Time
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRateLimiter(limit int64, window time.Duration) *rateLimiter {
|
||||||
|
rl := &rateLimiter{
|
||||||
|
limit: limit,
|
||||||
|
window: window,
|
||||||
|
now: time.Now,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
go rl.cleanup()
|
||||||
|
return rl
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allow checks whether the given address is within its rate limit window.
|
||||||
|
// Returns false if the address has exceeded the limit.
|
||||||
|
func (rl *rateLimiter) Allow(addr netip.Addr) bool {
|
||||||
|
now := rl.now().UnixNano()
|
||||||
|
|
||||||
|
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
|
||||||
|
entry := val.(*rateLimitEntry)
|
||||||
|
|
||||||
|
windowStart := entry.start.Load()
|
||||||
|
if now-windowStart >= rl.window.Nanoseconds() {
|
||||||
|
// Window expired — reset. Intentionally non-atomic across the two
|
||||||
|
// stores; worst case a few extra connections slip through at the
|
||||||
|
// boundary, which is acceptable for a connection rate limiter.
|
||||||
|
entry.start.Store(now)
|
||||||
|
entry.count.Store(1)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
n := entry.count.Add(1)
|
||||||
|
return n <= rl.limit
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop terminates the cleanup goroutine.
|
||||||
|
func (rl *rateLimiter) Stop() {
|
||||||
|
close(rl.done)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rl *rateLimiter) cleanup() {
|
||||||
|
ticker := time.NewTicker(rl.window)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-rl.done:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
|
||||||
|
rl.entries.Range(func(key, value any) bool {
|
||||||
|
entry := value.(*rateLimitEntry)
|
||||||
|
if entry.start.Load() < cutoff {
|
||||||
|
rl.entries.Delete(key)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
95
internal/firewall/ratelimit_test.go
Normal file
95
internal/firewall/ratelimit_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package firewall
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiterAllow(t *testing.T) {
|
||||||
|
rl := newRateLimiter(3, time.Minute)
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
|
if !rl.Allow(addr) {
|
||||||
|
t.Fatalf("call %d: expected Allow=true", i+1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if rl.Allow(addr) {
|
||||||
|
t.Fatal("call 4: expected Allow=false (over limit)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||||
|
rl := newRateLimiter(2, time.Minute)
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
a := netip.MustParseAddr("10.0.0.1")
|
||||||
|
b := netip.MustParseAddr("10.0.0.2")
|
||||||
|
|
||||||
|
// Exhaust a's limit.
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
rl.Allow(a)
|
||||||
|
}
|
||||||
|
if rl.Allow(a) {
|
||||||
|
t.Fatal("a should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// b should be independent.
|
||||||
|
if !rl.Allow(b) {
|
||||||
|
t.Fatal("b should not be rate limited")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterWindowReset(t *testing.T) {
|
||||||
|
rl := newRateLimiter(2, time.Minute)
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
var fakeNow atomic.Int64
|
||||||
|
fakeNow.Store(time.Now().UnixNano())
|
||||||
|
rl.now = func() time.Time {
|
||||||
|
return time.Unix(0, fakeNow.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
|
||||||
|
// Exhaust the limit.
|
||||||
|
rl.Allow(addr)
|
||||||
|
rl.Allow(addr)
|
||||||
|
if rl.Allow(addr) {
|
||||||
|
t.Fatal("should be rate limited")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Advance past the window.
|
||||||
|
fakeNow.Add(int64(2 * time.Minute))
|
||||||
|
|
||||||
|
// Should be allowed again.
|
||||||
|
if !rl.Allow(addr) {
|
||||||
|
t.Fatal("should be allowed after window reset")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRateLimiterCleanup(t *testing.T) {
|
||||||
|
rl := newRateLimiter(10, 50*time.Millisecond)
|
||||||
|
defer rl.Stop()
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
rl.Allow(addr)
|
||||||
|
|
||||||
|
// Entry should exist.
|
||||||
|
if _, ok := rl.entries.Load(addr); !ok {
|
||||||
|
t.Fatal("entry should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for 2*window + a cleanup cycle to pass.
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
// Entry should have been cleaned up.
|
||||||
|
if _, ok := rl.entries.Load(addr); ok {
|
||||||
|
t.Fatal("stale entry should have been cleaned up")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,7 +7,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -22,6 +24,8 @@ import (
|
|||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
|
||||||
|
|
||||||
// AdminServer implements the ProxyAdmin gRPC service.
|
// AdminServer implements the ProxyAdmin gRPC service.
|
||||||
type AdminServer struct {
|
type AdminServer struct {
|
||||||
pb.UnimplementedProxyAdminServiceServer
|
pb.UnimplementedProxyAdminServiceServer
|
||||||
@@ -30,8 +34,43 @@ type AdminServer struct {
|
|||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a gRPC server with TLS and optional mTLS.
|
// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP
|
||||||
|
// addresses, TLS is required with optional mTLS.
|
||||||
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||||
|
admin := &AdminServer{
|
||||||
|
srv: srv,
|
||||||
|
store: store,
|
||||||
|
logger: logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.IsUnixSocket() {
|
||||||
|
return newUnixServer(cfg, admin)
|
||||||
|
}
|
||||||
|
return newTCPServer(cfg, admin)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
|
||||||
|
path := cfg.SocketPath()
|
||||||
|
|
||||||
|
// Remove stale socket file from a previous run.
|
||||||
|
os.Remove(path)
|
||||||
|
|
||||||
|
ln, err := net.Listen("unix", path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Chmod(path, 0600); err != nil {
|
||||||
|
ln.Close()
|
||||||
|
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcServer := grpc.NewServer()
|
||||||
|
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
||||||
|
return grpcServer, ln, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
|
||||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
||||||
@@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
|
|||||||
|
|
||||||
creds := credentials.NewTLS(tlsConfig)
|
creds := credentials.NewTLS(tlsConfig)
|
||||||
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
||||||
|
|
||||||
admin := &AdminServer{
|
|
||||||
srv: srv,
|
|
||||||
store: store,
|
|
||||||
logger: logger,
|
|
||||||
}
|
|
||||||
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
||||||
|
|
||||||
ln, err := net.Listen("tcp", cfg.Addr)
|
ln, err := net.Listen("tcp", cfg.Addr)
|
||||||
@@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
|
|||||||
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate backend is a valid host:port.
|
||||||
|
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
ls, err := a.findListener(req.ListenerAddr)
|
ls, err := a.findListener(req.ListenerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule
|
|||||||
return nil, status.Error(codes.InvalidArgument, "value is required")
|
return nil, status.Error(codes.InvalidArgument, "value is required")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate the value matches the rule type before persisting.
|
||||||
|
switch ruleType {
|
||||||
|
case "ip":
|
||||||
|
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
|
||||||
|
}
|
||||||
|
case "cidr":
|
||||||
|
prefix, err := netip.ParsePrefix(req.Rule.Value)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
|
||||||
|
}
|
||||||
|
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
|
||||||
|
if prefix.Masked().String() != req.Rule.Value {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument,
|
||||||
|
"CIDR not in canonical form: use %s", prefix.Masked().String())
|
||||||
|
}
|
||||||
|
case "country":
|
||||||
|
if !countryCodeRe.MatchString(req.Rule.Value) {
|
||||||
|
return nil, status.Error(codes.InvalidArgument,
|
||||||
|
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Write-through: DB first, then memory.
|
// Write-through: DB first, then memory.
|
||||||
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
|
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
|
||||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func setup(t *testing.T) *testEnv {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build server with matching in-memory state.
|
// Build server with matching in-memory state.
|
||||||
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil)
|
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("firewall: %v", err)
|
t.Fatalf("firewall: %v", err)
|
||||||
}
|
}
|
||||||
@@ -268,6 +268,15 @@ func TestAddRouteValidation(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for empty backend")
|
t.Fatal("expected error for empty backend")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Invalid backend (not host:port).
|
||||||
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid backend address")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveRoute(t *testing.T) {
|
func TestRemoveRoute(t *testing.T) {
|
||||||
@@ -410,6 +419,61 @@ func TestAddFirewallRuleValidation(t *testing.T) {
|
|||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected error for empty value")
|
t.Fatal("expected error for empty value")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Invalid IP address.
|
||||||
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||||
|
Rule: &pb.FirewallRule{
|
||||||
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||||
|
Value: "not-an-ip",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid IP")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid CIDR.
|
||||||
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||||
|
Rule: &pb.FirewallRule{
|
||||||
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||||
|
Value: "not-a-cidr",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid CIDR")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-canonical CIDR.
|
||||||
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||||
|
Rule: &pb.FirewallRule{
|
||||||
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||||
|
Value: "192.168.1.5/16",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for non-canonical CIDR")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid country code (lowercase).
|
||||||
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||||
|
Rule: &pb.FirewallRule{
|
||||||
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||||
|
Value: "cn",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for lowercase country code")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid country code (too long).
|
||||||
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||||
|
Rule: &pb.FirewallRule{
|
||||||
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||||
|
Value: "USA",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for 3-letter country code")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRemoveFirewallRule(t *testing.T) {
|
func TestRemoveFirewallRule(t *testing.T) {
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ type ListenerState struct {
|
|||||||
routes map[string]string // lowercase hostname → backend addr
|
routes map[string]string // lowercase hostname → backend addr
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
ActiveConnections atomic.Int64
|
ActiveConnections atomic.Int64
|
||||||
|
activeConns map[net.Conn]struct{} // tracked for forced shutdown
|
||||||
|
connMu sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// Routes returns a snapshot of the listener's route table.
|
// Routes returns a snapshot of the listener's route table.
|
||||||
@@ -100,9 +102,10 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData,
|
|||||||
var listeners []*ListenerState
|
var listeners []*ListenerState
|
||||||
for _, ld := range listenerData {
|
for _, ld := range listenerData {
|
||||||
listeners = append(listeners, &ListenerState{
|
listeners = append(listeners, &ListenerState{
|
||||||
ID: ld.ID,
|
ID: ld.ID,
|
||||||
Addr: ld.Addr,
|
Addr: ld.Addr,
|
||||||
routes: ld.Routes,
|
routes: ld.Routes,
|
||||||
|
activeConns: make(map[net.Conn]struct{}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -186,6 +189,9 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
s.logger.Info("all connections drained")
|
s.logger.Info("all connections drained")
|
||||||
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
|
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
|
||||||
s.logger.Warn("shutdown timeout exceeded, forcing close")
|
s.logger.Warn("shutdown timeout exceeded, forcing close")
|
||||||
|
// Force-close all listener connections to unblock relay goroutines.
|
||||||
|
s.forceCloseAll()
|
||||||
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
s.fw.Close()
|
s.fw.Close()
|
||||||
@@ -214,11 +220,31 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forceCloseAll closes all tracked connections across all listeners.
|
||||||
|
func (s *Server) forceCloseAll() {
|
||||||
|
for _, ls := range s.listeners {
|
||||||
|
ls.connMu.Lock()
|
||||||
|
for conn := range ls.activeConns {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
ls.connMu.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
|
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
|
||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
defer ls.ActiveConnections.Add(-1)
|
defer ls.ActiveConnections.Add(-1)
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
|
ls.connMu.Lock()
|
||||||
|
ls.activeConns[conn] = struct{}{}
|
||||||
|
ls.connMu.Unlock()
|
||||||
|
defer func() {
|
||||||
|
ls.connMu.Lock()
|
||||||
|
delete(ls.activeConns, conn)
|
||||||
|
ls.connMu.Unlock()
|
||||||
|
}()
|
||||||
|
|
||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func echoServer(t *testing.T, ln net.Listener) {
|
|||||||
// newTestServer creates a Server with the given listener data and no firewall rules.
|
// newTestServer creates a Server with the given listener data and no firewall rules.
|
||||||
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
|
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
fw, err := firewall.New("", nil, nil, nil)
|
fw, err := firewall.New("", nil, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating firewall: %v", err)
|
t.Fatalf("creating firewall: %v", err)
|
||||||
}
|
}
|
||||||
@@ -195,7 +195,7 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
proxyLn.Close()
|
proxyLn.Close()
|
||||||
|
|
||||||
// Create a firewall that blocks 127.0.0.1 (the test client).
|
// Create a firewall that blocks 127.0.0.1 (the test client).
|
||||||
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil)
|
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating firewall: %v", err)
|
t.Fatalf("creating firewall: %v", err)
|
||||||
}
|
}
|
||||||
@@ -599,7 +599,7 @@ func TestGracefulShutdown(t *testing.T) {
|
|||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
proxyLn.Close()
|
||||||
|
|
||||||
fw, err := firewall.New("", nil, nil, nil)
|
fw, err := firewall.New("", nil, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating firewall: %v", err)
|
t.Fatalf("creating firewall: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,6 +165,9 @@ func parseServerNameExtension(data []byte) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nameType == 0x00 { // hostname
|
if nameType == 0x00 { // hostname
|
||||||
|
if nameLen > 255 {
|
||||||
|
return "", fmt.Errorf("SNI hostname exceeds 255 bytes")
|
||||||
|
}
|
||||||
return strings.ToLower(string(data[:nameLen])), nil
|
return strings.ToLower(string(data[:nameLen])), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user