Add per-IP rate limiting and Unix socket support for gRPC admin API

Rate limiting: per-source-IP connection rate limiter in the firewall layer
with configurable limit and sliding window. Blocklisted IPs are rejected
before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC
admin API can now listen on a Unix domain socket (no TLS required), secured
by file permissions (0600), as a simpler alternative for local-only access.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 14:37:21 -07:00
parent e84093b7fb
commit b25e1b0e79
16 changed files with 694 additions and 43 deletions

View File

@@ -3,6 +3,7 @@ package config
import (
"fmt"
"os"
"strings"
"time"
"github.com/pelletier/go-toml/v2"
@@ -44,6 +45,8 @@ type Firewall struct {
BlockedIPs []string `toml:"blocked_ips"`
BlockedCIDRs []string `toml:"blocked_cidrs"`
BlockedCountries []string `toml:"blocked_countries"`
RateLimit int64 `toml:"rate_limit"`
RateWindow Duration `toml:"rate_window"`
}
type Proxy struct {
@@ -61,6 +64,18 @@ type Duration struct {
time.Duration
}
// IsUnixSocket returns true if the gRPC address refers to a Unix domain socket.
func (g GRPC) IsUnixSocket() bool {
path := strings.TrimPrefix(g.Addr, "unix:")
return strings.Contains(path, "/")
}
// SocketPath returns the filesystem path for a Unix socket address,
// stripping any "unix:" prefix.
func (g GRPC) SocketPath() string {
return strings.TrimPrefix(g.Addr, "unix:")
}
func (d *Duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
@@ -114,5 +129,34 @@ func (c *Config) validate() error {
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
}
if c.Firewall.RateLimit < 0 {
return fmt.Errorf("firewall.rate_limit must not be negative")
}
if c.Firewall.RateWindow.Duration < 0 {
return fmt.Errorf("firewall.rate_window must not be negative")
}
if c.Firewall.RateLimit > 0 && c.Firewall.RateWindow.Duration == 0 {
return fmt.Errorf("firewall.rate_window is required when rate_limit is set")
}
// Validate gRPC config: if enabled, TLS cert and key are required
// (unless using a Unix socket, which doesn't need TLS).
if c.GRPC.Addr != "" && !c.GRPC.IsUnixSocket() {
if c.GRPC.TLSCert == "" || c.GRPC.TLSKey == "" {
return fmt.Errorf("grpc: tls_cert and tls_key are required when grpc.addr is a TCP address")
}
}
// Validate timeouts are non-negative.
if c.Proxy.ConnectTimeout.Duration < 0 {
return fmt.Errorf("proxy.connect_timeout must not be negative")
}
if c.Proxy.IdleTimeout.Duration < 0 {
return fmt.Errorf("proxy.idle_timeout must not be negative")
}
if c.Proxy.ShutdownTimeout.Duration < 0 {
return fmt.Errorf("proxy.shutdown_timeout must not be negative")
}
return nil
}

View File

@@ -195,6 +195,113 @@ addr = ":8443"
}
}
func TestGRPCIsUnixSocket(t *testing.T) {
tests := []struct {
addr string
want bool
}{
{"/var/run/mc-proxy.sock", true},
{"unix:/var/run/mc-proxy.sock", true},
{"127.0.0.1:9090", false},
{":9090", false},
{"", false},
}
for _, tt := range tests {
g := GRPC{Addr: tt.addr}
if got := g.IsUnixSocket(); got != tt.want {
t.Fatalf("IsUnixSocket(%q) = %v, want %v", tt.addr, got, tt.want)
}
}
}
func TestValidateGRPCUnixNoTLS(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[grpc]
addr = "/var/run/mc-proxy.sock"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err != nil {
t.Fatalf("expected Unix socket without TLS to be valid, got: %v", err)
}
}
func TestValidateGRPCTCPRequiresTLS(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[grpc]
addr = "127.0.0.1:9090"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for TCP gRPC addr without TLS certs")
}
}
func TestValidateRateLimitRequiresWindow(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[firewall]
rate_limit = 100
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for rate_limit without rate_window")
}
}
func TestValidateRateLimitWithWindow(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[database]
path = "/tmp/test.db"
[firewall]
rate_limit = 100
rate_window = "1m"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if cfg.Firewall.RateLimit != 100 {
t.Fatalf("got rate_limit %d, want 100", cfg.Firewall.RateLimit)
}
}
func TestDuration(t *testing.T) {
var d Duration
if err := d.UnmarshalText([]byte("5s")); err != nil {

View File

@@ -5,9 +5,9 @@ import (
"net/netip"
"strings"
"sync"
"time"
"github.com/oschwald/maxminddb-golang"
)
type geoIPRecord struct {
@@ -23,17 +23,23 @@ type Firewall struct {
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
rl *rateLimiter
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) {
// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window.
func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: geoIPPath,
}
if rateLimit > 0 && rateWindow > 0 {
f.rl = newRateLimiter(rateLimit, rateWindow)
}
for _, ip := range ips {
addr, err := netip.ParseAddr(ip)
if err != nil {
@@ -89,6 +95,12 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
}
}
// Rate limiting is checked after blocklist — no point tracking state
// for already-blocked IPs.
if f.rl != nil && !f.rl.Allow(addr) {
return true
}
return false
}
@@ -190,6 +202,10 @@ func (f *Firewall) ReloadGeoIP() error {
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
if f.rl != nil {
f.rl.Stop()
}
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -3,10 +3,11 @@ package firewall
import (
"net/netip"
"testing"
"time"
)
func TestEmptyFirewall(t *testing.T) {
fw, err := New("", nil, nil, nil)
fw, err := New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -22,7 +23,7 @@ func TestEmptyFirewall(t *testing.T) {
}
func TestIPBlocking(t *testing.T) {
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil)
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -47,7 +48,7 @@ func TestIPBlocking(t *testing.T) {
}
func TestCIDRBlocking(t *testing.T) {
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil)
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -73,7 +74,7 @@ func TestCIDRBlocking(t *testing.T) {
}
func TestIPv4MappedIPv6(t *testing.T) {
fw, err := New("", []string{"192.0.2.1"}, nil, nil)
fw, err := New("", []string{"192.0.2.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -86,21 +87,21 @@ func TestIPv4MappedIPv6(t *testing.T) {
}
func TestInvalidIP(t *testing.T) {
_, err := New("", []string{"not-an-ip"}, nil, nil)
_, err := New("", []string{"not-an-ip"}, nil, nil, 0, 0)
if err == nil {
t.Fatal("expected error for invalid IP")
}
}
func TestInvalidCIDR(t *testing.T) {
_, err := New("", nil, []string{"not-a-cidr"}, nil)
_, err := New("", nil, []string{"not-a-cidr"}, nil, 0, 0)
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
}
func TestCombinedRules(t *testing.T) {
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil)
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
@@ -124,8 +125,53 @@ func TestCombinedRules(t *testing.T) {
}
}
func TestRateLimitBlocking(t *testing.T) {
fw, err := New("", nil, nil, nil, 2, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
addr := netip.MustParseAddr("10.0.0.1")
if fw.Blocked(addr) {
t.Fatal("first request should be allowed")
}
if fw.Blocked(addr) {
t.Fatal("second request should be allowed")
}
if !fw.Blocked(addr) {
t.Fatal("third request should be blocked (limit=2)")
}
}
func TestRateLimitBlocklistFirst(t *testing.T) {
// A blocklisted IP should be blocked without consuming rate limit quota.
fw, err := New("", []string{"10.0.0.1"}, nil, nil, 1, time.Minute)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
blockedAddr := netip.MustParseAddr("10.0.0.1")
otherAddr := netip.MustParseAddr("10.0.0.2")
// Blocked by blocklist — should not touch the rate limiter.
if !fw.Blocked(blockedAddr) {
t.Fatal("blocklisted IP should be blocked")
}
// Other address should still have its full rate limit quota.
if fw.Blocked(otherAddr) {
t.Fatal("other IP should be allowed (within rate limit)")
}
if !fw.Blocked(otherAddr) {
t.Fatal("other IP should be blocked after exceeding rate limit")
}
}
func TestRuntimeMutation(t *testing.T) {
fw, err := New("", nil, nil, nil)
fw, err := New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

View File

@@ -0,0 +1,81 @@
package firewall
import (
"net/netip"
"sync"
"sync/atomic"
"time"
)
type rateLimitEntry struct {
count atomic.Int64
start atomic.Int64 // UnixNano
}
type rateLimiter struct {
limit int64
window time.Duration
entries sync.Map // netip.Addr → *rateLimitEntry
now func() time.Time
done chan struct{}
}
func newRateLimiter(limit int64, window time.Duration) *rateLimiter {
rl := &rateLimiter{
limit: limit,
window: window,
now: time.Now,
done: make(chan struct{}),
}
go rl.cleanup()
return rl
}
// Allow checks whether the given address is within its rate limit window.
// Returns false if the address has exceeded the limit.
func (rl *rateLimiter) Allow(addr netip.Addr) bool {
now := rl.now().UnixNano()
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
entry := val.(*rateLimitEntry)
windowStart := entry.start.Load()
if now-windowStart >= rl.window.Nanoseconds() {
// Window expired — reset. Intentionally non-atomic across the two
// stores; worst case a few extra connections slip through at the
// boundary, which is acceptable for a connection rate limiter.
entry.start.Store(now)
entry.count.Store(1)
return true
}
n := entry.count.Add(1)
return n <= rl.limit
}
// Stop terminates the cleanup goroutine.
func (rl *rateLimiter) Stop() {
close(rl.done)
}
func (rl *rateLimiter) cleanup() {
ticker := time.NewTicker(rl.window)
defer ticker.Stop()
for {
select {
case <-rl.done:
return
case <-ticker.C:
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
rl.entries.Range(func(key, value any) bool {
entry := value.(*rateLimitEntry)
if entry.start.Load() < cutoff {
rl.entries.Delete(key)
}
return true
})
}
}
}

View File

@@ -0,0 +1,95 @@
package firewall
import (
"net/netip"
"sync/atomic"
"testing"
"time"
)
func TestRateLimiterAllow(t *testing.T) {
rl := newRateLimiter(3, time.Minute)
defer rl.Stop()
addr := netip.MustParseAddr("10.0.0.1")
for i := 0; i < 3; i++ {
if !rl.Allow(addr) {
t.Fatalf("call %d: expected Allow=true", i+1)
}
}
if rl.Allow(addr) {
t.Fatal("call 4: expected Allow=false (over limit)")
}
}
func TestRateLimiterDifferentIPs(t *testing.T) {
rl := newRateLimiter(2, time.Minute)
defer rl.Stop()
a := netip.MustParseAddr("10.0.0.1")
b := netip.MustParseAddr("10.0.0.2")
// Exhaust a's limit.
for i := 0; i < 2; i++ {
rl.Allow(a)
}
if rl.Allow(a) {
t.Fatal("a should be rate limited")
}
// b should be independent.
if !rl.Allow(b) {
t.Fatal("b should not be rate limited")
}
}
func TestRateLimiterWindowReset(t *testing.T) {
rl := newRateLimiter(2, time.Minute)
defer rl.Stop()
var fakeNow atomic.Int64
fakeNow.Store(time.Now().UnixNano())
rl.now = func() time.Time {
return time.Unix(0, fakeNow.Load())
}
addr := netip.MustParseAddr("10.0.0.1")
// Exhaust the limit.
rl.Allow(addr)
rl.Allow(addr)
if rl.Allow(addr) {
t.Fatal("should be rate limited")
}
// Advance past the window.
fakeNow.Add(int64(2 * time.Minute))
// Should be allowed again.
if !rl.Allow(addr) {
t.Fatal("should be allowed after window reset")
}
}
func TestRateLimiterCleanup(t *testing.T) {
rl := newRateLimiter(10, 50*time.Millisecond)
defer rl.Stop()
addr := netip.MustParseAddr("10.0.0.1")
rl.Allow(addr)
// Entry should exist.
if _, ok := rl.entries.Load(addr); !ok {
t.Fatal("entry should exist")
}
// Wait for 2*window + a cleanup cycle to pass.
time.Sleep(200 * time.Millisecond)
// Entry should have been cleaned up.
if _, ok := rl.entries.Load(addr); ok {
t.Fatal("stale entry should have been cleaned up")
}
}

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"regexp"
"strings"
"google.golang.org/grpc"
@@ -22,6 +24,8 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServiceServer
@@ -30,8 +34,43 @@ type AdminServer struct {
logger *slog.Logger
}
// New creates a gRPC server with TLS and optional mTLS.
// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP
// addresses, TLS is required with optional mTLS.
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
if cfg.IsUnixSocket() {
return newUnixServer(cfg, admin)
}
return newTCPServer(cfg, admin)
}
func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
}
if err := os.Chmod(path, 0600); err != nil {
ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
grpcServer := grpc.NewServer()
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
return grpcServer, ln, nil
}
func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
if err != nil {
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
@@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(grpc.Creds(creds))
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
ln, err := net.Listen("tcp", cfg.Addr)
@@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
// Validate backend is a valid host:port.
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
@@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule
return nil, status.Error(codes.InvalidArgument, "value is required")
}
// Validate the value matches the rule type before persisting.
switch ruleType {
case "ip":
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
}
case "cidr":
prefix, err := netip.ParsePrefix(req.Rule.Value)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
}
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
if prefix.Masked().String() != req.Rule.Value {
return nil, status.Errorf(codes.InvalidArgument,
"CIDR not in canonical form: use %s", prefix.Masked().String())
}
case "country":
if !countryCodeRe.MatchString(req.Rule.Value) {
return nil, status.Error(codes.InvalidArgument,
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
}
}
// Write-through: DB first, then memory.
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)

View File

@@ -62,7 +62,7 @@ func setup(t *testing.T) *testEnv {
}
// Build server with matching in-memory state.
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil)
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("firewall: %v", err)
}
@@ -268,6 +268,15 @@ func TestAddRouteValidation(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty backend")
}
// Invalid backend (not host:port).
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
})
if err == nil {
t.Fatal("expected error for invalid backend address")
}
}
func TestRemoveRoute(t *testing.T) {
@@ -410,6 +419,61 @@ func TestAddFirewallRuleValidation(t *testing.T) {
if err == nil {
t.Fatal("expected error for empty value")
}
// Invalid IP address.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "not-an-ip",
},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
// Invalid CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "not-a-cidr",
},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
// Non-canonical CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.1.5/16",
},
})
if err == nil {
t.Fatal("expected error for non-canonical CIDR")
}
// Invalid country code (lowercase).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "cn",
},
})
if err == nil {
t.Fatal("expected error for lowercase country code")
}
// Invalid country code (too long).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "USA",
},
})
if err == nil {
t.Fatal("expected error for 3-letter country code")
}
}
func TestRemoveFirewallRule(t *testing.T) {

View File

@@ -24,6 +24,8 @@ type ListenerState struct {
routes map[string]string // lowercase hostname → backend addr
mu sync.RWMutex
ActiveConnections atomic.Int64
activeConns map[net.Conn]struct{} // tracked for forced shutdown
connMu sync.Mutex
}
// Routes returns a snapshot of the listener's route table.
@@ -100,9 +102,10 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData,
var listeners []*ListenerState
for _, ld := range listenerData {
listeners = append(listeners, &ListenerState{
ID: ld.ID,
Addr: ld.Addr,
routes: ld.Routes,
ID: ld.ID,
Addr: ld.Addr,
routes: ld.Routes,
activeConns: make(map[net.Conn]struct{}),
})
}
@@ -186,6 +189,9 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("all connections drained")
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
s.logger.Warn("shutdown timeout exceeded, forcing close")
// Force-close all listener connections to unblock relay goroutines.
s.forceCloseAll()
<-done
}
s.fw.Close()
@@ -214,11 +220,31 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
}
}
// forceCloseAll closes all tracked connections across all listeners.
func (s *Server) forceCloseAll() {
for _, ls := range s.listeners {
ls.connMu.Lock()
for conn := range ls.activeConns {
conn.Close()
}
ls.connMu.Unlock()
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer conn.Close()
ls.connMu.Lock()
ls.activeConns[conn] = struct{}{}
ls.connMu.Unlock()
defer func() {
ls.connMu.Lock()
delete(ls.activeConns, conn)
ls.connMu.Unlock()
}()
remoteAddr := conn.RemoteAddr().String()
addrPort, err := netip.ParseAddrPort(remoteAddr)
if err != nil {

View File

@@ -28,7 +28,7 @@ func echoServer(t *testing.T, ln net.Listener) {
// newTestServer creates a Server with the given listener data and no firewall rules.
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
t.Helper()
fw, err := firewall.New("", nil, nil, nil)
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
@@ -195,7 +195,7 @@ func TestFirewallBlocks(t *testing.T) {
proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil)
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
@@ -599,7 +599,7 @@ func TestGracefulShutdown(t *testing.T) {
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil)
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}

View File

@@ -165,6 +165,9 @@ func parseServerNameExtension(data []byte) (string, error) {
}
if nameType == 0x00 { // hostname
if nameLen > 255 {
return "", fmt.Errorf("SNI hostname exceeds 255 bytes")
}
return strings.ToLower(string(data[:nameLen])), nil
}