package l7 import ( "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" "fmt" "io" "log/slog" "math/big" "net" "net/http" "net/netip" "os" "path/filepath" "testing" "time" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" ) // testCert generates a self-signed TLS certificate for the given hostname // and writes the cert/key to temporary files, returning their paths. func testCert(t *testing.T, hostname string) (certPath, keyPath string) { t.Helper() key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { t.Fatalf("generating key: %v", err) } tmpl := &x509.Certificate{ SerialNumber: big.NewInt(1), Subject: pkix.Name{CommonName: hostname}, DNSNames: []string{hostname}, NotBefore: time.Now().Add(-time.Hour), NotAfter: time.Now().Add(time.Hour), KeyUsage: x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, } certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) if err != nil { t.Fatalf("creating certificate: %v", err) } dir := t.TempDir() certPath = filepath.Join(dir, "cert.pem") keyPath = filepath.Join(dir, "key.pem") certFile, err := os.Create(certPath) if err != nil { t.Fatalf("creating cert file: %v", err) } pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) certFile.Close() keyDER, err := x509.MarshalECPrivateKey(key) if err != nil { t.Fatalf("marshaling key: %v", err) } keyFile, err := os.Create(keyPath) if err != nil { t.Fatalf("creating key file: %v", err) } pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) keyFile.Close() return certPath, keyPath } // startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that // responds with the given status and body. Returns the listener address. func startH2CBackend(t *testing.T, handler http.Handler) string { t.Helper() h2s := &http2.Server{} srv := &http.Server{ Handler: h2c.NewHandler(handler, h2s), ReadHeaderTimeout: 5 * time.Second, } ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } t.Cleanup(func() { srv.Close() ln.Close() }) go srv.Serve(ln) return ln.Addr().String() } // dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake // with the given serverName (skipping cert verification for self-signed), // and returns an *http.Client configured to use that connection for HTTP/2. func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client { t.Helper() tlsConf := &tls.Config{ ServerName: serverName, InsecureSkipVerify: true, NextProtos: []string{"h2"}, } conn, err := tls.DialWithDialer( &net.Dialer{Timeout: 5 * time.Second}, "tcp", proxyAddr, tlsConf, ) if err != nil { t.Fatalf("TLS dial: %v", err) } t.Cleanup(func() { conn.Close() }) // Create an HTTP/2 client transport over this single connection. tr := &http2.Transport{} h2conn, err := tr.NewClientConn(conn) if err != nil { t.Fatalf("creating h2 client conn: %v", err) } return &http.Client{ Transport: &singleConnRoundTripper{cc: h2conn}, } } // singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2 // client connection. type singleConnRoundTripper struct { cc *http2.ClientConn } func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { return s.cc.RoundTrip(req) } // serveL7Route starts l7.Serve in a goroutine for a single connection. // Returns when the goroutine completes. func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) { t.Helper() logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") ctx := context.Background() go func() { l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger) if l7Err != nil { t.Logf("l7.Serve: %v", l7Err) } }() } func TestL7H2CBackend(t *testing.T) { certPath, keyPath := testCert(t, "l7.test") // Start an h2c backend. backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Backend", "ok") fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path) })) // Start a TCP listener for the L7 proxy. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer proxyLn.Close() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } // Accept one connection and run L7 serve. go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") // No peeked bytes — the client is connecting directly with TLS. Serve(context.Background(), conn, nil, route, clientAddr, logger) }() // Connect as an HTTP/2 TLS client. client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test") resp, err := client.Get(fmt.Sprintf("https://l7.test/foo")) if err != nil { t.Fatalf("GET: %v", err) } defer resp.Body.Close() if resp.StatusCode != 200 { t.Fatalf("status = %d, want 200", resp.StatusCode) } body, _ := io.ReadAll(resp.Body) if got := string(body); got != "hello from backend, path=/foo" { t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo") } if resp.Header.Get("X-Backend") != "ok" { t.Fatalf("X-Backend header missing or wrong") } } func TestL7ForwardingHeaders(t *testing.T) { certPath, keyPath := testCert(t, "headers.test") // Backend that echoes the forwarding headers. backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "xff=%s xfp=%s xri=%s", r.Header.Get("X-Forwarded-For"), r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Real-IP"), ) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer proxyLn.Close() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test") resp, err := client.Get("https://headers.test/") if err != nil { t.Fatalf("GET: %v", err) } defer resp.Body.Close() body, _ := io.ReadAll(resp.Body) want := "xff=203.0.113.50 xfp=https xri=203.0.113.50" if string(body) != want { t.Fatalf("body = %q, want %q", body, want) } } func TestL7BackendUnreachable(t *testing.T) { certPath, keyPath := testCert(t, "unreachable.test") // Find a port that nothing is listening on. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } deadAddr := ln.Addr().String() ln.Close() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer proxyLn.Close() route := RouteConfig{ Backend: deadAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 1 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test") resp, err := client.Get("https://unreachable.test/") if err != nil { t.Fatalf("GET: %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusBadGateway { t.Fatalf("status = %d, want 502", resp.StatusCode) } } func TestL7MultipleRequests(t *testing.T) { certPath, keyPath := testCert(t, "multi.test") var reqCount int backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { reqCount++ fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path) })) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("proxy listen: %v", err) } defer proxyLn.Close() route := RouteConfig{ Backend: backendAddr, TLSCert: certPath, TLSKey: keyPath, ConnectTimeout: 5 * time.Second, } go func() { conn, err := proxyLn.Accept() if err != nil { return } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") Serve(context.Background(), conn, nil, route, clientAddr, logger) }() client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test") // Send multiple requests over the same HTTP/2 connection. for i := range 3 { path := fmt.Sprintf("/req%d", i) resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path)) if err != nil { t.Fatalf("GET %s: %v", path, err) } body, _ := io.ReadAll(resp.Body) resp.Body.Close() want := fmt.Sprintf("req=%d path=%s", i+1, path) if string(body) != want { t.Fatalf("request %d: body = %q, want %q", i, body, want) } } }