package firewall import ( "fmt" "net/netip" "strings" "sync" "github.com/oschwald/maxminddb-golang" ) type geoIPRecord struct { Country struct { ISOCode string `maxminddb:"iso_code"` } `maxminddb:"country"` } // Firewall evaluates global blocklist rules against connection source addresses. type Firewall struct { blockedIPs map[netip.Addr]struct{} blockedCIDRs []netip.Prefix blockedCountries map[string]struct{} geoDBPath string geoDB *maxminddb.Reader mu sync.RWMutex // protects all mutable state } // New creates a Firewall from raw rule lists and an optional GeoIP database path. func New(geoIPPath string, ips, cidrs, countries []string) (*Firewall, error) { f := &Firewall{ blockedIPs: make(map[netip.Addr]struct{}), blockedCountries: make(map[string]struct{}), geoDBPath: geoIPPath, } for _, ip := range ips { addr, err := netip.ParseAddr(ip) if err != nil { return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err) } f.blockedIPs[addr] = struct{}{} } for _, cidr := range cidrs { prefix, err := netip.ParsePrefix(cidr) if err != nil { return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err) } f.blockedCIDRs = append(f.blockedCIDRs, prefix) } for _, code := range countries { f.blockedCountries[strings.ToUpper(code)] = struct{}{} } if len(f.blockedCountries) > 0 && geoIPPath != "" { if err := f.loadGeoDB(geoIPPath); err != nil { return nil, fmt.Errorf("loading GeoIP database: %w", err) } } return f, nil } // Blocked returns true if the given address should be blocked. func (f *Firewall) Blocked(addr netip.Addr) bool { addr = addr.Unmap() f.mu.RLock() defer f.mu.RUnlock() if _, ok := f.blockedIPs[addr]; ok { return true } for _, prefix := range f.blockedCIDRs { if prefix.Contains(addr) { return true } } if len(f.blockedCountries) > 0 && f.geoDB != nil { var record geoIPRecord if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil { if _, ok := f.blockedCountries[record.Country.ISOCode]; ok { return true } } } return false } // AddIP adds an IP address to the blocklist. func (f *Firewall) AddIP(ip string) error { addr, err := netip.ParseAddr(ip) if err != nil { return fmt.Errorf("invalid IP %q: %w", ip, err) } f.mu.Lock() f.blockedIPs[addr] = struct{}{} f.mu.Unlock() return nil } // RemoveIP removes an IP address from the blocklist. func (f *Firewall) RemoveIP(ip string) error { addr, err := netip.ParseAddr(ip) if err != nil { return fmt.Errorf("invalid IP %q: %w", ip, err) } f.mu.Lock() delete(f.blockedIPs, addr) f.mu.Unlock() return nil } // AddCIDR adds a CIDR prefix to the blocklist. func (f *Firewall) AddCIDR(cidr string) error { prefix, err := netip.ParsePrefix(cidr) if err != nil { return fmt.Errorf("invalid CIDR %q: %w", cidr, err) } f.mu.Lock() f.blockedCIDRs = append(f.blockedCIDRs, prefix) f.mu.Unlock() return nil } // RemoveCIDR removes a CIDR prefix from the blocklist. func (f *Firewall) RemoveCIDR(cidr string) error { prefix, err := netip.ParsePrefix(cidr) if err != nil { return fmt.Errorf("invalid CIDR %q: %w", cidr, err) } f.mu.Lock() for i, p := range f.blockedCIDRs { if p == prefix { f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...) break } } f.mu.Unlock() return nil } // AddCountry adds a country code to the blocklist. func (f *Firewall) AddCountry(code string) { f.mu.Lock() f.blockedCountries[strings.ToUpper(code)] = struct{}{} f.mu.Unlock() } // RemoveCountry removes a country code from the blocklist. func (f *Firewall) RemoveCountry(code string) { f.mu.Lock() delete(f.blockedCountries, strings.ToUpper(code)) f.mu.Unlock() } // Rules returns a snapshot of all current firewall rules. func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) { f.mu.RLock() defer f.mu.RUnlock() for addr := range f.blockedIPs { ips = append(ips, addr.String()) } for _, prefix := range f.blockedCIDRs { cidrs = append(cidrs, prefix.String()) } for code := range f.blockedCountries { countries = append(countries, code) } return } // ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use. func (f *Firewall) ReloadGeoIP() error { if f.geoDBPath == "" { return nil } return f.loadGeoDB(f.geoDBPath) } // Close releases resources held by the firewall. func (f *Firewall) Close() error { f.mu.Lock() defer f.mu.Unlock() if f.geoDB != nil { return f.geoDB.Close() } return nil } func (f *Firewall) loadGeoDB(path string) error { db, err := maxminddb.Open(path) if err != nil { return err } f.mu.Lock() old := f.geoDB f.geoDB = db f.mu.Unlock() if old != nil { old.Close() } return nil }