Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
218
internal/firewall/firewall.go
Normal file
218
internal/firewall/firewall.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
)
|
||||
|
||||
type geoIPRecord struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
}
|
||||
|
||||
// Firewall evaluates global blocklist rules against connection source addresses.
|
||||
type Firewall struct {
|
||||
blockedIPs map[netip.Addr]struct{}
|
||||
blockedCIDRs []netip.Prefix
|
||||
blockedCountries map[string]struct{}
|
||||
geoDBPath string
|
||||
geoDB *maxminddb.Reader
|
||||
mu sync.RWMutex // protects all mutable state
|
||||
}
|
||||
|
||||
// New creates a Firewall from the given configuration.
|
||||
func New(cfg config.Firewall) (*Firewall, error) {
|
||||
f := &Firewall{
|
||||
blockedIPs: make(map[netip.Addr]struct{}),
|
||||
blockedCountries: make(map[string]struct{}),
|
||||
geoDBPath: cfg.GeoIPDB,
|
||||
}
|
||||
|
||||
for _, ip := range cfg.BlockedIPs {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
|
||||
}
|
||||
f.blockedIPs[addr] = struct{}{}
|
||||
}
|
||||
|
||||
for _, cidr := range cfg.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
|
||||
}
|
||||
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
|
||||
}
|
||||
|
||||
for _, code := range cfg.BlockedCountries {
|
||||
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
|
||||
}
|
||||
|
||||
if len(f.blockedCountries) > 0 {
|
||||
if err := f.loadGeoDB(cfg.GeoIPDB); err != nil {
|
||||
return nil, fmt.Errorf("loading GeoIP database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Blocked returns true if the given address should be blocked.
|
||||
func (f *Firewall) Blocked(addr netip.Addr) bool {
|
||||
addr = addr.Unmap()
|
||||
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if _, ok := f.blockedIPs[addr]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range f.blockedCIDRs {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(f.blockedCountries) > 0 && f.geoDB != nil {
|
||||
var record geoIPRecord
|
||||
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
|
||||
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddIP adds an IP address to the blocklist.
|
||||
func (f *Firewall) AddIP(ip string) error {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP %q: %w", ip, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.blockedIPs[addr] = struct{}{}
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveIP removes an IP address from the blocklist.
|
||||
func (f *Firewall) RemoveIP(ip string) error {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP %q: %w", ip, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
delete(f.blockedIPs, addr)
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddCIDR adds a CIDR prefix to the blocklist.
|
||||
func (f *Firewall) AddCIDR(cidr string) error {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveCIDR removes a CIDR prefix from the blocklist.
|
||||
func (f *Firewall) RemoveCIDR(cidr string) error {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
for i, p := range f.blockedCIDRs {
|
||||
if p == prefix {
|
||||
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddCountry adds a country code to the blocklist.
|
||||
func (f *Firewall) AddCountry(code string) {
|
||||
f.mu.Lock()
|
||||
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveCountry removes a country code from the blocklist.
|
||||
func (f *Firewall) RemoveCountry(code string) {
|
||||
f.mu.Lock()
|
||||
delete(f.blockedCountries, strings.ToUpper(code))
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// Rules returns a snapshot of all current firewall rules.
|
||||
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
for addr := range f.blockedIPs {
|
||||
ips = append(ips, addr.String())
|
||||
}
|
||||
for _, prefix := range f.blockedCIDRs {
|
||||
cidrs = append(cidrs, prefix.String())
|
||||
}
|
||||
for code := range f.blockedCountries {
|
||||
countries = append(countries, code)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
|
||||
func (f *Firewall) ReloadGeoIP() error {
|
||||
if f.geoDBPath == "" {
|
||||
return nil
|
||||
}
|
||||
return f.loadGeoDB(f.geoDBPath)
|
||||
}
|
||||
|
||||
// Close releases resources held by the firewall.
|
||||
func (f *Firewall) Close() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.geoDB != nil {
|
||||
return f.geoDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) loadGeoDB(path string) error {
|
||||
db, err := maxminddb.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
old := f.geoDB
|
||||
f.geoDB = db
|
||||
f.mu.Unlock()
|
||||
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
internal/firewall/firewall_test.go
Normal file
141
internal/firewall/firewall_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestEmptyFirewall(t *testing.T) {
|
||||
fw, err := New(config.Firewall{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
|
||||
for _, a := range addrs {
|
||||
addr := netip.MustParseAddr(a)
|
||||
if fw.Blocked(addr) {
|
||||
t.Fatalf("empty firewall blocked %s", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlocking(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"192.0.2.1", true},
|
||||
{"192.0.2.2", false},
|
||||
{"2001:db8::dead", true},
|
||||
{"2001:db8::beef", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDRBlocking(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"198.51.100.1", true},
|
||||
{"198.51.100.254", true},
|
||||
{"198.51.101.1", false},
|
||||
{"2001:db8::1", true},
|
||||
{"2001:db9::1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPv4MappedIPv6(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"192.0.2.1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
// IPv4-mapped IPv6 representation of 192.0.2.1.
|
||||
addr := netip.MustParseAddr("::ffff:192.0.2.1")
|
||||
if !fw.Blocked(addr) {
|
||||
t.Fatal("expected IPv4-mapped IPv6 address to be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidIP(t *testing.T) {
|
||||
_, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"not-an-ip"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCIDR(t *testing.T) {
|
||||
_, err := New(config.Firewall{
|
||||
BlockedCIDRs: []string{"not-a-cidr"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombinedRules(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"10.0.0.1"},
|
||||
BlockedCIDRs: []string{"192.168.0.0/16"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"10.0.0.1", true}, // IP match
|
||||
{"10.0.0.2", false}, // no match
|
||||
{"192.168.1.1", true}, // CIDR match
|
||||
{"172.16.0.1", false}, // no match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user