package server import ( "bytes" "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/binary" "encoding/pem" "fmt" "io" "log/slog" "math/big" "net" "net/http" "net/netip" "os" "path/filepath" "sync" "testing" "time" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) // l4Route creates a RouteInfo for an L4 passthrough route. func l4Route(backend string) RouteInfo { return RouteInfo{Backend: backend, Mode: "l4"} } // echoServer accepts one connection, copies everything back, then closes. func echoServer(t *testing.T, ln net.Listener) { t.Helper() conn, err := ln.Accept() if err != nil { return } defer conn.Close() io.Copy(conn, conn) } // newTestServer creates a Server with the given listener data and no firewall rules. func newTestServer(t *testing.T, listeners []ListenerData) *Server { t.Helper() fw, err := firewall.New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) return New(cfg, fw, listeners, logger, "test") } // startAndStop starts the server in a goroutine and returns a cancel function // that shuts it down and waits for it to exit. func startAndStop(t *testing.T, srv *Server) context.CancelFunc { t.Helper() ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() if err := srv.Run(ctx); err != nil { t.Errorf("server.Run: %v", err) } }() // Give the listeners a moment to bind. time.Sleep(50 * time.Millisecond) return func() { cancel() wg.Wait() } } func TestProxyRoundTrip(t *testing.T) { // Start an echo backend. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go echoServer(t, backendLn) // Pick a free port for the proxy listener. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "echo.test": l4Route(backendLn.Addr().String()), }, }, }) stop := startAndStop(t, srv) defer stop() // Connect through the proxy with a fake ClientHello. conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("echo.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } // The backend will echo our ClientHello back. Read it. echoed := make([]byte, len(hello)) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(conn, echoed); err != nil { t.Fatalf("read echoed data: %v", err) } // Send some additional data. payload := []byte("hello from client") if _, err := conn.Write(payload); err != nil { t.Fatalf("write payload: %v", err) } buf := make([]byte, len(payload)) if _, err := io.ReadFull(conn, buf); err != nil { t.Fatalf("read echoed payload: %v", err) } if string(buf) != string(payload) { t.Fatalf("got %q, want %q", buf, payload) } } func TestNoRouteResets(t *testing.T) { // Proxy listener with no routes for the requested hostname. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "other.test": l4Route("127.0.0.1:1"), // exists but won't match }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("unknown.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } // The proxy should close the connection (no route match). conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed, but read succeeded") } } func TestFirewallBlocks(t *testing.T) { // Start a backend that should never be reached. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() reached := make(chan struct{}, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } conn.Close() reached <- struct{}{} }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() // Create a firewall that blocks 127.0.0.1 (the test client). fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(cfg, fw, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "echo.test": l4Route(backendLn.Addr().String()), }, }, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() srv.Run(ctx) }() time.Sleep(50 * time.Millisecond) conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("echo.test") conn.Write(hello) // Connection should be closed (blocked by firewall). conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed by firewall") } // Backend should not have been reached. select { case <-reached: t.Fatal("backend was reached despite firewall block") case <-time.After(200 * time.Millisecond): // Expected. } cancel() wg.Wait() } func TestNotTLSResets(t *testing.T) { proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"x.test": l4Route("127.0.0.1:1")}, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() // Send HTTP, not TLS. conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n")) conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed for non-TLS data") } } func TestConnectionTracking(t *testing.T) { // Backend that holds connections open until we close it. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() var backendConns []net.Conn var mu sync.Mutex go func() { for { conn, err := backendLn.Accept() if err != nil { return } mu.Lock() backendConns = append(backendConns, conn) mu.Unlock() // Hold connection open, drain input. go io.Copy(io.Discard, conn) } }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "conn.test": l4Route(backendLn.Addr().String()), }, }, }) stop := startAndStop(t, srv) defer stop() if got := srv.TotalConnections(); got != 0 { t.Fatalf("expected 0 connections before any clients, got %d", got) } // Open two connections through the proxy. var clientConns []net.Conn for i := range 2 { conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy %d: %v", i, err) } hello := buildClientHello("conn.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello %d: %v", i, err) } clientConns = append(clientConns, conn) } // Give connections time to be established. time.Sleep(100 * time.Millisecond) if got := srv.TotalConnections(); got != 2 { t.Fatalf("expected 2 active connections, got %d", got) } // Close one client and its corresponding backend connection. clientConns[0].Close() mu.Lock() if len(backendConns) > 0 { backendConns[0].Close() } mu.Unlock() // Wait for the relay goroutines to detect the close. deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { if srv.TotalConnections() == 1 { break } time.Sleep(50 * time.Millisecond) } if got := srv.TotalConnections(); got != 1 { t.Fatalf("expected 1 active connection after closing one, got %d", got) } // Clean up. clientConns[1].Close() mu.Lock() for _, c := range backendConns { c.Close() } mu.Unlock() } func TestMultipleListeners(t *testing.T) { // Two backends. backendA, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend A listen: %v", err) } defer backendA.Close() backendB, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend B listen: %v", err) } defer backendB.Close() // Each backend writes its identity and closes. serve := func(ln net.Listener, id string) { conn, err := ln.Accept() if err != nil { return } defer conn.Close() // Drain the incoming data, then write identity. go io.Copy(io.Discard, conn) conn.Write([]byte(id)) } go serve(backendA, "A") go serve(backendB, "B") // Two proxy listeners, same hostname, different backends. ln1, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port 1: %v", err) } addr1 := ln1.Addr().String() ln1.Close() ln2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port 2: %v", err) } addr2 := ln2.Addr().String() ln2.Close() srv := newTestServer(t, []ListenerData{ {ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}}, {ID: 2, Addr: addr2, Routes: map[string]RouteInfo{"svc.test": l4Route(backendB.Addr().String())}}, }) stop := startAndStop(t, srv) defer stop() readID := func(proxyAddr string) string { t.Helper() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial %s: %v", proxyAddr, err) } defer conn.Close() hello := buildClientHello("svc.test") conn.Write(hello) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) buf := make([]byte, 128) // Read what the backend sends back: echoed ClientHello + ID. // The backend drains input and writes the ID, so we read until we // find the ID byte at the end. var all []byte for { n, err := conn.Read(buf) all = append(all, buf[:n]...) if err != nil { break } } if len(all) == 0 { t.Fatalf("no data from %s", proxyAddr) } // The ID is the last byte. return string(all[len(all)-1:]) } idA := readID(addr1) idB := readID(addr2) if idA != "A" { t.Fatalf("listener 1: got backend %q, want A", idA) } if idB != "B" { t.Fatalf("listener 2: got backend %q, want B", idB) } } func TestCaseInsensitiveRouting(t *testing.T) { backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go echoServer(t, backendLn) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "echo.test": l4Route(backendLn.Addr().String()), }, }, }) stop := startAndStop(t, srv) defer stop() // SNI extraction lowercases the hostname, so "ECHO.TEST" should match // the route for "echo.test". conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("ECHO.TEST") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } echoed := make([]byte, len(hello)) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(conn, echoed); err != nil { t.Fatalf("read echoed data: %v", err) } } func TestBackendUnreachable(t *testing.T) { // Find a port that nothing is listening on. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } deadAddr := ln.Addr().String() ln.Close() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "dead.test": l4Route(deadAddr), }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("dead.test") conn.Write(hello) // Proxy should close the connection after failing to dial backend. conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed when backend is unreachable") } } func TestGracefulShutdown(t *testing.T) { // Backend that holds the connection open. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go func() { conn, err := backendLn.Accept() if err != nil { return } defer conn.Close() io.Copy(io.Discard, conn) }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() fw, err := firewall.New("", nil, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 2 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(cfg, fw, []ListenerData{ {ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{"hold.test": l4Route(backendLn.Addr().String())}}, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { done <- srv.Run(ctx) }() time.Sleep(50 * time.Millisecond) // Establish a connection that will be in-flight during shutdown. conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("hold.test") conn.Write(hello) time.Sleep(50 * time.Millisecond) // Trigger shutdown. cancel() // Server should exit within the shutdown timeout. select { case err := <-done: if err != nil { t.Fatalf("server.Run returned error: %v", err) } case <-time.After(5 * time.Second): t.Fatal("server did not shut down within 5 seconds") } } func TestListenerStateRoutes(t *testing.T) { ls := &ListenerState{ ID: 1, Addr: ":443", routes: map[string]RouteInfo{ "a.test": l4Route("127.0.0.1:1"), }, } // AddRoute if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil { t.Fatalf("AddRoute: %v", err) } // AddRoute duplicate if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil { t.Fatal("expected error for duplicate route") } // Routes snapshot routes := ls.Routes() if len(routes) != 2 { t.Fatalf("expected 2 routes, got %d", len(routes)) } // RemoveRoute if err := ls.RemoveRoute("a.test"); err != nil { t.Fatalf("RemoveRoute: %v", err) } // RemoveRoute not found if err := ls.RemoveRoute("nonexistent.test"); err == nil { t.Fatal("expected error for removing nonexistent route") } routes = ls.Routes() if len(routes) != 1 { t.Fatalf("expected 1 route, got %d", len(routes)) } if routes["b.test"].Backend != "127.0.0.1:2" { t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend) } } func TestProxyProtocolReceive(t *testing.T) { // Backend echoes everything back. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go echoServer(t, backendLn) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, ProxyProtocol: true, Routes: map[string]RouteInfo{ "echo.test": l4Route(backendLn.Addr().String()), }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() // Send PROXY v2 header followed by TLS ClientHello. var ppBuf bytes.Buffer proxyproto.WriteV2(&ppBuf, netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("198.51.100.1:443"), ) conn.Write(ppBuf.Bytes()) hello := buildClientHello("echo.test") conn.Write(hello) // Backend should echo the ClientHello back (not the PROXY header). echoed := make([]byte, len(hello)) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(conn, echoed); err != nil { t.Fatalf("read echoed data: %v", err) } } func TestProxyProtocolReceiveGarbage(t *testing.T) { proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, ProxyProtocol: true, Routes: map[string]RouteInfo{ "echo.test": l4Route("127.0.0.1:1"), }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() // Send garbage instead of a valid PROXY header. conn.Write([]byte("NOT A PROXY HEADER\r\n")) // Connection should be closed. conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed for invalid PROXY header") } } func TestProxyProtocolSend(t *testing.T) { // Backend that captures the first bytes it receives. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() received := make(chan []byte, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } defer conn.Close() // Read all available data; the proxy sends PROXY header + ClientHello. conn.SetReadDeadline(time.Now().Add(5 * time.Second)) var all []byte buf := make([]byte, 4096) for { n, err := conn.Read(buf) all = append(all, buf[:n]...) if err != nil { break } // We expect at least pp header (28) + some TLS data. if len(all) >= 28+5 { break } } received <- all }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "pp.test": { Backend: backendLn.Addr().String(), Mode: "l4", SendProxyProtocol: true, }, }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("pp.test") conn.Write(hello) // The backend should receive: PROXY v2 header + ClientHello. select { case data := <-received: // PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs). if len(data) < 28 { t.Fatalf("backend received only %d bytes, want at least 28", len(data)) } // Check PROXY v2 signature. sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} if [12]byte(data[:12]) != sig { t.Fatal("backend data does not start with PROXY v2 signature") } // Verify TLS record header follows the PROXY header. ppLen := 28 // v2 IPv4 if len(data) <= ppLen { t.Fatalf("backend received only PROXY header, no TLS data") } if data[ppLen] != 0x16 { t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen]) } case <-time.After(5 * time.Second): t.Fatal("timeout waiting for backend data") } } func TestProxyProtocolNotSent(t *testing.T) { // Backend captures first bytes. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() received := make(chan []byte, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } defer conn.Close() buf := make([]byte, 4096) n, _ := conn.Read(buf) received <- buf[:n] }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "nopp.test": l4Route(backendLn.Addr().String()), }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("nopp.test") conn.Write(hello) select { case data := <-received: // First byte should be TLS record header, not PROXY signature. if data[0] != 0x16 { t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0]) } case <-time.After(5 * time.Second): t.Fatal("timeout waiting for backend data") } } func TestProxyProtocolFirewallUsesRealIP(t *testing.T) { // Backend that should never be reached. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() reached := make(chan struct{}, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } conn.Close() reached <- struct{}{} }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() // Block 203.0.113.50 (the "real" client IP from PROXY header). // 127.0.0.1 (the actual TCP peer) is NOT blocked. fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(cfg, fw, []ListenerData{ { ID: 1, Addr: proxyAddr, ProxyProtocol: true, Routes: map[string]RouteInfo{ "blocked.test": l4Route(backendLn.Addr().String()), }, }, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() srv.Run(ctx) }() time.Sleep(50 * time.Millisecond) conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() // Send PROXY v2 with the blocked real IP. var ppBuf bytes.Buffer proxyproto.WriteV2(&ppBuf, netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("198.51.100.1:443"), ) conn.Write(ppBuf.Bytes()) conn.Write(buildClientHello("blocked.test")) // Connection should be closed (firewall blocks real IP). conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed by firewall") } // Backend should not have been reached. select { case <-reached: t.Fatal("backend was reached despite firewall blocking real IP") case <-time.After(200 * time.Millisecond): // Expected. } cancel() wg.Wait() } // --- Multi-hop integration tests --- func TestMultiHopProxyProtocol(t *testing.T) { // Simulates edge → origin deployment with PROXY protocol. // // Client → [edge proxy] → PROXY v2 + TLS → [origin proxy] → h2c backend // proxy_protocol=true // L7 route certPath, keyPath := testCert(t, "multihop.test") // h2c backend on origin that echoes the X-Forwarded-For. backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For")) })) // Origin proxy: proxy_protocol=true listener, L7 route to backend. originLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("origin listen: %v", err) } originAddr := originLn.Addr().String() originLn.Close() originFw, _ := firewall.New("", nil, nil, nil, 0, 0) originCfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } originLogger := slog.New(slog.NewTextHandler(io.Discard, nil)) originSrv := New(originCfg, originFw, []ListenerData{ { ID: 1, Addr: originAddr, ProxyProtocol: true, Routes: map[string]RouteInfo{ "multihop.test": { Backend: backendAddr, Mode: "l7", TLSCert: certPath, TLSKey: keyPath, }, }, }, }, originLogger, "origin-test") originCtx, originCancel := context.WithCancel(context.Background()) var originWg sync.WaitGroup originWg.Add(1) go func() { defer originWg.Done() originSrv.Run(originCtx) }() time.Sleep(50 * time.Millisecond) defer func() { originCancel() originWg.Wait() }() // Edge proxy: L4 passthrough with send_proxy_protocol=true to origin. edgeLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("edge listen: %v", err) } edgeAddr := edgeLn.Addr().String() edgeLn.Close() edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0) edgeSrv := New(originCfg, edgeFw, []ListenerData{ { ID: 2, Addr: edgeAddr, Routes: map[string]RouteInfo{ "multihop.test": { Backend: originAddr, Mode: "l4", SendProxyProtocol: true, }, }, }, }, originLogger, "edge-test") edgeCtx, edgeCancel := context.WithCancel(context.Background()) var edgeWg sync.WaitGroup edgeWg.Add(1) go func() { defer edgeWg.Done() edgeSrv.Run(edgeCtx) }() time.Sleep(50 * time.Millisecond) defer func() { edgeCancel() edgeWg.Wait() }() // Client connects to edge proxy with TLS, as if it were a real client. tlsConf := &tls.Config{ ServerName: "multihop.test", InsecureSkipVerify: true, NextProtos: []string{"h2"}, } conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 5 * time.Second}, "tcp", edgeAddr, tlsConf, ) if err != nil { t.Fatalf("TLS dial edge: %v", err) } defer conn.Close() tr := &http2.Transport{} h2conn, err := tr.NewClientConn(conn) if err != nil { t.Fatalf("h2 client conn: %v", err) } req, _ := http.NewRequest("GET", "https://multihop.test/test", nil) resp, err := h2conn.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if resp.StatusCode != 200 { t.Fatalf("status = %d, want 200, body = %s", resp.StatusCode, body) } // The X-Forwarded-For should be 127.0.0.1 (the edge proxy's TCP client IP), // carried through via PROXY protocol from edge to origin. got := string(body) if got != "xff=127.0.0.1" { t.Fatalf("body = %q, want %q", got, "xff=127.0.0.1") } } func TestMultiHopFirewallBlocksRealIP(t *testing.T) { // Origin proxy with proxy_protocol=true and a firewall that blocks // the real client IP. The TCP peer (edge proxy) is NOT blocked. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() reached := make(chan struct{}, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } conn.Close() reached <- struct{}{} }() originLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("origin listen: %v", err) } originAddr := originLn.Addr().String() originLn.Close() // Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header. originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0) cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) originSrv := New(cfg, originFw, []ListenerData{ { ID: 1, Addr: originAddr, ProxyProtocol: true, Routes: map[string]RouteInfo{ "blocked.test": l4Route(backendLn.Addr().String()), }, }, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() originSrv.Run(ctx) }() time.Sleep(50 * time.Millisecond) // Connect to origin and send PROXY header with the blocked IP. conn, err := net.DialTimeout("tcp", originAddr, 2*time.Second) if err != nil { t.Fatalf("dial origin: %v", err) } defer conn.Close() var ppBuf bytes.Buffer proxyproto.WriteV2(&ppBuf, netip.MustParseAddrPort("198.51.100.99:12345"), netip.MustParseAddrPort("10.0.0.1:443"), ) conn.Write(ppBuf.Bytes()) conn.Write(buildClientHello("blocked.test")) // Connection should be dropped by firewall. conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed") } select { case <-reached: t.Fatal("backend was reached despite firewall block on real IP") case <-time.After(200 * time.Millisecond): } cancel() wg.Wait() } // --- L7 server integration tests --- // testCert generates a self-signed TLS certificate for the given hostname. func testCert(t *testing.T, hostname string) (certPath, keyPath string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("generating key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, DNSNames: []string{hostname}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("creating certificate: %v", err) } dir := t.TempDir() certPath = filepath.Join(dir, "cert.pem") keyPath = filepath.Join(dir, "key.pem") cf, _ := os.Create(certPath) pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) cf.Close() keyDER, _ := x509.MarshalECPrivateKey(key) kf, _ := os.Create(keyPath) pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) kf.Close() return } // startH2CBackend starts an h2c backend for testing. func startH2CBackend(t *testing.T, handler http.Handler) string { t.Helper() h2s := &http2.Server{} srv := &http.Server{ Handler: h2c.NewHandler(handler, h2s), ReadHeaderTimeout: 5 * time.Second, } ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } t.Cleanup(func() { srv.Close(); ln.Close() }) go srv.Serve(ln) return ln.Addr().String() } func TestL7ThroughServer(t *testing.T) { certPath, keyPath := testCert(t, "l7srv.test") backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For")) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "l7srv.test": { Backend: backendAddr, Mode: "l7", TLSCert: certPath, TLSKey: keyPath, }, }, }, }) stop := startAndStop(t, srv) defer stop() // Connect with TLS and make an HTTP/2 request. tlsConf := &tls.Config{ ServerName: "l7srv.test", InsecureSkipVerify: true, NextProtos: []string{"h2"}, } conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 5 * time.Second}, "tcp", proxyAddr, tlsConf, ) if err != nil { t.Fatalf("TLS dial: %v", err) } defer conn.Close() tr := &http2.Transport{} h2conn, err := tr.NewClientConn(conn) if err != nil { t.Fatalf("h2 client conn: %v", err) } req, _ := http.NewRequest("GET", "https://l7srv.test/hello", nil) resp, err := h2conn.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) // The X-Forwarded-For should be the TCP source IP (127.0.0.1) since // no PROXY protocol is in use. if resp.StatusCode != 200 { t.Fatalf("status = %d, want 200", resp.StatusCode) } got := string(body) if got != "ok path=/hello xff=127.0.0.1" { t.Fatalf("body = %q", got) } } func TestMixedL4L7SameListener(t *testing.T) { certPath, keyPath := testCert(t, "l7mixed.test") // L4 backend: echo server. l4BackendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("l4 backend listen: %v", err) } defer l4BackendLn.Close() go echoServer(t, l4BackendLn) // L7 backend: h2c HTTP server. l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, "l7-response") })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]RouteInfo{ "l4echo.test": l4Route(l4BackendLn.Addr().String()), "l7mixed.test": { Backend: l7BackendAddr, Mode: "l7", TLSCert: certPath, TLSKey: keyPath, }, }, }, }) stop := startAndStop(t, srv) defer stop() // Test L4 route works: send ClientHello, get echo. l4Conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial L4: %v", err) } defer l4Conn.Close() hello := buildClientHello("l4echo.test") l4Conn.Write(hello) echoed := make([]byte, len(hello)) l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(l4Conn, echoed); err != nil { t.Fatalf("L4 echo read: %v", err) } // Test L7 route works: TLS + HTTP/2. tlsConf := &tls.Config{ ServerName: "l7mixed.test", InsecureSkipVerify: true, NextProtos: []string{"h2"}, } l7Conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 5 * time.Second}, "tcp", proxyAddr, tlsConf, ) if err != nil { t.Fatalf("TLS dial L7: %v", err) } defer l7Conn.Close() tr := &http2.Transport{} h2conn, err := tr.NewClientConn(l7Conn) if err != nil { t.Fatalf("h2 client conn: %v", err) } req, _ := http.NewRequest("GET", "https://l7mixed.test/", nil) resp, err := h2conn.RoundTrip(req) if err != nil { t.Fatalf("L7 RoundTrip: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) if string(body) != "l7-response" { t.Fatalf("L7 body = %q, want %q", body, "l7-response") } } // --- ClientHello builder helpers (mirrors internal/sni test helpers) --- func buildClientHello(serverName string) []byte { return buildClientHelloWithExtensions(sniExtension(serverName)) } func buildClientHelloWithExtensions(extensions []byte) []byte { var hello []byte hello = append(hello, 0x03, 0x03) // TLS 1.2 hello = append(hello, make([]byte, 32)...) // random hello = append(hello, 0x00) // session ID: empty hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites hello = append(hello, 0x01, 0x00) // compression methods if len(extensions) > 0 { hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) hello = append(hello, extensions...) } handshake := []byte{0x01, 0x00, 0x00, 0x00} handshake[1] = byte(len(hello) >> 16) handshake[2] = byte(len(hello) >> 8) handshake[3] = byte(len(hello)) handshake = append(handshake, hello...) record := []byte{0x16, 0x03, 0x01} record = binary.BigEndian.AppendUint16(record, uint16(len(handshake))) record = append(record, handshake...) return record } func sniExtension(serverName string) []byte { name := []byte(serverName) var entry []byte entry = append(entry, 0x00) entry = binary.BigEndian.AppendUint16(entry, uint16(len(name))) entry = append(entry, name...) var list []byte list = binary.BigEndian.AppendUint16(list, uint16(len(entry))) list = append(list, entry...) var ext []byte ext = binary.BigEndian.AppendUint16(ext, 0x0000) ext = binary.BigEndian.AppendUint16(ext, uint16(len(list))) ext = append(ext, list...) return ext }