Add per-IP rate limiting and Unix socket support for gRPC admin API
Rate limiting: per-source-IP connection rate limiter in the firewall layer with configurable limit and sliding window. Blocklisted IPs are rejected before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC admin API can now listen on a Unix domain socket (no TLS required), secured by file permissions (0600), as a simpler alternative for local-only access. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -5,9 +5,9 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
|
||||
)
|
||||
|
||||
type geoIPRecord struct {
|
||||
@@ -23,17 +23,23 @@ type Firewall struct {
|
||||
blockedCountries map[string]struct{}
|
||||
geoDBPath string
|
||||
geoDB *maxminddb.Reader
|
||||
rl *rateLimiter
|
||||
mu sync.RWMutex // protects all mutable state
|
||||
}
|
||||
|
||||
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
|
||||
func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) {
|
||||
// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window.
|
||||
func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) {
|
||||
f := &Firewall{
|
||||
blockedIPs: make(map[netip.Addr]struct{}),
|
||||
blockedCountries: make(map[string]struct{}),
|
||||
geoDBPath: geoIPPath,
|
||||
}
|
||||
|
||||
if rateLimit > 0 && rateWindow > 0 {
|
||||
f.rl = newRateLimiter(rateLimit, rateWindow)
|
||||
}
|
||||
|
||||
for _, ip := range ips {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
@@ -89,6 +95,12 @@ func (f *Firewall) Blocked(addr netip.Addr) bool {
|
||||
}
|
||||
}
|
||||
|
||||
// Rate limiting is checked after blocklist — no point tracking state
|
||||
// for already-blocked IPs.
|
||||
if f.rl != nil && !f.rl.Allow(addr) {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -190,6 +202,10 @@ func (f *Firewall) ReloadGeoIP() error {
|
||||
|
||||
// Close releases resources held by the firewall.
|
||||
func (f *Firewall) Close() error {
|
||||
if f.rl != nil {
|
||||
f.rl.Stop()
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
|
||||
@@ -3,10 +3,11 @@ package firewall
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEmptyFirewall(t *testing.T) {
|
||||
fw, err := New("", nil, nil, nil)
|
||||
fw, err := New("", nil, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -22,7 +23,7 @@ func TestEmptyFirewall(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPBlocking(t *testing.T) {
|
||||
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil)
|
||||
fw, err := New("", []string{"192.0.2.1", "2001:db8::dead"}, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -47,7 +48,7 @@ func TestIPBlocking(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestCIDRBlocking(t *testing.T) {
|
||||
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil)
|
||||
fw, err := New("", nil, []string{"198.51.100.0/24", "2001:db8::/32"}, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -73,7 +74,7 @@ func TestCIDRBlocking(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestIPv4MappedIPv6(t *testing.T) {
|
||||
fw, err := New("", []string{"192.0.2.1"}, nil, nil)
|
||||
fw, err := New("", []string{"192.0.2.1"}, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -86,21 +87,21 @@ func TestIPv4MappedIPv6(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInvalidIP(t *testing.T) {
|
||||
_, err := New("", []string{"not-an-ip"}, nil, nil)
|
||||
_, err := New("", []string{"not-an-ip"}, nil, nil, 0, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCIDR(t *testing.T) {
|
||||
_, err := New("", nil, []string{"not-a-cidr"}, nil)
|
||||
_, err := New("", nil, []string{"not-a-cidr"}, nil, 0, 0)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombinedRules(t *testing.T) {
|
||||
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil)
|
||||
fw, err := New("", []string{"10.0.0.1"}, []string{"192.168.0.0/16"}, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -124,8 +125,53 @@ func TestCombinedRules(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitBlocking(t *testing.T) {
|
||||
fw, err := New("", nil, nil, nil, 2, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
|
||||
if fw.Blocked(addr) {
|
||||
t.Fatal("first request should be allowed")
|
||||
}
|
||||
if fw.Blocked(addr) {
|
||||
t.Fatal("second request should be allowed")
|
||||
}
|
||||
if !fw.Blocked(addr) {
|
||||
t.Fatal("third request should be blocked (limit=2)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimitBlocklistFirst(t *testing.T) {
|
||||
// A blocklisted IP should be blocked without consuming rate limit quota.
|
||||
fw, err := New("", []string{"10.0.0.1"}, nil, nil, 1, time.Minute)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
blockedAddr := netip.MustParseAddr("10.0.0.1")
|
||||
otherAddr := netip.MustParseAddr("10.0.0.2")
|
||||
|
||||
// Blocked by blocklist — should not touch the rate limiter.
|
||||
if !fw.Blocked(blockedAddr) {
|
||||
t.Fatal("blocklisted IP should be blocked")
|
||||
}
|
||||
|
||||
// Other address should still have its full rate limit quota.
|
||||
if fw.Blocked(otherAddr) {
|
||||
t.Fatal("other IP should be allowed (within rate limit)")
|
||||
}
|
||||
if !fw.Blocked(otherAddr) {
|
||||
t.Fatal("other IP should be blocked after exceeding rate limit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRuntimeMutation(t *testing.T) {
|
||||
fw, err := New("", nil, nil, nil)
|
||||
fw, err := New("", nil, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
81
internal/firewall/ratelimit.go
Normal file
81
internal/firewall/ratelimit.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type rateLimitEntry struct {
|
||||
count atomic.Int64
|
||||
start atomic.Int64 // UnixNano
|
||||
}
|
||||
|
||||
type rateLimiter struct {
|
||||
limit int64
|
||||
window time.Duration
|
||||
entries sync.Map // netip.Addr → *rateLimitEntry
|
||||
now func() time.Time
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
func newRateLimiter(limit int64, window time.Duration) *rateLimiter {
|
||||
rl := &rateLimiter{
|
||||
limit: limit,
|
||||
window: window,
|
||||
now: time.Now,
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
|
||||
go rl.cleanup()
|
||||
return rl
|
||||
}
|
||||
|
||||
// Allow checks whether the given address is within its rate limit window.
|
||||
// Returns false if the address has exceeded the limit.
|
||||
func (rl *rateLimiter) Allow(addr netip.Addr) bool {
|
||||
now := rl.now().UnixNano()
|
||||
|
||||
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
|
||||
entry := val.(*rateLimitEntry)
|
||||
|
||||
windowStart := entry.start.Load()
|
||||
if now-windowStart >= rl.window.Nanoseconds() {
|
||||
// Window expired — reset. Intentionally non-atomic across the two
|
||||
// stores; worst case a few extra connections slip through at the
|
||||
// boundary, which is acceptable for a connection rate limiter.
|
||||
entry.start.Store(now)
|
||||
entry.count.Store(1)
|
||||
return true
|
||||
}
|
||||
|
||||
n := entry.count.Add(1)
|
||||
return n <= rl.limit
|
||||
}
|
||||
|
||||
// Stop terminates the cleanup goroutine.
|
||||
func (rl *rateLimiter) Stop() {
|
||||
close(rl.done)
|
||||
}
|
||||
|
||||
func (rl *rateLimiter) cleanup() {
|
||||
ticker := time.NewTicker(rl.window)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-rl.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
|
||||
rl.entries.Range(func(key, value any) bool {
|
||||
entry := value.(*rateLimitEntry)
|
||||
if entry.start.Load() < cutoff {
|
||||
rl.entries.Delete(key)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
95
internal/firewall/ratelimit_test.go
Normal file
95
internal/firewall/ratelimit_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiterAllow(t *testing.T) {
|
||||
rl := newRateLimiter(3, time.Minute)
|
||||
defer rl.Stop()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
|
||||
for i := 0; i < 3; i++ {
|
||||
if !rl.Allow(addr) {
|
||||
t.Fatalf("call %d: expected Allow=true", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
if rl.Allow(addr) {
|
||||
t.Fatal("call 4: expected Allow=false (over limit)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterDifferentIPs(t *testing.T) {
|
||||
rl := newRateLimiter(2, time.Minute)
|
||||
defer rl.Stop()
|
||||
|
||||
a := netip.MustParseAddr("10.0.0.1")
|
||||
b := netip.MustParseAddr("10.0.0.2")
|
||||
|
||||
// Exhaust a's limit.
|
||||
for i := 0; i < 2; i++ {
|
||||
rl.Allow(a)
|
||||
}
|
||||
if rl.Allow(a) {
|
||||
t.Fatal("a should be rate limited")
|
||||
}
|
||||
|
||||
// b should be independent.
|
||||
if !rl.Allow(b) {
|
||||
t.Fatal("b should not be rate limited")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterWindowReset(t *testing.T) {
|
||||
rl := newRateLimiter(2, time.Minute)
|
||||
defer rl.Stop()
|
||||
|
||||
var fakeNow atomic.Int64
|
||||
fakeNow.Store(time.Now().UnixNano())
|
||||
rl.now = func() time.Time {
|
||||
return time.Unix(0, fakeNow.Load())
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
|
||||
// Exhaust the limit.
|
||||
rl.Allow(addr)
|
||||
rl.Allow(addr)
|
||||
if rl.Allow(addr) {
|
||||
t.Fatal("should be rate limited")
|
||||
}
|
||||
|
||||
// Advance past the window.
|
||||
fakeNow.Add(int64(2 * time.Minute))
|
||||
|
||||
// Should be allowed again.
|
||||
if !rl.Allow(addr) {
|
||||
t.Fatal("should be allowed after window reset")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRateLimiterCleanup(t *testing.T) {
|
||||
rl := newRateLimiter(10, 50*time.Millisecond)
|
||||
defer rl.Stop()
|
||||
|
||||
addr := netip.MustParseAddr("10.0.0.1")
|
||||
rl.Allow(addr)
|
||||
|
||||
// Entry should exist.
|
||||
if _, ok := rl.entries.Load(addr); !ok {
|
||||
t.Fatal("entry should exist")
|
||||
}
|
||||
|
||||
// Wait for 2*window + a cleanup cycle to pass.
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Entry should have been cleaned up.
|
||||
if _, ok := rl.entries.Load(addr); ok {
|
||||
t.Fatal("stale entry should have been cleaned up")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user