Files
mc-proxy/internal/firewall/firewall.go
Kyle Isom 9cba3241e8 Add SQLite persistence and write-through gRPC mutations
Database (internal/db) stores listeners, routes, and firewall rules with
WAL mode, foreign keys, and idempotent migrations. First run seeds from
TOML config; subsequent runs load from DB as source of truth.

gRPC admin API now writes to the database before updating in-memory state
(write-through cache pattern). Adds snapshot command for VACUUM INTO
backups. Refactors firewall.New to accept raw rule slices instead of
config struct for flexibility.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 03:07:30 -07:00

218 lines
4.7 KiB
Go

package firewall
import (
"fmt"
"net/netip"
"strings"
"sync"
"github.com/oschwald/maxminddb-golang"
)
type geoIPRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Firewall evaluates global blocklist rules against connection source addresses.
type Firewall struct {
blockedIPs map[netip.Addr]struct{}
blockedCIDRs []netip.Prefix
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from raw rule lists and an optional GeoIP database path.
func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: geoIPPath,
}
for _, ip := range ips {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
}
f.blockedIPs[addr] = struct{}{}
}
for _, cidr := range cidrs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
}
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
}
for _, code := range countries {
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
}
if len(f.blockedCountries) > 0 && geoIPPath != "" {
if err := f.loadGeoDB(geoIPPath); err != nil {
return nil, fmt.Errorf("loading GeoIP database: %w", err)
}
}
return f, nil
}
// Blocked returns true if the given address should be blocked.
func (f *Firewall) Blocked(addr netip.Addr) bool {
addr = addr.Unmap()
f.mu.RLock()
defer f.mu.RUnlock()
if _, ok := f.blockedIPs[addr]; ok {
return true
}
for _, prefix := range f.blockedCIDRs {
if prefix.Contains(addr) {
return true
}
}
if len(f.blockedCountries) > 0 && f.geoDB != nil {
var record geoIPRecord
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
return true
}
}
}
return false
}
// AddIP adds an IP address to the blocklist.
func (f *Firewall) AddIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
f.blockedIPs[addr] = struct{}{}
f.mu.Unlock()
return nil
}
// RemoveIP removes an IP address from the blocklist.
func (f *Firewall) RemoveIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
delete(f.blockedIPs, addr)
f.mu.Unlock()
return nil
}
// AddCIDR adds a CIDR prefix to the blocklist.
func (f *Firewall) AddCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
f.mu.Unlock()
return nil
}
// RemoveCIDR removes a CIDR prefix from the blocklist.
func (f *Firewall) RemoveCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
for i, p := range f.blockedCIDRs {
if p == prefix {
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
break
}
}
f.mu.Unlock()
return nil
}
// AddCountry adds a country code to the blocklist.
func (f *Firewall) AddCountry(code string) {
f.mu.Lock()
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
f.mu.Unlock()
}
// RemoveCountry removes a country code from the blocklist.
func (f *Firewall) RemoveCountry(code string) {
f.mu.Lock()
delete(f.blockedCountries, strings.ToUpper(code))
f.mu.Unlock()
}
// Rules returns a snapshot of all current firewall rules.
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
f.mu.RLock()
defer f.mu.RUnlock()
for addr := range f.blockedIPs {
ips = append(ips, addr.String())
}
for _, prefix := range f.blockedCIDRs {
cidrs = append(cidrs, prefix.String())
}
for code := range f.blockedCountries {
countries = append(countries, code)
}
return
}
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
func (f *Firewall) ReloadGeoIP() error {
if f.geoDBPath == "" {
return nil
}
return f.loadGeoDB(f.geoDBPath)
}
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.geoDB != nil {
return f.geoDB.Close()
}
return nil
}
func (f *Firewall) loadGeoDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
f.mu.Lock()
old := f.geoDB
f.geoDB = db
f.mu.Unlock()
if old != nil {
old.Close()
}
return nil
}