Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
115
internal/config/config.go
Normal file
115
internal/config/config.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/pelletier/go-toml/v2"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Listeners []Listener `toml:"listeners"`
|
||||
GRPC GRPC `toml:"grpc"`
|
||||
Firewall Firewall `toml:"firewall"`
|
||||
Proxy Proxy `toml:"proxy"`
|
||||
Log Log `toml:"log"`
|
||||
}
|
||||
|
||||
type GRPC struct {
|
||||
Addr string `toml:"addr"`
|
||||
TLSCert string `toml:"tls_cert"`
|
||||
TLSKey string `toml:"tls_key"`
|
||||
ClientCA string `toml:"client_ca"`
|
||||
}
|
||||
|
||||
type Listener struct {
|
||||
Addr string `toml:"addr"`
|
||||
Routes []Route `toml:"routes"`
|
||||
}
|
||||
|
||||
type Route struct {
|
||||
Hostname string `toml:"hostname"`
|
||||
Backend string `toml:"backend"`
|
||||
}
|
||||
|
||||
type Firewall struct {
|
||||
GeoIPDB string `toml:"geoip_db"`
|
||||
BlockedIPs []string `toml:"blocked_ips"`
|
||||
BlockedCIDRs []string `toml:"blocked_cidrs"`
|
||||
BlockedCountries []string `toml:"blocked_countries"`
|
||||
}
|
||||
|
||||
type Proxy struct {
|
||||
ConnectTimeout Duration `toml:"connect_timeout"`
|
||||
IdleTimeout Duration `toml:"idle_timeout"`
|
||||
ShutdownTimeout Duration `toml:"shutdown_timeout"`
|
||||
}
|
||||
|
||||
type Log struct {
|
||||
Level string `toml:"level"`
|
||||
}
|
||||
|
||||
// Duration wraps time.Duration for TOML string unmarshalling.
|
||||
type Duration struct {
|
||||
time.Duration
|
||||
}
|
||||
|
||||
func (d *Duration) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
d.Duration, err = time.ParseDuration(string(text))
|
||||
return err
|
||||
}
|
||||
|
||||
func Load(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading config: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := toml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("parsing config: %w", err)
|
||||
}
|
||||
|
||||
if err := cfg.validate(); err != nil {
|
||||
return nil, fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
func (c *Config) validate() error {
|
||||
if len(c.Listeners) == 0 {
|
||||
return fmt.Errorf("at least one listener is required")
|
||||
}
|
||||
|
||||
for i, l := range c.Listeners {
|
||||
if l.Addr == "" {
|
||||
return fmt.Errorf("listener %d: addr is required", i)
|
||||
}
|
||||
if len(l.Routes) == 0 {
|
||||
return fmt.Errorf("listener %d (%s): at least one route is required", i, l.Addr)
|
||||
}
|
||||
|
||||
seen := make(map[string]bool)
|
||||
for j, r := range l.Routes {
|
||||
if r.Hostname == "" {
|
||||
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
|
||||
}
|
||||
if r.Backend == "" {
|
||||
return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j)
|
||||
}
|
||||
if seen[r.Hostname] {
|
||||
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
|
||||
}
|
||||
seen[r.Hostname] = true
|
||||
}
|
||||
}
|
||||
|
||||
if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" {
|
||||
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
186
internal/config/config_test.go
Normal file
186
internal/config/config_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestLoadValid(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
|
||||
[proxy]
|
||||
connect_timeout = "5s"
|
||||
idle_timeout = "300s"
|
||||
shutdown_timeout = "30s"
|
||||
|
||||
[log]
|
||||
level = "info"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Listeners) != 1 {
|
||||
t.Fatalf("got %d listeners, want 1", len(cfg.Listeners))
|
||||
}
|
||||
if cfg.Listeners[0].Addr != ":443" {
|
||||
t.Fatalf("got listener addr %q, want %q", cfg.Listeners[0].Addr, ":443")
|
||||
}
|
||||
if len(cfg.Listeners[0].Routes) != 1 {
|
||||
t.Fatalf("got %d routes, want 1", len(cfg.Listeners[0].Routes))
|
||||
}
|
||||
if cfg.Listeners[0].Routes[0].Hostname != "example.com" {
|
||||
t.Fatalf("got hostname %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadNoListeners(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[log]
|
||||
level = "info"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing listeners")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadNoRoutes(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for missing routes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadDuplicateHostnames(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:9443"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate hostnames")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadGeoIPRequiredWithCountries(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
|
||||
[firewall]
|
||||
blocked_countries = ["CN"]
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
_, err := Load(path)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for blocked_countries without geoip_db")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadMultipleListeners(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "test.toml")
|
||||
|
||||
data := `
|
||||
[[listeners]]
|
||||
addr = ":443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "public.example.com"
|
||||
backend = "127.0.0.1:8443"
|
||||
|
||||
[[listeners]]
|
||||
addr = ":8443"
|
||||
|
||||
[[listeners.routes]]
|
||||
hostname = "internal.example.com"
|
||||
backend = "127.0.0.1:9443"
|
||||
`
|
||||
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
|
||||
t.Fatalf("write config: %v", err)
|
||||
}
|
||||
|
||||
cfg, err := Load(path)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(cfg.Listeners) != 2 {
|
||||
t.Fatalf("got %d listeners, want 2", len(cfg.Listeners))
|
||||
}
|
||||
if cfg.Listeners[0].Routes[0].Hostname != "public.example.com" {
|
||||
t.Fatalf("listener 0 hostname = %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "public.example.com")
|
||||
}
|
||||
if cfg.Listeners[1].Routes[0].Hostname != "internal.example.com" {
|
||||
t.Fatalf("listener 1 hostname = %q, want %q", cfg.Listeners[1].Routes[0].Hostname, "internal.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration(t *testing.T) {
|
||||
var d Duration
|
||||
if err := d.UnmarshalText([]byte("5s")); err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if d.Duration.Seconds() != 5 {
|
||||
t.Fatalf("got %v, want 5s", d.Duration)
|
||||
}
|
||||
}
|
||||
218
internal/firewall/firewall.go
Normal file
218
internal/firewall/firewall.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/oschwald/maxminddb-golang"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
)
|
||||
|
||||
type geoIPRecord struct {
|
||||
Country struct {
|
||||
ISOCode string `maxminddb:"iso_code"`
|
||||
} `maxminddb:"country"`
|
||||
}
|
||||
|
||||
// Firewall evaluates global blocklist rules against connection source addresses.
|
||||
type Firewall struct {
|
||||
blockedIPs map[netip.Addr]struct{}
|
||||
blockedCIDRs []netip.Prefix
|
||||
blockedCountries map[string]struct{}
|
||||
geoDBPath string
|
||||
geoDB *maxminddb.Reader
|
||||
mu sync.RWMutex // protects all mutable state
|
||||
}
|
||||
|
||||
// New creates a Firewall from the given configuration.
|
||||
func New(cfg config.Firewall) (*Firewall, error) {
|
||||
f := &Firewall{
|
||||
blockedIPs: make(map[netip.Addr]struct{}),
|
||||
blockedCountries: make(map[string]struct{}),
|
||||
geoDBPath: cfg.GeoIPDB,
|
||||
}
|
||||
|
||||
for _, ip := range cfg.BlockedIPs {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
|
||||
}
|
||||
f.blockedIPs[addr] = struct{}{}
|
||||
}
|
||||
|
||||
for _, cidr := range cfg.BlockedCIDRs {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
|
||||
}
|
||||
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
|
||||
}
|
||||
|
||||
for _, code := range cfg.BlockedCountries {
|
||||
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
|
||||
}
|
||||
|
||||
if len(f.blockedCountries) > 0 {
|
||||
if err := f.loadGeoDB(cfg.GeoIPDB); err != nil {
|
||||
return nil, fmt.Errorf("loading GeoIP database: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Blocked returns true if the given address should be blocked.
|
||||
func (f *Firewall) Blocked(addr netip.Addr) bool {
|
||||
addr = addr.Unmap()
|
||||
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
if _, ok := f.blockedIPs[addr]; ok {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, prefix := range f.blockedCIDRs {
|
||||
if prefix.Contains(addr) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
if len(f.blockedCountries) > 0 && f.geoDB != nil {
|
||||
var record geoIPRecord
|
||||
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
|
||||
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// AddIP adds an IP address to the blocklist.
|
||||
func (f *Firewall) AddIP(ip string) error {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP %q: %w", ip, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.blockedIPs[addr] = struct{}{}
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveIP removes an IP address from the blocklist.
|
||||
func (f *Firewall) RemoveIP(ip string) error {
|
||||
addr, err := netip.ParseAddr(ip)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid IP %q: %w", ip, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
delete(f.blockedIPs, addr)
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddCIDR adds a CIDR prefix to the blocklist.
|
||||
func (f *Firewall) AddCIDR(cidr string) error {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveCIDR removes a CIDR prefix from the blocklist.
|
||||
func (f *Firewall) RemoveCIDR(cidr string) error {
|
||||
prefix, err := netip.ParsePrefix(cidr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
for i, p := range f.blockedCIDRs {
|
||||
if p == prefix {
|
||||
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
f.mu.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddCountry adds a country code to the blocklist.
|
||||
func (f *Firewall) AddCountry(code string) {
|
||||
f.mu.Lock()
|
||||
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// RemoveCountry removes a country code from the blocklist.
|
||||
func (f *Firewall) RemoveCountry(code string) {
|
||||
f.mu.Lock()
|
||||
delete(f.blockedCountries, strings.ToUpper(code))
|
||||
f.mu.Unlock()
|
||||
}
|
||||
|
||||
// Rules returns a snapshot of all current firewall rules.
|
||||
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
|
||||
f.mu.RLock()
|
||||
defer f.mu.RUnlock()
|
||||
|
||||
for addr := range f.blockedIPs {
|
||||
ips = append(ips, addr.String())
|
||||
}
|
||||
for _, prefix := range f.blockedCIDRs {
|
||||
cidrs = append(cidrs, prefix.String())
|
||||
}
|
||||
for code := range f.blockedCountries {
|
||||
countries = append(countries, code)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
|
||||
func (f *Firewall) ReloadGeoIP() error {
|
||||
if f.geoDBPath == "" {
|
||||
return nil
|
||||
}
|
||||
return f.loadGeoDB(f.geoDBPath)
|
||||
}
|
||||
|
||||
// Close releases resources held by the firewall.
|
||||
func (f *Firewall) Close() error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
if f.geoDB != nil {
|
||||
return f.geoDB.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Firewall) loadGeoDB(path string) error {
|
||||
db, err := maxminddb.Open(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f.mu.Lock()
|
||||
old := f.geoDB
|
||||
f.geoDB = db
|
||||
f.mu.Unlock()
|
||||
|
||||
if old != nil {
|
||||
old.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
141
internal/firewall/firewall_test.go
Normal file
141
internal/firewall/firewall_test.go
Normal file
@@ -0,0 +1,141 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
)
|
||||
|
||||
func TestEmptyFirewall(t *testing.T) {
|
||||
fw, err := New(config.Firewall{})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
|
||||
for _, a := range addrs {
|
||||
addr := netip.MustParseAddr(a)
|
||||
if fw.Blocked(addr) {
|
||||
t.Fatalf("empty firewall blocked %s", addr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPBlocking(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"192.0.2.1", true},
|
||||
{"192.0.2.2", false},
|
||||
{"2001:db8::dead", true},
|
||||
{"2001:db8::beef", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCIDRBlocking(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"198.51.100.1", true},
|
||||
{"198.51.100.254", true},
|
||||
{"198.51.101.1", false},
|
||||
{"2001:db8::1", true},
|
||||
{"2001:db9::1", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIPv4MappedIPv6(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"192.0.2.1"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
// IPv4-mapped IPv6 representation of 192.0.2.1.
|
||||
addr := netip.MustParseAddr("::ffff:192.0.2.1")
|
||||
if !fw.Blocked(addr) {
|
||||
t.Fatal("expected IPv4-mapped IPv6 address to be blocked")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidIP(t *testing.T) {
|
||||
_, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"not-an-ip"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IP")
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCIDR(t *testing.T) {
|
||||
_, err := New(config.Firewall{
|
||||
BlockedCIDRs: []string{"not-a-cidr"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CIDR")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCombinedRules(t *testing.T) {
|
||||
fw, err := New(config.Firewall{
|
||||
BlockedIPs: []string{"10.0.0.1"},
|
||||
BlockedCIDRs: []string{"192.168.0.0/16"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
defer fw.Close()
|
||||
|
||||
tests := []struct {
|
||||
addr string
|
||||
blocked bool
|
||||
}{
|
||||
{"10.0.0.1", true}, // IP match
|
||||
{"10.0.0.2", false}, // no match
|
||||
{"192.168.1.1", true}, // CIDR match
|
||||
{"172.16.0.1", false}, // no match
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
if got := fw.Blocked(addr); got != tt.blocked {
|
||||
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
|
||||
}
|
||||
}
|
||||
}
|
||||
246
internal/grpcserver/grpcserver.go
Normal file
246
internal/grpcserver/grpcserver.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||
)
|
||||
|
||||
// AdminServer implements the ProxyAdmin gRPC service.
|
||||
type AdminServer struct {
|
||||
pb.UnimplementedProxyAdminServer
|
||||
srv *server.Server
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a gRPC server with TLS and optional mTLS.
|
||||
func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
// mTLS: require and verify client certificates.
|
||||
if cfg.ClientCA != "" {
|
||||
caCert, err := os.ReadFile(cfg.ClientCA)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading client CA: %w", err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caCert) {
|
||||
return nil, nil, fmt.Errorf("failed to parse client CA certificate")
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
|
||||
creds := credentials.NewTLS(tlsConfig)
|
||||
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
||||
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
logger: logger,
|
||||
}
|
||||
pb.RegisterProxyAdminServer(grpcServer, admin)
|
||||
|
||||
ln, err := net.Listen("tcp", cfg.Addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err)
|
||||
}
|
||||
|
||||
return grpcServer, ln, nil
|
||||
}
|
||||
|
||||
// ListRoutes returns the route table for a listener.
|
||||
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := ls.Routes()
|
||||
resp := &pb.ListRoutesResponse{
|
||||
ListenerAddr: ls.Addr,
|
||||
}
|
||||
for hostname, backend := range routes {
|
||||
resp.Routes = append(resp.Routes, &pb.Route{
|
||||
Hostname: hostname,
|
||||
Backend: backend,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// AddRoute adds a route to a listener's route table.
|
||||
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
|
||||
if req.Route == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "route is required")
|
||||
}
|
||||
if req.Route.Hostname == "" || req.Route.Backend == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend)
|
||||
return &pb.AddRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route from a listener's route table.
|
||||
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
|
||||
if req.Hostname == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname is required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.RemoveRoute(req.Hostname); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname)
|
||||
return &pb.RemoveRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// GetFirewallRules returns all current firewall rules.
|
||||
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
||||
ips, cidrs, countries := a.srv.Firewall().Rules()
|
||||
|
||||
var rules []*pb.FirewallRule
|
||||
for _, ip := range ips {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: ip,
|
||||
})
|
||||
}
|
||||
for _, cidr := range cidrs {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||
Value: cidr,
|
||||
})
|
||||
}
|
||||
for _, code := range countries {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||
Value: code,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
|
||||
}
|
||||
|
||||
// AddFirewallRule adds a firewall rule.
|
||||
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.AddIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.AddCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.AddCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
return &pb.AddFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveFirewallRule removes a firewall rule.
|
||||
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.RemoveIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.RemoveCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
return &pb.RemoveFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// GetStatus returns the proxy's current status.
|
||||
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
|
||||
var listeners []*pb.ListenerStatus
|
||||
for _, ls := range a.srv.Listeners() {
|
||||
routes := ls.Routes()
|
||||
listeners = append(listeners, &pb.ListenerStatus{
|
||||
Addr: ls.Addr,
|
||||
RouteCount: int32(len(routes)),
|
||||
ActiveConnections: ls.ActiveConnections.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.GetStatusResponse{
|
||||
Version: a.srv.Version(),
|
||||
StartedAt: timestamppb.New(a.srv.StartedAt()),
|
||||
Listeners: listeners,
|
||||
TotalConnections: a.srv.TotalConnections(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
|
||||
for _, ls := range a.srv.Listeners() {
|
||||
if ls.Addr == addr {
|
||||
return ls, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
|
||||
}
|
||||
105
internal/proxy/proxy.go
Normal file
105
internal/proxy/proxy.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Result holds the outcome of a relay operation.
|
||||
type Result struct {
|
||||
ClientBytes int64 // bytes sent from client to backend
|
||||
BackendBytes int64 // bytes sent from backend to client
|
||||
}
|
||||
|
||||
// Relay performs bidirectional byte copying between client and backend.
|
||||
// The peeked bytes (the TLS ClientHello) are written to the backend first.
|
||||
// Relay blocks until both directions are done or ctx is cancelled.
|
||||
func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) {
|
||||
// Forward the buffered ClientHello to the backend.
|
||||
if len(peeked) > 0 {
|
||||
if _, err := backend.Write(peeked); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel context closes both connections to unblock copy goroutines.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
client.Close()
|
||||
backend.Close()
|
||||
}()
|
||||
|
||||
var (
|
||||
result Result
|
||||
wg sync.WaitGroup
|
||||
errC2B error
|
||||
errB2C error
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// client → backend
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
|
||||
// Half-close backend's write side.
|
||||
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
|
||||
hc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// backend → client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
|
||||
// Half-close client's write side.
|
||||
if hc, ok := client.(interface{ CloseWrite() error }); ok {
|
||||
hc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If context was cancelled, that's the primary error.
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
// Return the first meaningful error, if any.
|
||||
if errC2B != nil {
|
||||
return result, errC2B
|
||||
}
|
||||
return result, errB2C
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst, resetting the idle deadline
|
||||
// on each successful read.
|
||||
func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
var total int64
|
||||
|
||||
for {
|
||||
src.SetReadDeadline(time.Now().Add(idleTimeout))
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
|
||||
nw, writeErr := dst.Write(buf[:nr])
|
||||
total += int64(nw)
|
||||
if writeErr != nil {
|
||||
return total, writeErr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
259
internal/proxy/proxy_test.go
Normal file
259
internal/proxy/proxy_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRelayBasic(t *testing.T) {
|
||||
// Set up a TCP listener to act as the backend.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
|
||||
peeked := []byte("peeked-hello-bytes")
|
||||
clientData := []byte("data from client")
|
||||
backendData := []byte("data from backend")
|
||||
|
||||
// Backend goroutine: accept, read peeked+client data, send response, close.
|
||||
backendDone := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read everything the backend receives.
|
||||
received, _ := io.ReadAll(conn)
|
||||
backendDone <- received
|
||||
|
||||
// This won't work since ReadAll waits for EOF.
|
||||
// Instead, restructure: read expected bytes, write response, close write.
|
||||
}()
|
||||
|
||||
// Restructure: use a more controlled flow.
|
||||
backendLn.Close()
|
||||
|
||||
// Use a real TCP pair for proper half-close.
|
||||
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer backendLn2.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := backendLn2.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read peeked + client data.
|
||||
buf := make([]byte, len(peeked)+len(clientData))
|
||||
n, _ := io.ReadFull(conn, buf)
|
||||
backendDone <- buf[:n]
|
||||
|
||||
// Send response.
|
||||
conn.Write(backendData)
|
||||
|
||||
// Close write side to signal EOF.
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Dial the backend.
|
||||
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial backend: %v", err)
|
||||
}
|
||||
|
||||
// Create a client-side TCP pair.
|
||||
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer clientLn.Close()
|
||||
|
||||
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial client: %v", err)
|
||||
}
|
||||
serverSideClient, err := clientLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("accept client: %v", err)
|
||||
}
|
||||
|
||||
// Client sends data then closes write.
|
||||
go func() {
|
||||
clientConn.Write(clientData)
|
||||
if tc, ok := clientConn.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Run relay.
|
||||
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("relay error: %v", err)
|
||||
}
|
||||
|
||||
// Verify backend received peeked + client data.
|
||||
received := <-backendDone
|
||||
expected := append(peeked, clientData...)
|
||||
if !bytes.Equal(received, expected) {
|
||||
t.Fatalf("backend received %q, want %q", received, expected)
|
||||
}
|
||||
|
||||
// Verify client received backend data.
|
||||
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
clientReceived, _ := io.ReadAll(clientConn)
|
||||
if !bytes.Equal(clientReceived, backendData) {
|
||||
t.Fatalf("client received %q, want %q", clientReceived, backendData)
|
||||
}
|
||||
|
||||
if result.ClientBytes != int64(len(clientData)) {
|
||||
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
|
||||
}
|
||||
if result.BackendBytes != int64(len(backendData)) {
|
||||
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
// Two connected pairs via TCP.
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
start := time.Now()
|
||||
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Should return due to idle timeout.
|
||||
if err == nil {
|
||||
t.Fatal("expected error from idle timeout")
|
||||
}
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("relay took %v, expected ~100ms", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayContextCancel(t *testing.T) {
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, clientB, backendA, nil, time.Minute)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel after a short delay.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// OK
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return after context cancel")
|
||||
}
|
||||
|
||||
_ = backendB // keep reference
|
||||
}
|
||||
|
||||
func TestRelayLargeTransfer(t *testing.T) {
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
// 1 MB of random data.
|
||||
data := make([]byte, 1<<20)
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
t.Fatalf("rand read: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
clientA.Write(data)
|
||||
if tc, ok := clientA.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Backend reads and echoes chunks, then closes when client EOF arrives.
|
||||
go func() {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := backendB.Read(buf)
|
||||
if n > 0 {
|
||||
backendB.Write(buf[:n])
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tc, ok := backendB.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("relay error: %v", err)
|
||||
}
|
||||
|
||||
if result.ClientBytes != int64(len(data)) {
|
||||
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
|
||||
}
|
||||
}
|
||||
|
||||
// tcpPair returns two connected TCP connections.
|
||||
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
serverConn, _ = ln.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
<-done
|
||||
return clientConn, serverConn
|
||||
}
|
||||
271
internal/server/server.go
Normal file
271
internal/server/server.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
||||
)
|
||||
|
||||
// ListenerState holds the mutable state for a single proxy listener.
|
||||
type ListenerState struct {
|
||||
Addr string
|
||||
routes map[string]string // lowercase hostname → backend addr
|
||||
mu sync.RWMutex
|
||||
ActiveConnections atomic.Int64
|
||||
}
|
||||
|
||||
// Routes returns a snapshot of the listener's route table.
|
||||
func (ls *ListenerState) Routes() map[string]string {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
|
||||
m := make(map[string]string, len(ls.routes))
|
||||
for k, v := range ls.routes {
|
||||
m[k] = v
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// AddRoute adds a route to the listener. Returns an error if the hostname
|
||||
// already exists.
|
||||
func (ls *ListenerState) AddRoute(hostname, backend string) error {
|
||||
key := strings.ToLower(hostname)
|
||||
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
|
||||
if _, ok := ls.routes[key]; ok {
|
||||
return fmt.Errorf("route %q already exists", hostname)
|
||||
}
|
||||
ls.routes[key] = backend
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route from the listener. Returns an error if the
|
||||
// hostname does not exist.
|
||||
func (ls *ListenerState) RemoveRoute(hostname string) error {
|
||||
key := strings.ToLower(hostname)
|
||||
|
||||
ls.mu.Lock()
|
||||
defer ls.mu.Unlock()
|
||||
|
||||
if _, ok := ls.routes[key]; !ok {
|
||||
return fmt.Errorf("route %q not found", hostname)
|
||||
}
|
||||
delete(ls.routes, key)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *ListenerState) lookupRoute(hostname string) (string, bool) {
|
||||
ls.mu.RLock()
|
||||
defer ls.mu.RUnlock()
|
||||
|
||||
backend, ok := ls.routes[hostname]
|
||||
return backend, ok
|
||||
}
|
||||
|
||||
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
|
||||
// SNI-based routing, and bidirectional proxying.
|
||||
type Server struct {
|
||||
cfg *config.Config
|
||||
fw *firewall.Firewall
|
||||
listeners []*ListenerState
|
||||
logger *slog.Logger
|
||||
wg sync.WaitGroup // tracks active connections
|
||||
startedAt time.Time
|
||||
version string
|
||||
}
|
||||
|
||||
// New creates a Server from the given configuration.
|
||||
func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) {
|
||||
fw, err := firewall.New(cfg.Firewall)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("initializing firewall: %w", err)
|
||||
}
|
||||
|
||||
var listeners []*ListenerState
|
||||
for _, lcfg := range cfg.Listeners {
|
||||
routes := make(map[string]string, len(lcfg.Routes))
|
||||
for _, r := range lcfg.Routes {
|
||||
routes[strings.ToLower(r.Hostname)] = r.Backend
|
||||
}
|
||||
listeners = append(listeners, &ListenerState{
|
||||
Addr: lcfg.Addr,
|
||||
routes: routes,
|
||||
})
|
||||
}
|
||||
|
||||
return &Server{
|
||||
cfg: cfg,
|
||||
fw: fw,
|
||||
listeners: listeners,
|
||||
logger: logger,
|
||||
version: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Firewall returns the server's firewall for use by the gRPC admin API.
|
||||
func (s *Server) Firewall() *firewall.Firewall {
|
||||
return s.fw
|
||||
}
|
||||
|
||||
// Listeners returns the server's listener states for use by the gRPC admin API.
|
||||
func (s *Server) Listeners() []*ListenerState {
|
||||
return s.listeners
|
||||
}
|
||||
|
||||
// StartedAt returns the time the server started.
|
||||
func (s *Server) StartedAt() time.Time {
|
||||
return s.startedAt
|
||||
}
|
||||
|
||||
// Version returns the server's version string.
|
||||
func (s *Server) Version() string {
|
||||
return s.version
|
||||
}
|
||||
|
||||
// TotalConnections returns the total number of active connections.
|
||||
func (s *Server) TotalConnections() int64 {
|
||||
var total int64
|
||||
for _, ls := range s.listeners {
|
||||
total += ls.ActiveConnections.Load()
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
// Run starts all listeners and blocks until ctx is cancelled.
|
||||
func (s *Server) Run(ctx context.Context) error {
|
||||
s.startedAt = time.Now()
|
||||
|
||||
var netListeners []net.Listener
|
||||
|
||||
for _, ls := range s.listeners {
|
||||
ln, err := net.Listen("tcp", ls.Addr)
|
||||
if err != nil {
|
||||
for _, l := range netListeners {
|
||||
l.Close()
|
||||
}
|
||||
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
|
||||
}
|
||||
s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes))
|
||||
netListeners = append(netListeners, ln)
|
||||
}
|
||||
|
||||
// Start accept loops.
|
||||
for i, ln := range netListeners {
|
||||
ln := ln
|
||||
ls := s.listeners[i]
|
||||
go s.serve(ctx, ln, ls)
|
||||
}
|
||||
|
||||
// Block until shutdown signal.
|
||||
<-ctx.Done()
|
||||
s.logger.Info("shutting down")
|
||||
|
||||
// Stop accepting new connections.
|
||||
for _, ln := range netListeners {
|
||||
ln.Close()
|
||||
}
|
||||
|
||||
// Wait for in-flight connections with a timeout.
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
s.wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
s.logger.Info("all connections drained")
|
||||
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
|
||||
s.logger.Warn("shutdown timeout exceeded, forcing close")
|
||||
}
|
||||
|
||||
s.fw.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReloadGeoIP reloads the GeoIP database from disk.
|
||||
func (s *Server) ReloadGeoIP() error {
|
||||
return s.fw.ReloadGeoIP()
|
||||
}
|
||||
|
||||
func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
s.logger.Error("accept error", "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
ls.ActiveConnections.Add(1)
|
||||
go s.handleConn(ctx, conn, ls)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
|
||||
defer s.wg.Done()
|
||||
defer ls.ActiveConnections.Add(-1)
|
||||
defer conn.Close()
|
||||
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr)
|
||||
if err != nil {
|
||||
s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err)
|
||||
return
|
||||
}
|
||||
addr := addrPort.Addr()
|
||||
|
||||
if s.fw.Blocked(addr) {
|
||||
s.logger.Debug("blocked by firewall", "addr", addr)
|
||||
return
|
||||
}
|
||||
|
||||
hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second))
|
||||
if err != nil {
|
||||
s.logger.Debug("SNI extraction failed", "addr", addr, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
backend, ok := ls.lookupRoute(hostname)
|
||||
if !ok {
|
||||
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
|
||||
return
|
||||
}
|
||||
|
||||
backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
||||
if err != nil {
|
||||
s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err)
|
||||
return
|
||||
}
|
||||
defer backendConn.Close()
|
||||
|
||||
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend)
|
||||
|
||||
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
|
||||
if err != nil && ctx.Err() == nil {
|
||||
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
|
||||
}
|
||||
|
||||
s.logger.Info("connection closed",
|
||||
"addr", addr,
|
||||
"hostname", hostname,
|
||||
"client_bytes", result.ClientBytes,
|
||||
"backend_bytes", result.BackendBytes,
|
||||
)
|
||||
}
|
||||
175
internal/sni/sni.go
Normal file
175
internal/sni/sni.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
const maxBufferSize = 16384 // 16 KiB, max TLS record size
|
||||
|
||||
// Extract reads the TLS ClientHello from conn and returns the SNI hostname.
|
||||
// The returned peeked bytes contain the full ClientHello and must be forwarded
|
||||
// to the backend before starting the bidirectional copy.
|
||||
//
|
||||
// A read deadline is set on the connection to prevent slowloris attacks.
|
||||
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
|
||||
conn.SetReadDeadline(deadline)
|
||||
defer conn.SetReadDeadline(time.Time{})
|
||||
|
||||
// Read TLS record header (5 bytes).
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", nil, fmt.Errorf("reading TLS record header: %w", err)
|
||||
}
|
||||
|
||||
// Verify this is a TLS handshake record (content type 0x16).
|
||||
if header[0] != 0x16 {
|
||||
return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0])
|
||||
}
|
||||
|
||||
// Record length.
|
||||
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
|
||||
if recordLen == 0 || recordLen > maxBufferSize-5 {
|
||||
return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen)
|
||||
}
|
||||
|
||||
// Read the full record body.
|
||||
buf := make([]byte, 5+recordLen)
|
||||
copy(buf, header)
|
||||
if _, err := io.ReadFull(conn, buf[5:]); err != nil {
|
||||
return "", nil, fmt.Errorf("reading TLS record body: %w", err)
|
||||
}
|
||||
|
||||
// Parse the handshake message from the record body.
|
||||
hostname, err = parseClientHello(buf[5:])
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
return hostname, buf, nil
|
||||
}
|
||||
|
||||
func parseClientHello(data []byte) (string, error) {
|
||||
if len(data) < 4 {
|
||||
return "", fmt.Errorf("handshake message too short")
|
||||
}
|
||||
|
||||
// Handshake type: 0x01 = ClientHello.
|
||||
if data[0] != 0x01 {
|
||||
return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0])
|
||||
}
|
||||
|
||||
// Handshake length (3 bytes).
|
||||
hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
|
||||
data = data[4:]
|
||||
if len(data) < hsLen {
|
||||
return "", fmt.Errorf("ClientHello truncated")
|
||||
}
|
||||
data = data[:hsLen]
|
||||
|
||||
// Skip client version (2 bytes) + random (32 bytes).
|
||||
if len(data) < 34 {
|
||||
return "", fmt.Errorf("ClientHello too short for version+random")
|
||||
}
|
||||
data = data[34:]
|
||||
|
||||
// Skip session ID (1-byte length prefix).
|
||||
if len(data) < 1 {
|
||||
return "", fmt.Errorf("ClientHello too short for session ID length")
|
||||
}
|
||||
sidLen := int(data[0])
|
||||
data = data[1:]
|
||||
if len(data) < sidLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at session ID")
|
||||
}
|
||||
data = data[sidLen:]
|
||||
|
||||
// Skip cipher suites (2-byte length prefix).
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("ClientHello too short for cipher suites length")
|
||||
}
|
||||
csLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < csLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at cipher suites")
|
||||
}
|
||||
data = data[csLen:]
|
||||
|
||||
// Skip compression methods (1-byte length prefix).
|
||||
if len(data) < 1 {
|
||||
return "", fmt.Errorf("ClientHello too short for compression methods length")
|
||||
}
|
||||
cmLen := int(data[0])
|
||||
data = data[1:]
|
||||
if len(data) < cmLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at compression methods")
|
||||
}
|
||||
data = data[cmLen:]
|
||||
|
||||
// Extensions (2-byte total length).
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("no extensions in ClientHello")
|
||||
}
|
||||
extLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < extLen {
|
||||
return "", fmt.Errorf("ClientHello truncated at extensions")
|
||||
}
|
||||
data = data[:extLen]
|
||||
|
||||
return findSNI(data)
|
||||
}
|
||||
|
||||
func findSNI(data []byte) (string, error) {
|
||||
for len(data) >= 4 {
|
||||
extType := binary.BigEndian.Uint16(data[:2])
|
||||
extDataLen := int(binary.BigEndian.Uint16(data[2:4]))
|
||||
data = data[4:]
|
||||
if len(data) < extDataLen {
|
||||
return "", fmt.Errorf("extension truncated")
|
||||
}
|
||||
|
||||
if extType == 0x0000 { // server_name
|
||||
return parseServerNameExtension(data[:extDataLen])
|
||||
}
|
||||
|
||||
data = data[extDataLen:]
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no SNI extension found")
|
||||
}
|
||||
|
||||
func parseServerNameExtension(data []byte) (string, error) {
|
||||
if len(data) < 2 {
|
||||
return "", fmt.Errorf("server_name extension too short")
|
||||
}
|
||||
|
||||
// Server name list length.
|
||||
listLen := int(binary.BigEndian.Uint16(data[:2]))
|
||||
data = data[2:]
|
||||
if len(data) < listLen {
|
||||
return "", fmt.Errorf("server_name list truncated")
|
||||
}
|
||||
data = data[:listLen]
|
||||
|
||||
for len(data) >= 3 {
|
||||
nameType := data[0]
|
||||
nameLen := int(binary.BigEndian.Uint16(data[1:3]))
|
||||
data = data[3:]
|
||||
if len(data) < nameLen {
|
||||
return "", fmt.Errorf("server_name entry truncated")
|
||||
}
|
||||
|
||||
if nameType == 0x00 { // hostname
|
||||
return strings.ToLower(string(data[:nameLen])), nil
|
||||
}
|
||||
|
||||
data = data[nameLen:]
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no hostname in server_name extension")
|
||||
}
|
||||
220
internal/sni/sni_test.go
Normal file
220
internal/sni/sni_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package sni
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestExtract(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sni string
|
||||
wantSNI string
|
||||
wantErr bool
|
||||
}{
|
||||
{"basic", "example.com", "example.com", false},
|
||||
{"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false},
|
||||
{"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHello(tt.sni)
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if tt.wantErr {
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if hostname != tt.wantSNI {
|
||||
t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI)
|
||||
}
|
||||
if len(peeked) != len(hello) {
|
||||
t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNoSNI(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHelloNoSNI()
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for ClientHello without SNI")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractNotTLS(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-TLS data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractTruncated(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
// Write just the TLS record header, then close.
|
||||
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for truncated record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractOversizedRecord(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
go func() {
|
||||
// Record header claiming a length larger than 16 KiB.
|
||||
header := []byte{0x16, 0x03, 0x01}
|
||||
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
|
||||
client.Write(header)
|
||||
client.Close()
|
||||
}()
|
||||
|
||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err == nil {
|
||||
t.Fatal("expected error for oversized record")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractMultipleExtensions(t *testing.T) {
|
||||
client, server := net.Pipe()
|
||||
defer client.Close()
|
||||
defer server.Close()
|
||||
|
||||
hello := buildClientHelloWithExtraExtensions("target.example.com")
|
||||
|
||||
go func() {
|
||||
client.Write(hello)
|
||||
}()
|
||||
|
||||
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if hostname != "target.example.com" {
|
||||
t.Fatalf("got hostname %q, want %q", hostname, "target.example.com")
|
||||
}
|
||||
}
|
||||
|
||||
// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension.
|
||||
func buildClientHello(serverName string) []byte {
|
||||
return buildClientHelloWithExtensions(sniExtension(serverName))
|
||||
}
|
||||
|
||||
// buildClientHelloNoSNI constructs a ClientHello with no extensions.
|
||||
func buildClientHelloNoSNI() []byte {
|
||||
return buildClientHelloWithExtensions(nil)
|
||||
}
|
||||
|
||||
// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI.
|
||||
func buildClientHelloWithExtraExtensions(serverName string) []byte {
|
||||
// Dummy extension (type 0xFF01, empty data).
|
||||
dummy := []byte{0xFF, 0x01, 0x00, 0x00}
|
||||
ext := append(dummy, sniExtension(serverName)...)
|
||||
return buildClientHelloWithExtensions(ext)
|
||||
}
|
||||
|
||||
func buildClientHelloWithExtensions(extensions []byte) []byte {
|
||||
var hello []byte
|
||||
|
||||
// Client version: TLS 1.2.
|
||||
hello = append(hello, 0x03, 0x03)
|
||||
|
||||
// Random: 32 bytes of zeros.
|
||||
hello = append(hello, make([]byte, 32)...)
|
||||
|
||||
// Session ID: empty.
|
||||
hello = append(hello, 0x00)
|
||||
|
||||
// Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256).
|
||||
hello = append(hello, 0x00, 0x02, 0x00, 0x9C)
|
||||
|
||||
// Compression methods: null.
|
||||
hello = append(hello, 0x01, 0x00)
|
||||
|
||||
// Extensions.
|
||||
if len(extensions) > 0 {
|
||||
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
|
||||
hello = append(hello, extensions...)
|
||||
}
|
||||
|
||||
// Wrap in handshake header (type 0x01 = ClientHello).
|
||||
handshake := []byte{0x01, 0x00, 0x00, 0x00}
|
||||
handshake[1] = byte(len(hello) >> 16)
|
||||
handshake[2] = byte(len(hello) >> 8)
|
||||
handshake[3] = byte(len(hello))
|
||||
handshake = append(handshake, hello...)
|
||||
|
||||
// Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0).
|
||||
record := []byte{0x16, 0x03, 0x01}
|
||||
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
|
||||
record = append(record, handshake...)
|
||||
|
||||
return record
|
||||
}
|
||||
|
||||
func sniExtension(serverName string) []byte {
|
||||
name := []byte(serverName)
|
||||
|
||||
// Server name entry: type 0x00 (hostname), length, name.
|
||||
var entry []byte
|
||||
entry = append(entry, 0x00)
|
||||
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
|
||||
entry = append(entry, name...)
|
||||
|
||||
// Server name list: length prefix.
|
||||
var list []byte
|
||||
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
|
||||
list = append(list, entry...)
|
||||
|
||||
// Extension: type 0x0000 (server_name), length, data.
|
||||
var ext []byte
|
||||
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
|
||||
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
|
||||
ext = append(ext, list...)
|
||||
|
||||
return ext
|
||||
}
|
||||
Reference in New Issue
Block a user