Initial implementation of mc-proxy

Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking),
per-listener route tables, bidirectional TCP relay with half-close
propagation, and a gRPC admin API (routes, firewall, status) with
TLS/mTLS support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 02:56:24 -07:00
commit c7024dcdf0
23 changed files with 2693 additions and 0 deletions

115
internal/config/config.go Normal file
View File

@@ -0,0 +1,115 @@
package config
import (
"fmt"
"os"
"time"
"github.com/pelletier/go-toml/v2"
)
type Config struct {
Listeners []Listener `toml:"listeners"`
GRPC GRPC `toml:"grpc"`
Firewall Firewall `toml:"firewall"`
Proxy Proxy `toml:"proxy"`
Log Log `toml:"log"`
}
type GRPC struct {
Addr string `toml:"addr"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
ClientCA string `toml:"client_ca"`
}
type Listener struct {
Addr string `toml:"addr"`
Routes []Route `toml:"routes"`
}
type Route struct {
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
}
type Firewall struct {
GeoIPDB string `toml:"geoip_db"`
BlockedIPs []string `toml:"blocked_ips"`
BlockedCIDRs []string `toml:"blocked_cidrs"`
BlockedCountries []string `toml:"blocked_countries"`
}
type Proxy struct {
ConnectTimeout Duration `toml:"connect_timeout"`
IdleTimeout Duration `toml:"idle_timeout"`
ShutdownTimeout Duration `toml:"shutdown_timeout"`
}
type Log struct {
Level string `toml:"level"`
}
// Duration wraps time.Duration for TOML string unmarshalling.
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config: %w", err)
}
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &cfg, nil
}
func (c *Config) validate() error {
if len(c.Listeners) == 0 {
return fmt.Errorf("at least one listener is required")
}
for i, l := range c.Listeners {
if l.Addr == "" {
return fmt.Errorf("listener %d: addr is required", i)
}
if len(l.Routes) == 0 {
return fmt.Errorf("listener %d (%s): at least one route is required", i, l.Addr)
}
seen := make(map[string]bool)
for j, r := range l.Routes {
if r.Hostname == "" {
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
}
if r.Backend == "" {
return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j)
}
if seen[r.Hostname] {
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
}
seen[r.Hostname] = true
}
}
if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" {
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
}
return nil
}

View File

@@ -0,0 +1,186 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadValid(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[proxy]
connect_timeout = "5s"
idle_timeout = "300s"
shutdown_timeout = "30s"
[log]
level = "info"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Listeners) != 1 {
t.Fatalf("got %d listeners, want 1", len(cfg.Listeners))
}
if cfg.Listeners[0].Addr != ":443" {
t.Fatalf("got listener addr %q, want %q", cfg.Listeners[0].Addr, ":443")
}
if len(cfg.Listeners[0].Routes) != 1 {
t.Fatalf("got %d routes, want 1", len(cfg.Listeners[0].Routes))
}
if cfg.Listeners[0].Routes[0].Hostname != "example.com" {
t.Fatalf("got hostname %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "example.com")
}
}
func TestLoadNoListeners(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[log]
level = "info"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for missing listeners")
}
}
func TestLoadNoRoutes(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for missing routes")
}
}
func TestLoadDuplicateHostnames(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:9443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for duplicate hostnames")
}
}
func TestLoadGeoIPRequiredWithCountries(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[firewall]
blocked_countries = ["CN"]
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for blocked_countries without geoip_db")
}
}
func TestLoadMultipleListeners(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "public.example.com"
backend = "127.0.0.1:8443"
[[listeners]]
addr = ":8443"
[[listeners.routes]]
hostname = "internal.example.com"
backend = "127.0.0.1:9443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Listeners) != 2 {
t.Fatalf("got %d listeners, want 2", len(cfg.Listeners))
}
if cfg.Listeners[0].Routes[0].Hostname != "public.example.com" {
t.Fatalf("listener 0 hostname = %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "public.example.com")
}
if cfg.Listeners[1].Routes[0].Hostname != "internal.example.com" {
t.Fatalf("listener 1 hostname = %q, want %q", cfg.Listeners[1].Routes[0].Hostname, "internal.example.com")
}
}
func TestDuration(t *testing.T) {
var d Duration
if err := d.UnmarshalText([]byte("5s")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if d.Duration.Seconds() != 5 {
t.Fatalf("got %v, want 5s", d.Duration)
}
}

View File

@@ -0,0 +1,218 @@
package firewall
import (
"fmt"
"net/netip"
"strings"
"sync"
"github.com/oschwald/maxminddb-golang"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
)
type geoIPRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Firewall evaluates global blocklist rules against connection source addresses.
type Firewall struct {
blockedIPs map[netip.Addr]struct{}
blockedCIDRs []netip.Prefix
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from the given configuration.
func New(cfg config.Firewall) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: cfg.GeoIPDB,
}
for _, ip := range cfg.BlockedIPs {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
}
f.blockedIPs[addr] = struct{}{}
}
for _, cidr := range cfg.BlockedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
}
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
}
for _, code := range cfg.BlockedCountries {
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
}
if len(f.blockedCountries) > 0 {
if err := f.loadGeoDB(cfg.GeoIPDB); err != nil {
return nil, fmt.Errorf("loading GeoIP database: %w", err)
}
}
return f, nil
}
// Blocked returns true if the given address should be blocked.
func (f *Firewall) Blocked(addr netip.Addr) bool {
addr = addr.Unmap()
f.mu.RLock()
defer f.mu.RUnlock()
if _, ok := f.blockedIPs[addr]; ok {
return true
}
for _, prefix := range f.blockedCIDRs {
if prefix.Contains(addr) {
return true
}
}
if len(f.blockedCountries) > 0 && f.geoDB != nil {
var record geoIPRecord
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
return true
}
}
}
return false
}
// AddIP adds an IP address to the blocklist.
func (f *Firewall) AddIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
f.blockedIPs[addr] = struct{}{}
f.mu.Unlock()
return nil
}
// RemoveIP removes an IP address from the blocklist.
func (f *Firewall) RemoveIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
delete(f.blockedIPs, addr)
f.mu.Unlock()
return nil
}
// AddCIDR adds a CIDR prefix to the blocklist.
func (f *Firewall) AddCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
f.mu.Unlock()
return nil
}
// RemoveCIDR removes a CIDR prefix from the blocklist.
func (f *Firewall) RemoveCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
for i, p := range f.blockedCIDRs {
if p == prefix {
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
break
}
}
f.mu.Unlock()
return nil
}
// AddCountry adds a country code to the blocklist.
func (f *Firewall) AddCountry(code string) {
f.mu.Lock()
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
f.mu.Unlock()
}
// RemoveCountry removes a country code from the blocklist.
func (f *Firewall) RemoveCountry(code string) {
f.mu.Lock()
delete(f.blockedCountries, strings.ToUpper(code))
f.mu.Unlock()
}
// Rules returns a snapshot of all current firewall rules.
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
f.mu.RLock()
defer f.mu.RUnlock()
for addr := range f.blockedIPs {
ips = append(ips, addr.String())
}
for _, prefix := range f.blockedCIDRs {
cidrs = append(cidrs, prefix.String())
}
for code := range f.blockedCountries {
countries = append(countries, code)
}
return
}
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
func (f *Firewall) ReloadGeoIP() error {
if f.geoDBPath == "" {
return nil
}
return f.loadGeoDB(f.geoDBPath)
}
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.geoDB != nil {
return f.geoDB.Close()
}
return nil
}
func (f *Firewall) loadGeoDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
f.mu.Lock()
old := f.geoDB
f.geoDB = db
f.mu.Unlock()
if old != nil {
old.Close()
}
return nil
}

View File

@@ -0,0 +1,141 @@
package firewall
import (
"net/netip"
"testing"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
)
func TestEmptyFirewall(t *testing.T) {
fw, err := New(config.Firewall{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
for _, a := range addrs {
addr := netip.MustParseAddr(a)
if fw.Blocked(addr) {
t.Fatalf("empty firewall blocked %s", addr)
}
}
}
func TestIPBlocking(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"192.0.2.1", true},
{"192.0.2.2", false},
{"2001:db8::dead", true},
{"2001:db8::beef", false},
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}
func TestCIDRBlocking(t *testing.T) {
fw, err := New(config.Firewall{
BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"198.51.100.1", true},
{"198.51.100.254", true},
{"198.51.101.1", false},
{"2001:db8::1", true},
{"2001:db9::1", false},
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}
func TestIPv4MappedIPv6(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"192.0.2.1"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
// IPv4-mapped IPv6 representation of 192.0.2.1.
addr := netip.MustParseAddr("::ffff:192.0.2.1")
if !fw.Blocked(addr) {
t.Fatal("expected IPv4-mapped IPv6 address to be blocked")
}
}
func TestInvalidIP(t *testing.T) {
_, err := New(config.Firewall{
BlockedIPs: []string{"not-an-ip"},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
}
func TestInvalidCIDR(t *testing.T) {
_, err := New(config.Firewall{
BlockedCIDRs: []string{"not-a-cidr"},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
}
func TestCombinedRules(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"10.0.0.1"},
BlockedCIDRs: []string{"192.168.0.0/16"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"10.0.0.1", true}, // IP match
{"10.0.0.2", false}, // no match
{"192.168.1.1", true}, // CIDR match
{"172.16.0.1", false}, // no match
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}

View File

@@ -0,0 +1,246 @@
package grpcserver
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
"os"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServer
srv *server.Server
logger *slog.Logger
}
// New creates a gRPC server with TLS and optional mTLS.
func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
if err != nil {
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
// mTLS: require and verify client certificates.
if cfg.ClientCA != "" {
caCert, err := os.ReadFile(cfg.ClientCA)
if err != nil {
return nil, nil, fmt.Errorf("reading client CA: %w", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caCert) {
return nil, nil, fmt.Errorf("failed to parse client CA certificate")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = pool
}
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(grpc.Creds(creds))
admin := &AdminServer{
srv: srv,
logger: logger,
}
pb.RegisterProxyAdminServer(grpcServer, admin)
ln, err := net.Listen("tcp", cfg.Addr)
if err != nil {
return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err)
}
return grpcServer, ln, nil
}
// ListRoutes returns the route table for a listener.
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
routes := ls.Routes()
resp := &pb.ListRoutesResponse{
ListenerAddr: ls.Addr,
}
for hostname, backend := range routes {
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: backend,
})
}
return resp, nil
}
// AddRoute adds a route to a listener's route table.
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
if req.Route == nil {
return nil, status.Error(codes.InvalidArgument, "route is required")
}
if req.Route.Hostname == "" || req.Route.Backend == "" {
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend)
return &pb.AddRouteResponse{}, nil
}
// RemoveRoute removes a route from a listener's route table.
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
if req.Hostname == "" {
return nil, status.Error(codes.InvalidArgument, "hostname is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
if err := ls.RemoveRoute(req.Hostname); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname)
return &pb.RemoveRouteResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()
var rules []*pb.FirewallRule
for _, ip := range ips {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: ip,
})
}
for _, cidr := range cidrs {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: cidr,
})
}
for _, code := range countries {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: code,
})
}
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
}
// AddFirewallRule adds a firewall rule.
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
fw := a.srv.Firewall()
switch req.Rule.Type {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
if err := fw.AddIP(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
if err := fw.AddCIDR(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "country code is required")
}
fw.AddCountry(req.Rule.Value)
default:
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
}
a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value)
return &pb.AddFirewallRuleResponse{}, nil
}
// RemoveFirewallRule removes a firewall rule.
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
fw := a.srv.Firewall()
switch req.Rule.Type {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
if err := fw.RemoveIP(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "country code is required")
}
fw.RemoveCountry(req.Rule.Value)
default:
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
}
a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value)
return &pb.RemoveFirewallRuleResponse{}, nil
}
// GetStatus returns the proxy's current status.
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
var listeners []*pb.ListenerStatus
for _, ls := range a.srv.Listeners() {
routes := ls.Routes()
listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr,
RouteCount: int32(len(routes)),
ActiveConnections: ls.ActiveConnections.Load(),
})
}
return &pb.GetStatusResponse{
Version: a.srv.Version(),
StartedAt: timestamppb.New(a.srv.StartedAt()),
Listeners: listeners,
TotalConnections: a.srv.TotalConnections(),
}, nil
}
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
for _, ls := range a.srv.Listeners() {
if ls.Addr == addr {
return ls, nil
}
}
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
}

105
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,105 @@
package proxy
import (
"context"
"io"
"net"
"sync"
"time"
)
// Result holds the outcome of a relay operation.
type Result struct {
ClientBytes int64 // bytes sent from client to backend
BackendBytes int64 // bytes sent from backend to client
}
// Relay performs bidirectional byte copying between client and backend.
// The peeked bytes (the TLS ClientHello) are written to the backend first.
// Relay blocks until both directions are done or ctx is cancelled.
func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) {
// Forward the buffered ClientHello to the backend.
if len(peeked) > 0 {
if _, err := backend.Write(peeked); err != nil {
return Result{}, err
}
}
// Cancel context closes both connections to unblock copy goroutines.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
client.Close()
backend.Close()
}()
var (
result Result
wg sync.WaitGroup
errC2B error
errB2C error
)
wg.Add(2)
// client → backend
go func() {
defer wg.Done()
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
// backend → client
go func() {
defer wg.Done()
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
wg.Wait()
// If context was cancelled, that's the primary error.
if ctx.Err() != nil {
return result, ctx.Err()
}
// Return the first meaningful error, if any.
if errC2B != nil {
return result, errC2B
}
return result, errB2C
}
// copyWithIdleTimeout copies from src to dst, resetting the idle deadline
// on each successful read.
func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) {
buf := make([]byte, 32*1024)
var total int64
for {
src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr])
total += int64(nw)
if writeErr != nil {
return total, writeErr
}
}
if readErr != nil {
if readErr == io.EOF {
return total, nil
}
return total, readErr
}
}
}

View File

@@ -0,0 +1,259 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"io"
"net"
"testing"
"time"
)
func TestRelayBasic(t *testing.T) {
// Set up a TCP listener to act as the backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn.Close()
peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client")
backendData := []byte("data from backend")
// Backend goroutine: accept, read peeked+client data, send response, close.
backendDone := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
// Read everything the backend receives.
received, _ := io.ReadAll(conn)
backendDone <- received
// This won't work since ReadAll waits for EOF.
// Instead, restructure: read expected bytes, write response, close write.
}()
// Restructure: use a more controlled flow.
backendLn.Close()
// Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn2.Close()
go func() {
conn, err := backendLn2.Accept()
if err != nil {
return
}
defer conn.Close()
// Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData))
n, _ := io.ReadFull(conn, buf)
backendDone <- buf[:n]
// Send response.
conn.Write(backendData)
// Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Dial the backend.
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
if err != nil {
t.Fatalf("dial backend: %v", err)
}
// Create a client-side TCP pair.
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer clientLn.Close()
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil {
t.Fatalf("dial client: %v", err)
}
serverSideClient, err := clientLn.Accept()
if err != nil {
t.Fatalf("accept client: %v", err)
}
// Client sends data then closes write.
go func() {
clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Run relay.
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
// Verify backend received peeked + client data.
received := <-backendDone
expected := append(peeked, clientData...)
if !bytes.Equal(received, expected) {
t.Fatalf("backend received %q, want %q", received, expected)
}
// Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData)
}
if result.ClientBytes != int64(len(clientData)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
}
if result.BackendBytes != int64(len(backendData)) {
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
}
}
func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP.
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
elapsed := time.Since(start)
// Should return due to idle timeout.
if err == nil {
t.Fatal("expected error from idle timeout")
}
if elapsed > 2*time.Second {
t.Fatalf("relay took %v, expected ~100ms", elapsed)
}
}
func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, clientB, backendA, nil, time.Minute)
close(done)
}()
// Cancel after a short delay.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
// OK
case <-time.After(2 * time.Second):
t.Fatal("relay did not return after context cancel")
}
_ = backendB // keep reference
}
func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
// 1 MB of random data.
data := make([]byte, 1<<20)
if _, err := rand.Read(data); err != nil {
t.Fatalf("rand read: %v", err)
}
go func() {
clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Backend reads and echoes chunks, then closes when client EOF arrives.
go func() {
buf := make([]byte, 32*1024)
for {
n, err := backendB.Read(buf)
if n > 0 {
backendB.Write(buf[:n])
}
if err != nil {
break
}
}
if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
if result.ClientBytes != int64(len(data)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
}
}
// tcpPair returns two connected TCP connections.
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var serverConn net.Conn
done := make(chan struct{})
go func() {
serverConn, _ = ln.Accept()
close(done)
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
<-done
return clientConn, serverConn
}

271
internal/server/server.go Normal file
View File

@@ -0,0 +1,271 @@
package server
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
)
// ListenerState holds the mutable state for a single proxy listener.
type ListenerState struct {
Addr string
routes map[string]string // lowercase hostname → backend addr
mu sync.RWMutex
ActiveConnections atomic.Int64
}
// Routes returns a snapshot of the listener's route table.
func (ls *ListenerState) Routes() map[string]string {
ls.mu.RLock()
defer ls.mu.RUnlock()
m := make(map[string]string, len(ls.routes))
for k, v := range ls.routes {
m[k] = v
}
return m
}
// AddRoute adds a route to the listener. Returns an error if the hostname
// already exists.
func (ls *ListenerState) AddRoute(hostname, backend string) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; ok {
return fmt.Errorf("route %q already exists", hostname)
}
ls.routes[key] = backend
return nil
}
// RemoveRoute removes a route from the listener. Returns an error if the
// hostname does not exist.
func (ls *ListenerState) RemoveRoute(hostname string) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; !ok {
return fmt.Errorf("route %q not found", hostname)
}
delete(ls.routes, key)
return nil
}
func (ls *ListenerState) lookupRoute(hostname string) (string, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
backend, ok := ls.routes[hostname]
return backend, ok
}
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
// SNI-based routing, and bidirectional proxying.
type Server struct {
cfg *config.Config
fw *firewall.Firewall
listeners []*ListenerState
logger *slog.Logger
wg sync.WaitGroup // tracks active connections
startedAt time.Time
version string
}
// New creates a Server from the given configuration.
func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) {
fw, err := firewall.New(cfg.Firewall)
if err != nil {
return nil, fmt.Errorf("initializing firewall: %w", err)
}
var listeners []*ListenerState
for _, lcfg := range cfg.Listeners {
routes := make(map[string]string, len(lcfg.Routes))
for _, r := range lcfg.Routes {
routes[strings.ToLower(r.Hostname)] = r.Backend
}
listeners = append(listeners, &ListenerState{
Addr: lcfg.Addr,
routes: routes,
})
}
return &Server{
cfg: cfg,
fw: fw,
listeners: listeners,
logger: logger,
version: version,
}, nil
}
// Firewall returns the server's firewall for use by the gRPC admin API.
func (s *Server) Firewall() *firewall.Firewall {
return s.fw
}
// Listeners returns the server's listener states for use by the gRPC admin API.
func (s *Server) Listeners() []*ListenerState {
return s.listeners
}
// StartedAt returns the time the server started.
func (s *Server) StartedAt() time.Time {
return s.startedAt
}
// Version returns the server's version string.
func (s *Server) Version() string {
return s.version
}
// TotalConnections returns the total number of active connections.
func (s *Server) TotalConnections() int64 {
var total int64
for _, ls := range s.listeners {
total += ls.ActiveConnections.Load()
}
return total
}
// Run starts all listeners and blocks until ctx is cancelled.
func (s *Server) Run(ctx context.Context) error {
s.startedAt = time.Now()
var netListeners []net.Listener
for _, ls := range s.listeners {
ln, err := net.Listen("tcp", ls.Addr)
if err != nil {
for _, l := range netListeners {
l.Close()
}
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
}
s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes))
netListeners = append(netListeners, ln)
}
// Start accept loops.
for i, ln := range netListeners {
ln := ln
ls := s.listeners[i]
go s.serve(ctx, ln, ls)
}
// Block until shutdown signal.
<-ctx.Done()
s.logger.Info("shutting down")
// Stop accepting new connections.
for _, ln := range netListeners {
ln.Close()
}
// Wait for in-flight connections with a timeout.
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
s.logger.Info("all connections drained")
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
s.logger.Warn("shutdown timeout exceeded, forcing close")
}
s.fw.Close()
return nil
}
// ReloadGeoIP reloads the GeoIP database from disk.
func (s *Server) ReloadGeoIP() error {
return s.fw.ReloadGeoIP()
}
func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) {
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
s.logger.Error("accept error", "error", err)
continue
}
s.wg.Add(1)
ls.ActiveConnections.Add(1)
go s.handleConn(ctx, conn, ls)
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer conn.Close()
remoteAddr := conn.RemoteAddr().String()
addrPort, err := netip.ParseAddrPort(remoteAddr)
if err != nil {
s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err)
return
}
addr := addrPort.Addr()
if s.fw.Blocked(addr) {
s.logger.Debug("blocked by firewall", "addr", addr)
return
}
hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second))
if err != nil {
s.logger.Debug("SNI extraction failed", "addr", addr, "error", err)
return
}
backend, ok := ls.lookupRoute(hostname)
if !ok {
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
return
}
backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration)
if err != nil {
s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err)
return
}
defer backendConn.Close()
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend)
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
if err != nil && ctx.Err() == nil {
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
}
s.logger.Info("connection closed",
"addr", addr,
"hostname", hostname,
"client_bytes", result.ClientBytes,
"backend_bytes", result.BackendBytes,
)
}

175
internal/sni/sni.go Normal file
View File

@@ -0,0 +1,175 @@
package sni
import (
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"time"
)
const maxBufferSize = 16384 // 16 KiB, max TLS record size
// Extract reads the TLS ClientHello from conn and returns the SNI hostname.
// The returned peeked bytes contain the full ClientHello and must be forwarded
// to the backend before starting the bidirectional copy.
//
// A read deadline is set on the connection to prevent slowloris attacks.
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{})
// Read TLS record header (5 bytes).
header := make([]byte, 5)
if _, err := io.ReadFull(conn, header); err != nil {
return "", nil, fmt.Errorf("reading TLS record header: %w", err)
}
// Verify this is a TLS handshake record (content type 0x16).
if header[0] != 0x16 {
return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0])
}
// Record length.
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
if recordLen == 0 || recordLen > maxBufferSize-5 {
return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen)
}
// Read the full record body.
buf := make([]byte, 5+recordLen)
copy(buf, header)
if _, err := io.ReadFull(conn, buf[5:]); err != nil {
return "", nil, fmt.Errorf("reading TLS record body: %w", err)
}
// Parse the handshake message from the record body.
hostname, err = parseClientHello(buf[5:])
if err != nil {
return "", nil, err
}
return hostname, buf, nil
}
func parseClientHello(data []byte) (string, error) {
if len(data) < 4 {
return "", fmt.Errorf("handshake message too short")
}
// Handshake type: 0x01 = ClientHello.
if data[0] != 0x01 {
return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0])
}
// Handshake length (3 bytes).
hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < hsLen {
return "", fmt.Errorf("ClientHello truncated")
}
data = data[:hsLen]
// Skip client version (2 bytes) + random (32 bytes).
if len(data) < 34 {
return "", fmt.Errorf("ClientHello too short for version+random")
}
data = data[34:]
// Skip session ID (1-byte length prefix).
if len(data) < 1 {
return "", fmt.Errorf("ClientHello too short for session ID length")
}
sidLen := int(data[0])
data = data[1:]
if len(data) < sidLen {
return "", fmt.Errorf("ClientHello truncated at session ID")
}
data = data[sidLen:]
// Skip cipher suites (2-byte length prefix).
if len(data) < 2 {
return "", fmt.Errorf("ClientHello too short for cipher suites length")
}
csLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < csLen {
return "", fmt.Errorf("ClientHello truncated at cipher suites")
}
data = data[csLen:]
// Skip compression methods (1-byte length prefix).
if len(data) < 1 {
return "", fmt.Errorf("ClientHello too short for compression methods length")
}
cmLen := int(data[0])
data = data[1:]
if len(data) < cmLen {
return "", fmt.Errorf("ClientHello truncated at compression methods")
}
data = data[cmLen:]
// Extensions (2-byte total length).
if len(data) < 2 {
return "", fmt.Errorf("no extensions in ClientHello")
}
extLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < extLen {
return "", fmt.Errorf("ClientHello truncated at extensions")
}
data = data[:extLen]
return findSNI(data)
}
func findSNI(data []byte) (string, error) {
for len(data) >= 4 {
extType := binary.BigEndian.Uint16(data[:2])
extDataLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < extDataLen {
return "", fmt.Errorf("extension truncated")
}
if extType == 0x0000 { // server_name
return parseServerNameExtension(data[:extDataLen])
}
data = data[extDataLen:]
}
return "", fmt.Errorf("no SNI extension found")
}
func parseServerNameExtension(data []byte) (string, error) {
if len(data) < 2 {
return "", fmt.Errorf("server_name extension too short")
}
// Server name list length.
listLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < listLen {
return "", fmt.Errorf("server_name list truncated")
}
data = data[:listLen]
for len(data) >= 3 {
nameType := data[0]
nameLen := int(binary.BigEndian.Uint16(data[1:3]))
data = data[3:]
if len(data) < nameLen {
return "", fmt.Errorf("server_name entry truncated")
}
if nameType == 0x00 { // hostname
return strings.ToLower(string(data[:nameLen])), nil
}
data = data[nameLen:]
}
return "", fmt.Errorf("no hostname in server_name extension")
}

220
internal/sni/sni_test.go Normal file
View File

@@ -0,0 +1,220 @@
package sni
import (
"encoding/binary"
"net"
"testing"
"time"
)
func TestExtract(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{"basic", "example.com", "example.com", false},
{"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false},
{"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHello(tt.sni)
go func() {
client.Write(hello)
}()
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hostname != tt.wantSNI {
t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI)
}
if len(peeked) != len(hello) {
t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello))
}
})
}
}
func TestExtractNoSNI(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHelloNoSNI()
go func() {
client.Write(hello)
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for ClientHello without SNI")
}
}
func TestExtractNotTLS(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for non-TLS data")
}
}
func TestExtractTruncated(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
// Write just the TLS record header, then close.
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for truncated record")
}
}
func TestExtractOversizedRecord(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
// Record header claiming a length larger than 16 KiB.
header := []byte{0x16, 0x03, 0x01}
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
client.Write(header)
client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for oversized record")
}
}
func TestExtractMultipleExtensions(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHelloWithExtraExtensions("target.example.com")
go func() {
client.Write(hello)
}()
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hostname != "target.example.com" {
t.Fatalf("got hostname %q, want %q", hostname, "target.example.com")
}
}
// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension.
func buildClientHello(serverName string) []byte {
return buildClientHelloWithExtensions(sniExtension(serverName))
}
// buildClientHelloNoSNI constructs a ClientHello with no extensions.
func buildClientHelloNoSNI() []byte {
return buildClientHelloWithExtensions(nil)
}
// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI.
func buildClientHelloWithExtraExtensions(serverName string) []byte {
// Dummy extension (type 0xFF01, empty data).
dummy := []byte{0xFF, 0x01, 0x00, 0x00}
ext := append(dummy, sniExtension(serverName)...)
return buildClientHelloWithExtensions(ext)
}
func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte
// Client version: TLS 1.2.
hello = append(hello, 0x03, 0x03)
// Random: 32 bytes of zeros.
hello = append(hello, make([]byte, 32)...)
// Session ID: empty.
hello = append(hello, 0x00)
// Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256).
hello = append(hello, 0x00, 0x02, 0x00, 0x9C)
// Compression methods: null.
hello = append(hello, 0x01, 0x00)
// Extensions.
if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
hello = append(hello, extensions...)
}
// Wrap in handshake header (type 0x01 = ClientHello).
handshake := []byte{0x01, 0x00, 0x00, 0x00}
handshake[1] = byte(len(hello) >> 16)
handshake[2] = byte(len(hello) >> 8)
handshake[3] = byte(len(hello))
handshake = append(handshake, hello...)
// Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0).
record := []byte{0x16, 0x03, 0x01}
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
record = append(record, handshake...)
return record
}
func sniExtension(serverName string) []byte {
name := []byte(serverName)
// Server name entry: type 0x00 (hostname), length, name.
var entry []byte
entry = append(entry, 0x00)
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
entry = append(entry, name...)
// Server name list: length prefix.
var list []byte
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
list = append(list, entry...)
// Extension: type 0x0000 (server_name), length, data.
var ext []byte
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
ext = append(ext, list...)
return ext
}