Files
mc-proxy/internal/server/server_test.go
Kyle Isom 1ad9a1a43b Add PROXY protocol v1/v2 support for multi-hop deployments
New internal/proxyproto package implements PROXY protocol parsing and
writing without buffering past the header boundary (reads exact byte
counts so the connection is correctly positioned for SNI extraction).

Parser: auto-detects v1 (text) and v2 (binary) by first byte. Parses
TCP4/TCP6 for both versions plus v2 LOCAL command. Enforces max header
sizes and read deadlines.

Writer: generates v2 binary headers for IPv4 and IPv6 with PROXY
command.

Server integration:
- Receive: when listener.ProxyProtocol is true, parses PROXY header
  before firewall check. Real client IP from header is used for
  firewall evaluation and logging. Malformed headers cause RST.
- Send: when route.SendProxyProtocol is true, writes PROXY v2 header
  to backend before forwarding the ClientHello bytes.

Tests cover v1/v2 parsing, malformed rejection, timeout, round-trip
write+parse, and five server integration tests: receive with valid
header, receive with garbage, send verification, send-disabled
verification, and firewall evaluation using the real client IP.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 13:28:49 -07:00

1094 lines
26 KiB
Go

package server
import (
"bytes"
"context"
"encoding/binary"
"io"
"log/slog"
"net"
"net/netip"
"sync"
"testing"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
)
// l4Route creates a RouteInfo for an L4 passthrough route.
func l4Route(backend string) RouteInfo {
return RouteInfo{Backend: backend, Mode: "l4"}
}
// echoServer accepts one connection, copies everything back, then closes.
func echoServer(t *testing.T, ln net.Listener) {
t.Helper()
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
io.Copy(conn, conn)
}
// newTestServer creates a Server with the given listener data and no firewall rules.
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
t.Helper()
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
return New(cfg, fw, listeners, logger, "test")
}
// startAndStop starts the server in a goroutine and returns a cancel function
// that shuts it down and waits for it to exit.
func startAndStop(t *testing.T, srv *Server) context.CancelFunc {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Run(ctx); err != nil {
t.Errorf("server.Run: %v", err)
}
}()
// Give the listeners a moment to bind.
time.Sleep(50 * time.Millisecond)
return func() {
cancel()
wg.Wait()
}
}
func TestProxyRoundTrip(t *testing.T) {
// Start an echo backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go echoServer(t, backendLn)
// Pick a free port for the proxy listener.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Connect through the proxy with a fake ClientHello.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
// Send some additional data.
payload := []byte("hello from client")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write payload: %v", err)
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read echoed payload: %v", err)
}
if string(buf) != string(payload) {
t.Fatalf("got %q, want %q", buf, payload)
}
}
func TestNoRouteResets(t *testing.T) {
// Proxy listener with no routes for the requested hostname.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"other.test": l4Route("127.0.0.1:1"), // exists but won't match
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The proxy should close the connection (no route match).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed, but read succeeded")
}
}
func TestFirewallBlocks(t *testing.T) {
// Start a backend that should never be reached.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
conn.Close()
reached <- struct{}{}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("echo.test")
conn.Write(hello)
// Connection should be closed (blocked by firewall).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
}
// Backend should not have been reached.
select {
case <-reached:
t.Fatal("backend was reached despite firewall block")
case <-time.After(200 * time.Millisecond):
// Expected.
}
cancel()
wg.Wait()
}
func TestNotTLSResets(t *testing.T) {
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{"x.test": l4Route("127.0.0.1:1")},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
// Send HTTP, not TLS.
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for non-TLS data")
}
}
func TestConnectionTracking(t *testing.T) {
// Backend that holds connections open until we close it.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
var backendConns []net.Conn
var mu sync.Mutex
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
mu.Lock()
backendConns = append(backendConns, conn)
mu.Unlock()
// Hold connection open, drain input.
go io.Copy(io.Discard, conn)
}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"conn.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
if got := srv.TotalConnections(); got != 0 {
t.Fatalf("expected 0 connections before any clients, got %d", got)
}
// Open two connections through the proxy.
var clientConns []net.Conn
for i := range 2 {
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy %d: %v", i, err)
}
hello := buildClientHello("conn.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello %d: %v", i, err)
}
clientConns = append(clientConns, conn)
}
// Give connections time to be established.
time.Sleep(100 * time.Millisecond)
if got := srv.TotalConnections(); got != 2 {
t.Fatalf("expected 2 active connections, got %d", got)
}
// Close one client and its corresponding backend connection.
clientConns[0].Close()
mu.Lock()
if len(backendConns) > 0 {
backendConns[0].Close()
}
mu.Unlock()
// Wait for the relay goroutines to detect the close.
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if srv.TotalConnections() == 1 {
break
}
time.Sleep(50 * time.Millisecond)
}
if got := srv.TotalConnections(); got != 1 {
t.Fatalf("expected 1 active connection after closing one, got %d", got)
}
// Clean up.
clientConns[1].Close()
mu.Lock()
for _, c := range backendConns {
c.Close()
}
mu.Unlock()
}
func TestMultipleListeners(t *testing.T) {
// Two backends.
backendA, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend A listen: %v", err)
}
defer backendA.Close()
backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend B listen: %v", err)
}
defer backendB.Close()
// Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
// Drain the incoming data, then write identity.
go io.Copy(io.Discard, conn)
conn.Write([]byte(id))
}
go serve(backendA, "A")
go serve(backendB, "B")
// Two proxy listeners, same hostname, different backends.
ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 1: %v", err)
}
addr1 := ln1.Addr().String()
ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 2: %v", err)
}
addr2 := ln2.Addr().String()
ln2.Close()
srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
{ID: 2, Addr: addr2, Routes: map[string]RouteInfo{"svc.test": l4Route(backendB.Addr().String())}},
})
stop := startAndStop(t, srv)
defer stop()
readID := func(proxyAddr string) string {
t.Helper()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err)
}
defer conn.Close()
hello := buildClientHello("svc.test")
conn.Write(hello)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we
// find the ID byte at the end.
var all []byte
for {
n, err := conn.Read(buf)
all = append(all, buf[:n]...)
if err != nil {
break
}
}
if len(all) == 0 {
t.Fatalf("no data from %s", proxyAddr)
}
// The ID is the last byte.
return string(all[len(all)-1:])
}
idA := readID(addr1)
idB := readID(addr2)
if idA != "A" {
t.Fatalf("listener 1: got backend %q, want A", idA)
}
if idB != "B" {
t.Fatalf("listener 2: got backend %q, want B", idB)
}
}
func TestCaseInsensitiveRouting(t *testing.T) {
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// SNI extraction lowercases the hostname, so "ECHO.TEST" should match
// the route for "echo.test".
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
}
func TestBackendUnreachable(t *testing.T) {
// Find a port that nothing is listening on.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"dead.test": l4Route(deadAddr),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("dead.test")
conn.Write(hello)
// Proxy should close the connection after failing to dial backend.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable")
}
}
func TestGracefulShutdown(t *testing.T) {
// Backend that holds the connection open.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
io.Copy(io.Discard, conn)
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 2 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"hold.test": l4Route(backendLn.Addr().String())}},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
// Establish a connection that will be in-flight during shutdown.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("hold.test")
conn.Write(hello)
time.Sleep(50 * time.Millisecond)
// Trigger shutdown.
cancel()
// Server should exit within the shutdown timeout.
select {
case err := <-done:
if err != nil {
t.Fatalf("server.Run returned error: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatal("server did not shut down within 5 seconds")
}
}
func TestListenerStateRoutes(t *testing.T) {
ls := &ListenerState{
ID: 1,
Addr: ":443",
routes: map[string]RouteInfo{
"a.test": l4Route("127.0.0.1:1"),
},
}
// AddRoute
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil {
t.Fatalf("AddRoute: %v", err)
}
// AddRoute duplicate
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil {
t.Fatal("expected error for duplicate route")
}
// Routes snapshot
routes := ls.Routes()
if len(routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(routes))
}
// RemoveRoute
if err := ls.RemoveRoute("a.test"); err != nil {
t.Fatalf("RemoveRoute: %v", err)
}
// RemoveRoute not found
if err := ls.RemoveRoute("nonexistent.test"); err == nil {
t.Fatal("expected error for removing nonexistent route")
}
routes = ls.Routes()
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes["b.test"].Backend != "127.0.0.1:2" {
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend)
}
}
func TestProxyProtocolReceive(t *testing.T) {
// Backend echoes everything back.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
// Send PROXY v2 header followed by TLS ClientHello.
var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
conn.Write(ppBuf.Bytes())
hello := buildClientHello("echo.test")
conn.Write(hello)
// Backend should echo the ClientHello back (not the PROXY header).
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
}
func TestProxyProtocolReceiveGarbage(t *testing.T) {
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"echo.test": l4Route("127.0.0.1:1"),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
// Send garbage instead of a valid PROXY header.
conn.Write([]byte("NOT A PROXY HEADER\r\n"))
// Connection should be closed.
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for invalid PROXY header")
}
}
func TestProxyProtocolSend(t *testing.T) {
// Backend that captures the first bytes it receives.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
received := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
// Read all available data; the proxy sends PROXY header + ClientHello.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var all []byte
buf := make([]byte, 4096)
for {
n, err := conn.Read(buf)
all = append(all, buf[:n]...)
if err != nil {
break
}
// We expect at least pp header (28) + some TLS data.
if len(all) >= 28+5 {
break
}
}
received <- all
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"pp.test": {
Backend: backendLn.Addr().String(),
Mode: "l4",
SendProxyProtocol: true,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("pp.test")
conn.Write(hello)
// The backend should receive: PROXY v2 header + ClientHello.
select {
case data := <-received:
// PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs).
if len(data) < 28 {
t.Fatalf("backend received only %d bytes, want at least 28", len(data))
}
// Check PROXY v2 signature.
sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
if [12]byte(data[:12]) != sig {
t.Fatal("backend data does not start with PROXY v2 signature")
}
// Verify TLS record header follows the PROXY header.
ppLen := 28 // v2 IPv4
if len(data) <= ppLen {
t.Fatalf("backend received only PROXY header, no TLS data")
}
if data[ppLen] != 0x16 {
t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen])
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for backend data")
}
}
func TestProxyProtocolNotSent(t *testing.T) {
// Backend captures first bytes.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
received := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
received <- buf[:n]
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"nopp.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("nopp.test")
conn.Write(hello)
select {
case data := <-received:
// First byte should be TLS record header, not PROXY signature.
if data[0] != 0x16 {
t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0])
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for backend data")
}
}
func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
// Backend that should never be reached.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
conn.Close()
reached <- struct{}{}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
// Block 203.0.113.50 (the "real" client IP from PROXY header).
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"blocked.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
// Send PROXY v2 with the blocked real IP.
var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test"))
// Connection should be closed (firewall blocks real IP).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
}
// Backend should not have been reached.
select {
case <-reached:
t.Fatal("backend was reached despite firewall blocking real IP")
case <-time.After(200 * time.Millisecond):
// Expected.
}
cancel()
wg.Wait()
}
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
func buildClientHello(serverName string) []byte {
return buildClientHelloWithExtensions(sniExtension(serverName))
}
func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte
hello = append(hello, 0x03, 0x03) // TLS 1.2
hello = append(hello, make([]byte, 32)...) // random
hello = append(hello, 0x00) // session ID: empty
hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites
hello = append(hello, 0x01, 0x00) // compression methods
if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
hello = append(hello, extensions...)
}
handshake := []byte{0x01, 0x00, 0x00, 0x00}
handshake[1] = byte(len(hello) >> 16)
handshake[2] = byte(len(hello) >> 8)
handshake[3] = byte(len(hello))
handshake = append(handshake, hello...)
record := []byte{0x16, 0x03, 0x01}
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
record = append(record, handshake...)
return record
}
func sniExtension(serverName string) []byte {
name := []byte(serverName)
var entry []byte
entry = append(entry, 0x00)
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
entry = append(entry, name...)
var list []byte
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
list = append(list, entry...)
var ext []byte
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
ext = append(ext, list...)
return ext
}