package l7 import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "io" "log/slog" "math/big" "net" "net/http" "net/netip" "os" "path/filepath" "testing" "time" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) // testCert generates a self-signed TLS certificate for the given hostname // and writes the cert/key to temporary files, returning their paths. func testCert(t *testing.T, hostname string) (certPath, keyPath string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("generating key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, DNSNames: []string{hostname}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("creating certificate: %v", err) } dir := t.TempDir() certPath = filepath.Join(dir, "cert.pem") keyPath = filepath.Join(dir, "key.pem") certFile, err := os.Create(certPath) if err != nil { t.Fatalf("creating cert file: %v", err) } _ = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) _ = certFile.Close() keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { t.Fatalf("marshaling key: %v", err) } keyFile, err := os.Create(keyPath) if err != nil { t.Fatalf("creating key file: %v", err) } _ = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) _ = keyFile.Close() return certPath, keyPath } // startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that // responds with the given status and body. Returns the listener address. func startH2CBackend(t *testing.T, handler http.Handler) string { t.Helper() h2s := &http2.Server{} srv := &http.Server{ Handler: h2c.NewHandler(handler, h2s), ReadHeaderTimeout: 5 * time.Second, } ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } t.Cleanup(func() { _ = srv.Close() _ = ln.Close() }) go func() { _ = srv.Serve(ln) }() return ln.Addr().String() } // dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake // with the given serverName (skipping cert verification for self-signed), // and returns an *http.Client configured to use that connection for HTTP/2. func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client { t.Helper() tlsConf := &tls.Config{ ServerName: serverName, InsecureSkipVerify: true, NextProtos: []string{"h2"}, } conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 5 * time.Second}, "tcp", proxyAddr, tlsConf, ) if err != nil { t.Fatalf("TLS dial: %v", err) } t.Cleanup(func() { _ = conn.Close() }) // Create an HTTP/2 client transport over this single connection. tr := &http2.Transport{} h2conn, err := tr.NewClientConn(conn) if err != nil { t.Fatalf("creating h2 client conn: %v", err) } return &http.Client{ Transport: &singleConnRoundTripper{cc: h2conn}, } } // singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2 // client connection. type singleConnRoundTripper struct { cc *http2.ClientConn } func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return s.cc.RoundTrip(req) } func TestL7H2CBackend(t *testing.T) { certPath, keyPath := testCert(t, "l7.test") // Start an h2c backend. backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Backend", "ok") _, _ = fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path) })) // Start a TCP listener for the L7 proxy. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } // Accept one connection and run L7 serve. go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") // No peeked bytes — the client is connecting directly with TLS. _ = Serve(context.Background(), conn, nil, route, clientAddr, logger) }() // Connect as an HTTP/2 TLS client. client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test") resp, err := client.Get("https://l7.test/foo") if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != 200 { t.Fatalf("status = %d, want 200", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) if got := string(body); got != "hello from backend, path=/foo" { t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo") } if resp.Header.Get("X-Backend") != "ok" { t.Fatalf("X-Backend header missing or wrong") } } func TestL7ForwardingHeaders(t *testing.T) { certPath, keyPath := testCert(t, "headers.test") // Backend that echoes the forwarding headers. backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "xff=%s xfp=%s xri=%s", r.Header.Get("X-Forwarded-For"), r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Real-IP"), ) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") _ = Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test") resp, err := client.Get("https://headers.test/") if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) want := "xff=203.0.113.50 xfp=https xri=203.0.113.50" if string(body) != want { t.Fatalf("body = %q, want %q", body, want) } } func TestL7BackendUnreachable(t *testing.T) { certPath, keyPath := testCert(t, "unreachable.test") // Find a port that nothing is listening on. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } deadAddr := ln.Addr().String() _ = ln.Close() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: deadAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 1 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") _ = Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test") resp, err := client.Get("https://unreachable.test/") if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusBadGateway { t.Fatalf("status = %d, want 502", resp.StatusCode) } } func TestIsTimeoutError(t *testing.T) { // context.DeadlineExceeded is a timeout. if !isTimeoutError(context.DeadlineExceeded) { t.Fatal("expected DeadlineExceeded to be a timeout error") } // A net timeout error is a timeout. netErr := &net.OpError{Op: "dial", Err: &timeoutErr{}} if !isTimeoutError(netErr) { t.Fatal("expected net timeout to be a timeout error") } // A regular error is not a timeout. if isTimeoutError(fmt.Errorf("connection refused")) { t.Fatal("expected non-timeout error to return false") } } // timeoutErr implements net.Error with Timeout() = true. type timeoutErr struct{} func (e *timeoutErr) Error() string { return "timeout" } func (e *timeoutErr) Timeout() bool { return true } func (e *timeoutErr) Temporary() bool { return false } func TestL7MultipleRequests(t *testing.T) { certPath, keyPath := testCert(t, "multi.test") var reqCount int backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqCount++ _, _ = fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") _ = Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test") // Send multiple requests over the same HTTP/2 connection. for i := range 3 { path := fmt.Sprintf("/req%d", i) resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path)) if err != nil { t.Fatalf("GET %s: %v", path, err) } body, _ := io.ReadAll(resp.Body) _ = resp.Body.Close() want := fmt.Sprintf("req=%d path=%s", i+1, path) if string(body) != want { t.Fatalf("request %d: body = %q, want %q", i, body, want) } } } func TestL7LargeResponse(t *testing.T) { certPath, keyPath := testCert(t, "large.test") // Backend sends a 1 MB response. largeBody := make([]byte, 1<<20) for i := range largeBody { largeBody[i] = byte(i % 256) } backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = w.Write(largeBody) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test") resp, err := client.Get("https://large.test/") if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() body, _ := io.ReadAll(resp.Body) if len(body) != len(largeBody) { t.Fatalf("got %d bytes, want %d", len(body), len(largeBody)) } } func TestL7GRPCTrailers(t *testing.T) { certPath, keyPath := testCert(t, "trailers.test") // Backend that sets HTTP trailers (used by gRPC for status). backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Trailer", "Grpc-Status, Grpc-Message") w.Header().Set("Content-Type", "application/grpc") w.WriteHeader(200) // Flush to send headers. if f, ok := w.(http.Flusher); ok { f.Flush() } // Set trailers. w.Header().Set("Grpc-Status", "0") w.Header().Set("Grpc-Message", "OK") })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test") req, _ := http.NewRequest("POST", "https://trailers.test/grpc.test.Service/Method", nil) req.Header.Set("Content-Type", "application/grpc") resp, err := client.Do(req) if err != nil { t.Fatalf("POST: %v", err) } defer func() { _ = resp.Body.Close() }() // Read body to trigger trailer delivery. _, _ = io.ReadAll(resp.Body) // Verify trailers were forwarded through the proxy. grpcStatus := resp.Trailer.Get("Grpc-Status") if grpcStatus != "0" { t.Fatalf("Grpc-Status trailer = %q, want %q", grpcStatus, "0") } grpcMessage := resp.Trailer.Get("Grpc-Message") if grpcMessage != "OK" { t.Fatalf("Grpc-Message trailer = %q, want %q", grpcMessage, "OK") } } func TestL7HTTP11Fallback(t *testing.T) { certPath, keyPath := testCert(t, "http11.test") backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprintf(w, "proto=%s", r.Proto) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) }() // Connect with HTTP/1.1 only (no h2 ALPN). tlsConf := &tls.Config{ ServerName: "http11.test", InsecureSkipVerify: true, NextProtos: []string{"http/1.1"}, } tr := &http.Transport{TLSClientConfig: tlsConf} client := &http.Client{Transport: tr} resp, err := client.Get(fmt.Sprintf("https://%s/", proxyLn.Addr().String())) if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != 200 { t.Fatalf("status = %d, want 200", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) // The backend sees HTTP/2 (proxied via h2c) regardless of client protocol. // Just verify we got a response — the protocol the backend sees depends // on the h2c transport. if len(body) == 0 { t.Fatal("empty response body") } } func TestL7PolicyBlocksUserAgentE2E(t *testing.T) { certPath, keyPath := testCert(t, "policy.test") backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprint(w, "should-not-reach") })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, Policies: []PolicyRule{ {Type: "block_user_agent", Value: "EvilBot"}, }, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test") req, _ := http.NewRequest("GET", "https://policy.test/", nil) req.Header.Set("User-Agent", "EvilBot/1.0") resp, err := client.Do(req) if err != nil { t.Fatalf("GET: %v", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != 403 { t.Fatalf("status = %d, want 403", resp.StatusCode) } } func TestL7PolicyRequiresHeaderE2E(t *testing.T) { certPath, keyPath := testCert(t, "reqhdr.test") backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { _, _ = fmt.Fprint(w, "ok") })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer func() { _ = proxyLn.Close() }() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, Policies: []PolicyRule{ {Type: "require_header", Value: "X-Auth-Token"}, }, } // Accept two connections (one blocked, one allowed). go func() { for range 2 { conn, err := proxyLn.Accept() if err != nil { return } go func() { logger := slog.New(slog.NewTextHandler(io.Discard, nil)) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) }() } }() // Without the required header → 403. client1 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test") resp1, err := client1.Get("https://reqhdr.test/") if err != nil { t.Fatalf("GET without header: %v", err) } _ = resp1.Body.Close() if resp1.StatusCode != 403 { t.Fatalf("without header: status = %d, want 403", resp1.StatusCode) } // With the required header → 200. client2 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test") req, _ := http.NewRequest("GET", "https://reqhdr.test/", nil) req.Header.Set("X-Auth-Token", "valid-token") resp2, err := client2.Do(req) if err != nil { t.Fatalf("GET with header: %v", err) } defer func() { _ = resp2.Body.Close() }() body, _ := io.ReadAll(resp2.Body) if resp2.StatusCode != 200 { t.Fatalf("with header: status = %d, want 200", resp2.StatusCode) } if string(body) != "ok" { t.Fatalf("body = %q, want %q", body, "ok") } }