Files
mc-proxy/internal/firewall/firewall.go
Kyle Isom a60e5cb86a Fix golangci-lint v2 compliance, make all passes clean
- Fix 314 errcheck violations (blank identifier for unrecoverable errors)
- Fix errorlint violation (errors.Is for io.EOF)
- Remove unused serveL7Route test helper
- Simplify Duration.Seconds() selectors in tests
- Remove unnecessary fmt.Sprintf in test
- Migrate exclusion rules from issues.exclusions to linters.exclusions (v2 schema)
- Add gosec test exclusions (G115, G304, G402, G705)
- Disable fieldalignment govet analyzer (optimization, not correctness)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 13:30:43 -07:00

241 lines
5.5 KiB
Go

package firewall
import (
"fmt"
"net/netip"
"strings"
"sync"
"time"
"github.com/oschwald/maxminddb-golang"
)
type geoIPRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Firewall evaluates global blocklist rules against connection source addresses.
type Firewall struct {
blockedIPs map[netip.Addr]struct{}
blockedCIDRs []netip.Prefix
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
rl *rateLimiter
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
// If rateLimit > 0, per-source-IP rate limiting is enabled with the given window.
func New(geoIPPath string, ips, cidrs, countries []string, rateLimit int64, rateWindow time.Duration) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: geoIPPath,
}
if rateLimit > 0 && rateWindow > 0 {
f.rl = newRateLimiter(rateLimit, rateWindow)
}
for _, ip := range ips {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
}
f.blockedIPs[addr] = struct{}{}
}
for _, cidr := range cidrs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
}
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
}
for _, code := range countries {
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
}
if len(f.blockedCountries) > 0 && geoIPPath != "" {
if err := f.loadGeoDB(geoIPPath); err != nil {
return nil, fmt.Errorf("loading GeoIP database: %w", err)
}
}
return f, nil
}
// Blocked returns true if the given address should be blocked.
func (f *Firewall) Blocked(addr netip.Addr) bool {
blocked, _ := f.BlockedWithReason(addr)
return blocked
}
// BlockedWithReason returns whether the address is blocked and the reason.
// Possible reasons: "ip", "cidr", "country", "rate_limit", or "" if not blocked.
func (f *Firewall) BlockedWithReason(addr netip.Addr) (bool, string) {
addr = addr.Unmap()
f.mu.RLock()
defer f.mu.RUnlock()
if _, ok := f.blockedIPs[addr]; ok {
return true, "ip"
}
for _, prefix := range f.blockedCIDRs {
if prefix.Contains(addr) {
return true, "cidr"
}
}
if len(f.blockedCountries) > 0 && f.geoDB != nil {
var record geoIPRecord
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
return true, "country"
}
}
}
// Rate limiting is checked after blocklist — no point tracking state
// for already-blocked IPs.
if f.rl != nil && !f.rl.Allow(addr) {
return true, "rate_limit"
}
return false, ""
}
// AddIP adds an IP address to the blocklist.
func (f *Firewall) AddIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
f.blockedIPs[addr] = struct{}{}
f.mu.Unlock()
return nil
}
// RemoveIP removes an IP address from the blocklist.
func (f *Firewall) RemoveIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
delete(f.blockedIPs, addr)
f.mu.Unlock()
return nil
}
// AddCIDR adds a CIDR prefix to the blocklist.
func (f *Firewall) AddCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
f.mu.Unlock()
return nil
}
// RemoveCIDR removes a CIDR prefix from the blocklist.
func (f *Firewall) RemoveCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
for i, p := range f.blockedCIDRs {
if p == prefix {
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
break
}
}
f.mu.Unlock()
return nil
}
// AddCountry adds a country code to the blocklist.
func (f *Firewall) AddCountry(code string) {
f.mu.Lock()
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
f.mu.Unlock()
}
// RemoveCountry removes a country code from the blocklist.
func (f *Firewall) RemoveCountry(code string) {
f.mu.Lock()
delete(f.blockedCountries, strings.ToUpper(code))
f.mu.Unlock()
}
// Rules returns a snapshot of all current firewall rules.
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
f.mu.RLock()
defer f.mu.RUnlock()
for addr := range f.blockedIPs {
ips = append(ips, addr.String())
}
for _, prefix := range f.blockedCIDRs {
cidrs = append(cidrs, prefix.String())
}
for code := range f.blockedCountries {
countries = append(countries, code)
}
return
}
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
func (f *Firewall) ReloadGeoIP() error {
if f.geoDBPath == "" {
return nil
}
return f.loadGeoDB(f.geoDBPath)
}
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
if f.rl != nil {
f.rl.Stop()
}
f.mu.Lock()
defer f.mu.Unlock()
if f.geoDB != nil {
return f.geoDB.Close()
}
return nil
}
func (f *Firewall) loadGeoDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
f.mu.Lock()
old := f.geoDB
f.geoDB = db
f.mu.Unlock()
if old != nil {
_ = old.Close()
}
return nil
}