Files
mc-proxy/internal/server/server_test.go
Kyle Isom 28321e22f4 Make AddRoute idempotent (upsert instead of reject duplicates)
AddRoute now updates an existing route if one already exists for the
same (listener, hostname) pair, instead of returning AlreadyExists.
This makes repeated deploys idempotent — the MCP agent can register
routes on every deploy without needing to remove them first.

- DB: INSERT ... ON CONFLICT DO UPDATE (SQLite upsert)
- In-memory: overwrite existing route unconditionally
- gRPC: error code changed from AlreadyExists to Internal (for real DB errors)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 14:01:45 -07:00

1635 lines
41 KiB
Go

package server
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"encoding/pem"
"fmt"
"io"
"log/slog"
"math/big"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"sync"
"testing"
"time"
"git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
// l4Route creates a RouteInfo for an L4 passthrough route.
func l4Route(backend string) RouteInfo {
return RouteInfo{Backend: backend, Mode: "l4"}
}
// echoServer accepts one connection, copies everything back, then closes.
func echoServer(t *testing.T, ln net.Listener) {
t.Helper()
conn, err := ln.Accept()
if err != nil {
return
}
defer func() { _ = conn.Close() }()
_, _ = io.Copy(conn, conn)
}
// newTestServer creates a Server with the given listener data and no firewall rules.
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
t.Helper()
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
return New(cfg, fw, listeners, logger, "test")
}
// startAndStop starts the server in a goroutine and returns a cancel function
// that shuts it down and waits for it to exit.
func startAndStop(t *testing.T, srv *Server) context.CancelFunc {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Run(ctx); err != nil {
t.Errorf("server.Run: %v", err)
}
}()
// Give the listeners a moment to bind.
time.Sleep(50 * time.Millisecond)
return func() {
cancel()
wg.Wait()
}
}
func TestProxyRoundTrip(t *testing.T) {
// Start an echo backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
// Pick a free port for the proxy listener.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Connect through the proxy with a fake ClientHello.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
// Send some additional data.
payload := []byte("hello from client")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write payload: %v", err)
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read echoed payload: %v", err)
}
if string(buf) != string(payload) {
t.Fatalf("got %q, want %q", buf, payload)
}
}
func TestNoRouteResets(t *testing.T) {
// Proxy listener with no routes for the requested hostname.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"other.test": l4Route("127.0.0.1:1"), // exists but won't match
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The proxy should close the connection (no route match).
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed, but read succeeded")
}
}
func TestFirewallBlocks(t *testing.T) {
// Start a backend that should never be reached.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
_ = conn.Close()
reached <- struct{}{}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test")
_, _ = conn.Write(hello)
// Connection should be closed (blocked by firewall).
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
}
// Backend should not have been reached.
select {
case <-reached:
t.Fatal("backend was reached despite firewall block")
case <-time.After(200 * time.Millisecond):
// Expected.
}
cancel()
wg.Wait()
}
func TestNotTLSResets(t *testing.T) {
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{"x.test": l4Route("127.0.0.1:1")},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
// Send HTTP, not TLS.
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for non-TLS data")
}
}
func TestConnectionTracking(t *testing.T) {
// Backend that holds connections open until we close it.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
var backendConns []net.Conn
var mu sync.Mutex
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
mu.Lock()
backendConns = append(backendConns, conn)
mu.Unlock()
// Hold connection open, drain input.
go func() { _, _ = io.Copy(io.Discard, conn) }()
}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"conn.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
if got := srv.TotalConnections(); got != 0 {
t.Fatalf("expected 0 connections before any clients, got %d", got)
}
// Open two connections through the proxy.
var clientConns []net.Conn
for i := range 2 {
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy %d: %v", i, err)
}
hello := buildClientHello("conn.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello %d: %v", i, err)
}
clientConns = append(clientConns, conn)
}
// Give connections time to be established.
time.Sleep(100 * time.Millisecond)
if got := srv.TotalConnections(); got != 2 {
t.Fatalf("expected 2 active connections, got %d", got)
}
// Close one client and its corresponding backend connection.
_ = clientConns[0].Close()
mu.Lock()
if len(backendConns) > 0 {
_ = backendConns[0].Close()
}
mu.Unlock()
// Wait for the relay goroutines to detect the close.
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if srv.TotalConnections() == 1 {
break
}
time.Sleep(50 * time.Millisecond)
}
if got := srv.TotalConnections(); got != 1 {
t.Fatalf("expected 1 active connection after closing one, got %d", got)
}
// Clean up.
_ = clientConns[1].Close()
mu.Lock()
for _, c := range backendConns {
_ = c.Close()
}
mu.Unlock()
}
func TestMultipleListeners(t *testing.T) {
// Two backends.
backendA, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend A listen: %v", err)
}
defer func() { _ = backendA.Close() }()
backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend B listen: %v", err)
}
defer func() { _ = backendB.Close() }()
// Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) {
conn, err := ln.Accept()
if err != nil {
return
}
defer func() { _ = conn.Close() }()
// Drain the incoming data, then write identity.
go func() { _, _ = io.Copy(io.Discard, conn) }()
_, _ = conn.Write([]byte(id))
}
go serve(backendA, "A")
go serve(backendB, "B")
// Two proxy listeners, same hostname, different backends.
ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 1: %v", err)
}
addr1 := ln1.Addr().String()
_ = ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 2: %v", err)
}
addr2 := ln2.Addr().String()
_ = ln2.Close()
srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
{ID: 2, Addr: addr2, Routes: map[string]RouteInfo{"svc.test": l4Route(backendB.Addr().String())}},
})
stop := startAndStop(t, srv)
defer stop()
readID := func(proxyAddr string) string {
t.Helper()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("svc.test")
_, _ = conn.Write(hello)
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we
// find the ID byte at the end.
var all []byte
for {
n, err := conn.Read(buf)
all = append(all, buf[:n]...)
if err != nil {
break
}
}
if len(all) == 0 {
t.Fatalf("no data from %s", proxyAddr)
}
// The ID is the last byte.
return string(all[len(all)-1:])
}
idA := readID(addr1)
idB := readID(addr2)
if idA != "A" {
t.Fatalf("listener 1: got backend %q, want A", idA)
}
if idB != "B" {
t.Fatalf("listener 2: got backend %q, want B", idB)
}
}
func TestCaseInsensitiveRouting(t *testing.T) {
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// SNI extraction lowercases the hostname, so "ECHO.TEST" should match
// the route for "echo.test".
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
echoed := make([]byte, len(hello))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
}
func TestBackendUnreachable(t *testing.T) {
// Find a port that nothing is listening on.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
deadAddr := ln.Addr().String()
_ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"dead.test": l4Route(deadAddr),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("dead.test")
_, _ = conn.Write(hello)
// Proxy should close the connection after failing to dial backend.
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable")
}
}
func TestGracefulShutdown(t *testing.T) {
// Backend that holds the connection open.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer func() { _ = conn.Close() }()
_, _ = io.Copy(io.Discard, conn)
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 2 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"hold.test": l4Route(backendLn.Addr().String())}},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
// Establish a connection that will be in-flight during shutdown.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("hold.test")
_, _ = conn.Write(hello)
time.Sleep(50 * time.Millisecond)
// Trigger shutdown.
cancel()
// Server should exit within the shutdown timeout.
select {
case err := <-done:
if err != nil {
t.Fatalf("server.Run returned error: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatal("server did not shut down within 5 seconds")
}
}
func TestListenerStateRoutes(t *testing.T) {
ls := &ListenerState{
ID: 1,
Addr: ":443",
routes: map[string]RouteInfo{
"a.test": l4Route("127.0.0.1:1"),
},
}
// AddRoute
ls.AddRoute("b.test", l4Route("127.0.0.1:2"))
// AddRoute duplicate (upsert — updates in place)
ls.AddRoute("b.test", l4Route("127.0.0.1:3"))
// Routes snapshot — still 2 routes, the duplicate replaced the first.
routes := ls.Routes()
if len(routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(routes))
}
// RemoveRoute
if err := ls.RemoveRoute("a.test"); err != nil {
t.Fatalf("RemoveRoute: %v", err)
}
// RemoveRoute not found
if err := ls.RemoveRoute("nonexistent.test"); err == nil {
t.Fatal("expected error for removing nonexistent route")
}
routes = ls.Routes()
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes["b.test"].Backend != "127.0.0.1:3" {
t.Fatalf("expected b.test → 127.0.0.1:3 (upserted), got %q", routes["b.test"].Backend)
}
}
func TestProxyProtocolReceive(t *testing.T) {
// Backend echoes everything back.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"echo.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
// Send PROXY v2 header followed by TLS ClientHello.
var ppBuf bytes.Buffer
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
_, _ = conn.Write(ppBuf.Bytes())
hello := buildClientHello("echo.test")
_, _ = conn.Write(hello)
// Backend should echo the ClientHello back (not the PROXY header).
echoed := make([]byte, len(hello))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
}
func TestProxyProtocolReceiveGarbage(t *testing.T) {
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"echo.test": l4Route("127.0.0.1:1"),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
// Send garbage instead of a valid PROXY header.
_, _ = conn.Write([]byte("NOT A PROXY HEADER\r\n"))
// Connection should be closed.
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for invalid PROXY header")
}
}
func TestProxyProtocolSend(t *testing.T) {
// Backend that captures the first bytes it receives.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer func() { _ = conn.Close() }()
// Read all available data; the proxy sends PROXY header + ClientHello.
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var all []byte
buf := make([]byte, 4096)
for {
n, err := conn.Read(buf)
all = append(all, buf[:n]...)
if err != nil {
break
}
// We expect at least pp header (28) + some TLS data.
if len(all) >= 28+5 {
break
}
}
received <- all
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"pp.test": {
Backend: backendLn.Addr().String(),
Mode: "l4",
SendProxyProtocol: true,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("pp.test")
_, _ = conn.Write(hello)
// The backend should receive: PROXY v2 header + ClientHello.
select {
case data := <-received:
// PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs).
if len(data) < 28 {
t.Fatalf("backend received only %d bytes, want at least 28", len(data))
}
// Check PROXY v2 signature.
sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
if [12]byte(data[:12]) != sig {
t.Fatal("backend data does not start with PROXY v2 signature")
}
// Verify TLS record header follows the PROXY header.
ppLen := 28 // v2 IPv4
if len(data) <= ppLen {
t.Fatalf("backend received only PROXY header, no TLS data")
}
if data[ppLen] != 0x16 {
t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen])
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for backend data")
}
}
func TestProxyProtocolNotSent(t *testing.T) {
// Backend captures first bytes.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer func() { _ = conn.Close() }()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
received <- buf[:n]
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"nopp.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
hello := buildClientHello("nopp.test")
_, _ = conn.Write(hello)
select {
case data := <-received:
// First byte should be TLS record header, not PROXY signature.
if data[0] != 0x16 {
t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0])
}
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for backend data")
}
}
func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
// Backend that should never be reached.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
_ = conn.Close()
reached <- struct{}{}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
// Block 203.0.113.50 (the "real" client IP from PROXY header).
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"blocked.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer func() { _ = conn.Close() }()
// Send PROXY v2 with the blocked real IP.
var ppBuf bytes.Buffer
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
_, _ = conn.Write(ppBuf.Bytes())
_, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be closed (firewall blocks real IP).
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
}
// Backend should not have been reached.
select {
case <-reached:
t.Fatal("backend was reached despite firewall blocking real IP")
case <-time.After(200 * time.Millisecond):
// Expected.
}
cancel()
wg.Wait()
}
// --- Connection limit tests ---
func TestConnectionLimitEnforced(t *testing.T) {
// Backend that holds connections open.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
go func() { _, _ = io.Copy(io.Discard, conn) }()
}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
MaxConnections: 2,
Routes: map[string]RouteInfo{
"limit.test": l4Route(backendLn.Addr().String()),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Open 2 connections (should succeed).
var conns []net.Conn
for i := range 2 {
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial %d: %v", i, err)
}
_, _ = conn.Write(buildClientHello("limit.test"))
conns = append(conns, conn)
}
time.Sleep(100 * time.Millisecond)
// 3rd connection should be rejected (closed immediately).
conn3, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial 3: %v", err)
}
_, _ = conn3.Write(buildClientHello("limit.test"))
_ = conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn3.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected 3rd connection to be closed due to limit")
}
_ = conn3.Close()
// Close one existing connection.
_ = conns[0].Close()
time.Sleep(200 * time.Millisecond)
// Now a new connection should succeed.
conn4, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial 4: %v", err)
}
defer func() { _ = conn4.Close() }()
_, _ = conn4.Write(buildClientHello("limit.test"))
// Give it time to be proxied.
time.Sleep(100 * time.Millisecond)
if got := srv.TotalConnections(); got < 2 {
t.Fatalf("expected at least 2 connections, got %d", got)
}
// Clean up.
for _, c := range conns[1:] {
_ = c.Close()
}
}
// --- Multi-hop integration tests ---
func TestMultiHopProxyProtocol(t *testing.T) {
// Simulates edge → origin deployment with PROXY protocol.
//
// Client → [edge proxy] → PROXY v2 + TLS → [origin proxy] → h2c backend
// proxy_protocol=true
// L7 route
certPath, keyPath := testCert(t, "multihop.test")
// h2c backend on origin that echoes the X-Forwarded-For.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
}))
// Origin proxy: proxy_protocol=true listener, L7 route to backend.
originLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("origin listen: %v", err)
}
originAddr := originLn.Addr().String()
_ = originLn.Close()
originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
originCfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
originLogger := slog.New(slog.NewTextHandler(io.Discard, nil))
originSrv := New(originCfg, originFw, []ListenerData{
{
ID: 1,
Addr: originAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"multihop.test": {
Backend: backendAddr,
Mode: "l7",
TLSCert: certPath,
TLSKey: keyPath,
},
},
},
}, originLogger, "origin-test")
originCtx, originCancel := context.WithCancel(context.Background())
var originWg sync.WaitGroup
originWg.Add(1)
go func() {
defer originWg.Done()
_ = originSrv.Run(originCtx)
}()
time.Sleep(50 * time.Millisecond)
defer func() {
originCancel()
originWg.Wait()
}()
// Edge proxy: L4 passthrough with send_proxy_protocol=true to origin.
edgeLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("edge listen: %v", err)
}
edgeAddr := edgeLn.Addr().String()
_ = edgeLn.Close()
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
edgeSrv := New(originCfg, edgeFw, []ListenerData{
{
ID: 2,
Addr: edgeAddr,
Routes: map[string]RouteInfo{
"multihop.test": {
Backend: originAddr,
Mode: "l4",
SendProxyProtocol: true,
},
},
},
}, originLogger, "edge-test")
edgeCtx, edgeCancel := context.WithCancel(context.Background())
var edgeWg sync.WaitGroup
edgeWg.Add(1)
go func() {
defer edgeWg.Done()
_ = edgeSrv.Run(edgeCtx)
}()
time.Sleep(50 * time.Millisecond)
defer func() {
edgeCancel()
edgeWg.Wait()
}()
// Client connects to edge proxy with TLS, as if it were a real client.
tlsConf := &tls.Config{
ServerName: "multihop.test",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", edgeAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial edge: %v", err)
}
defer func() { _ = conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
if err != nil {
t.Fatalf("h2 client conn: %v", err)
}
req, _ := http.NewRequest("GET", "https://multihop.test/test", nil)
resp, err := h2conn.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200, body = %s", resp.StatusCode, body)
}
// The X-Forwarded-For should be 127.0.0.1 (the edge proxy's TCP client IP),
// carried through via PROXY protocol from edge to origin.
got := string(body)
if got != "xff=127.0.0.1" {
t.Fatalf("body = %q, want %q", got, "xff=127.0.0.1")
}
}
func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
// Origin proxy with proxy_protocol=true and a firewall that blocks
// the real client IP. The TCP peer (edge proxy) is NOT blocked.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
_ = conn.Close()
reached <- struct{}{}
}()
originLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("origin listen: %v", err)
}
originAddr := originLn.Addr().String()
_ = originLn.Close()
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
originSrv := New(cfg, originFw, []ListenerData{
{
ID: 1,
Addr: originAddr,
ProxyProtocol: true,
Routes: map[string]RouteInfo{
"blocked.test": l4Route(backendLn.Addr().String()),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = originSrv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
// Connect to origin and send PROXY header with the blocked IP.
conn, err := net.DialTimeout("tcp", originAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial origin: %v", err)
}
defer func() { _ = conn.Close() }()
var ppBuf bytes.Buffer
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("198.51.100.99:12345"),
netip.MustParseAddrPort("10.0.0.1:443"),
)
_, _ = conn.Write(ppBuf.Bytes())
_, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be dropped by firewall.
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed")
}
select {
case <-reached:
t.Fatal("backend was reached despite firewall block on real IP")
case <-time.After(200 * time.Millisecond):
}
cancel()
wg.Wait()
}
// --- L7 server integration tests ---
// testCert generates a self-signed TLS certificate for the given hostname.
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generating key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: hostname},
DNSNames: []string{hostname},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("creating certificate: %v", err)
}
dir := t.TempDir()
certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")
cf, _ := os.Create(certPath)
_ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
_ = cf.Close()
keyDER, _ := x509.MarshalECPrivateKey(key)
kf, _ := os.Create(keyPath)
_ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
_ = kf.Close()
return
}
// startH2CBackend starts an h2c backend for testing.
func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Helper()
h2s := &http2.Server{}
srv := &http.Server{
Handler: h2c.NewHandler(handler, h2s),
ReadHeaderTimeout: 5 * time.Second,
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { _ = srv.Close(); _ = ln.Close() })
go func() { _ = srv.Serve(ln) }()
return ln.Addr().String()
}
func TestL7ThroughServer(t *testing.T) {
certPath, keyPath := testCert(t, "l7srv.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"l7srv.test": {
Backend: backendAddr,
Mode: "l7",
TLSCert: certPath,
TLSKey: keyPath,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Connect with TLS and make an HTTP/2 request.
tlsConf := &tls.Config{
ServerName: "l7srv.test",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
defer func() { _ = conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
if err != nil {
t.Fatalf("h2 client conn: %v", err)
}
req, _ := http.NewRequest("GET", "https://l7srv.test/hello", nil)
resp, err := h2conn.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
// no PROXY protocol is in use.
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
got := string(body)
if got != "ok path=/hello xff=127.0.0.1" {
t.Fatalf("body = %q", got)
}
}
func TestMixedL4L7SameListener(t *testing.T) {
certPath, keyPath := testCert(t, "l7mixed.test")
// L4 backend: echo server.
l4BackendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("l4 backend listen: %v", err)
}
defer func() { _ = l4BackendLn.Close() }()
go echoServer(t, l4BackendLn)
// L7 backend: h2c HTTP server.
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
_, _ = fmt.Fprint(w, "l7-response")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"l4echo.test": l4Route(l4BackendLn.Addr().String()),
"l7mixed.test": {
Backend: l7BackendAddr,
Mode: "l7",
TLSCert: certPath,
TLSKey: keyPath,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Test L4 route works: send ClientHello, get echo.
l4Conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial L4: %v", err)
}
defer func() { _ = l4Conn.Close() }()
hello := buildClientHello("l4echo.test")
_, _ = l4Conn.Write(hello)
echoed := make([]byte, len(hello))
_ = l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
t.Fatalf("L4 echo read: %v", err)
}
// Test L7 route works: TLS + HTTP/2.
tlsConf := &tls.Config{
ServerName: "l7mixed.test",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
l7Conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial L7: %v", err)
}
defer func() { _ = l7Conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(l7Conn)
if err != nil {
t.Fatalf("h2 client conn: %v", err)
}
req, _ := http.NewRequest("GET", "https://l7mixed.test/", nil)
resp, err := h2conn.RoundTrip(req)
if err != nil {
t.Fatalf("L7 RoundTrip: %v", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if string(body) != "l7-response" {
t.Fatalf("L7 body = %q, want %q", body, "l7-response")
}
}
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
func buildClientHello(serverName string) []byte {
return buildClientHelloWithExtensions(sniExtension(serverName))
}
func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte
hello = append(hello, 0x03, 0x03) // TLS 1.2
hello = append(hello, make([]byte, 32)...) // random
hello = append(hello, 0x00) // session ID: empty
hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites
hello = append(hello, 0x01, 0x00) // compression methods
if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
hello = append(hello, extensions...)
}
handshake := []byte{0x01, 0x00, 0x00, 0x00}
handshake[1] = byte(len(hello) >> 16)
handshake[2] = byte(len(hello) >> 8)
handshake[3] = byte(len(hello))
handshake = append(handshake, hello...)
record := []byte{0x16, 0x03, 0x01}
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
record = append(record, handshake...)
return record
}
func sniExtension(serverName string) []byte {
name := []byte(serverName)
var entry []byte
entry = append(entry, 0x00)
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
entry = append(entry, name...)
var list []byte
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
list = append(list, entry...)
var ext []byte
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
ext = append(ext, list...)
return ext
}