Files
mc-proxy/internal/l7/serve_test.go
Kyle Isom 5bc8f4fc8e Fix three doc-vs-implementation gaps found during audit
1. DB migration: add CHECK(mode IN ('l4', 'l7')) constraint on the
   routes.mode column. ARCHITECTURE.md documented this constraint but
   migration v2 omitted it. Enforces mode validity at the database
   level in addition to application-level validation.

2. L7 reverse proxy: distinguish timeout errors from connection errors
   in the ErrorHandler. Backend timeouts now return HTTP 504 Gateway
   Timeout instead of 502. Uses errors.Is(context.DeadlineExceeded)
   and net.Error.Timeout() detection. Added isTimeoutError unit tests.

3. Config validation: warn when L4 routes have tls_cert or tls_key set
   (they are silently ignored). ARCHITECTURE.md documented this warning
   but config.validate() did not emit it. Uses slog.Warn.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 14:25:41 -07:00

554 lines
14 KiB
Go

package l7
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"log/slog"
"math/big"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
// testCert generates a self-signed TLS certificate for the given hostname
// and writes the cert/key to temporary files, returning their paths.
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generating key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: hostname},
DNSNames: []string{hostname},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("creating certificate: %v", err)
}
dir := t.TempDir()
certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")
certFile, err := os.Create(certPath)
if err != nil {
t.Fatalf("creating cert file: %v", err)
}
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close()
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatalf("marshaling key: %v", err)
}
keyFile, err := os.Create(keyPath)
if err != nil {
t.Fatalf("creating key file: %v", err)
}
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
keyFile.Close()
return certPath, keyPath
}
// startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that
// responds with the given status and body. Returns the listener address.
func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Helper()
h2s := &http2.Server{}
srv := &http.Server{
Handler: h2c.NewHandler(handler, h2s),
ReadHeaderTimeout: 5 * time.Second,
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() {
srv.Close()
ln.Close()
})
go srv.Serve(ln)
return ln.Addr().String()
}
// dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake
// with the given serverName (skipping cert verification for self-signed),
// and returns an *http.Client configured to use that connection for HTTP/2.
func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
t.Helper()
tlsConf := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
t.Cleanup(func() { conn.Close() })
// Create an HTTP/2 client transport over this single connection.
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
if err != nil {
t.Fatalf("creating h2 client conn: %v", err)
}
return &http.Client{
Transport: &singleConnRoundTripper{cc: h2conn},
}
}
// singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2
// client connection.
type singleConnRoundTripper struct {
cc *http2.ClientConn
}
func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return s.cc.RoundTrip(req)
}
// serveL7Route starts l7.Serve in a goroutine for a single connection.
// Returns when the goroutine completes.
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
ctx := context.Background()
go func() {
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
if l7Err != nil {
t.Logf("l7.Serve: %v", l7Err)
}
}()
}
func TestL7H2CBackend(t *testing.T) {
certPath, keyPath := testCert(t, "l7.test")
// Start an h2c backend.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok")
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
}))
// Start a TCP listener for the L7 proxy.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
// Accept one connection and run L7 serve.
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
// No peeked bytes — the client is connecting directly with TLS.
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
// Connect as an HTTP/2 TLS client.
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if got := string(body); got != "hello from backend, path=/foo" {
t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo")
}
if resp.Header.Get("X-Backend") != "ok" {
t.Fatalf("X-Backend header missing or wrong")
}
}
func TestL7ForwardingHeaders(t *testing.T) {
certPath, keyPath := testCert(t, "headers.test")
// Backend that echoes the forwarding headers.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
r.Header.Get("X-Forwarded-For"),
r.Header.Get("X-Forwarded-Proto"),
r.Header.Get("X-Real-IP"),
)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
resp, err := client.Get("https://headers.test/")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
if string(body) != want {
t.Fatalf("body = %q, want %q", body, want)
}
}
func TestL7BackendUnreachable(t *testing.T) {
certPath, keyPath := testCert(t, "unreachable.test")
// Find a port that nothing is listening on.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: deadAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 1 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
resp, err := client.Get("https://unreachable.test/")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want 502", resp.StatusCode)
}
}
func TestIsTimeoutError(t *testing.T) {
// context.DeadlineExceeded is a timeout.
if !isTimeoutError(context.DeadlineExceeded) {
t.Fatal("expected DeadlineExceeded to be a timeout error")
}
// A net timeout error is a timeout.
netErr := &net.OpError{Op: "dial", Err: &timeoutErr{}}
if !isTimeoutError(netErr) {
t.Fatal("expected net timeout to be a timeout error")
}
// A regular error is not a timeout.
if isTimeoutError(fmt.Errorf("connection refused")) {
t.Fatal("expected non-timeout error to return false")
}
}
// timeoutErr implements net.Error with Timeout() = true.
type timeoutErr struct{}
func (e *timeoutErr) Error() string { return "timeout" }
func (e *timeoutErr) Timeout() bool { return true }
func (e *timeoutErr) Temporary() bool { return false }
func TestL7MultipleRequests(t *testing.T) {
certPath, keyPath := testCert(t, "multi.test")
var reqCount int
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
// Send multiple requests over the same HTTP/2 connection.
for i := range 3 {
path := fmt.Sprintf("/req%d", i)
resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path))
if err != nil {
t.Fatalf("GET %s: %v", path, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
want := fmt.Sprintf("req=%d path=%s", i+1, path)
if string(body) != want {
t.Fatalf("request %d: body = %q, want %q", i, body, want)
}
}
}
func TestL7LargeResponse(t *testing.T) {
certPath, keyPath := testCert(t, "large.test")
// Backend sends a 1 MB response.
largeBody := make([]byte, 1<<20)
for i := range largeBody {
largeBody[i] = byte(i % 256)
}
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(largeBody)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
resp, err := client.Get("https://large.test/")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if len(body) != len(largeBody) {
t.Fatalf("got %d bytes, want %d", len(body), len(largeBody))
}
}
func TestL7GRPCTrailers(t *testing.T) {
certPath, keyPath := testCert(t, "trailers.test")
// Backend that sets HTTP trailers (used by gRPC for status).
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Trailer", "Grpc-Status, Grpc-Message")
w.Header().Set("Content-Type", "application/grpc")
w.WriteHeader(200)
// Flush to send headers.
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
// Set trailers.
w.Header().Set("Grpc-Status", "0")
w.Header().Set("Grpc-Message", "OK")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
req, _ := http.NewRequest("POST", "https://trailers.test/grpc.test.Service/Method", nil)
req.Header.Set("Content-Type", "application/grpc")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("POST: %v", err)
}
defer resp.Body.Close()
// Read body to trigger trailer delivery.
io.ReadAll(resp.Body)
// Verify trailers were forwarded through the proxy.
grpcStatus := resp.Trailer.Get("Grpc-Status")
if grpcStatus != "0" {
t.Fatalf("Grpc-Status trailer = %q, want %q", grpcStatus, "0")
}
grpcMessage := resp.Trailer.Get("Grpc-Message")
if grpcMessage != "OK" {
t.Fatalf("Grpc-Message trailer = %q, want %q", grpcMessage, "OK")
}
}
func TestL7HTTP11Fallback(t *testing.T) {
certPath, keyPath := testCert(t, "http11.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "proto=%s", r.Proto)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
// Connect with HTTP/1.1 only (no h2 ALPN).
tlsConf := &tls.Config{
ServerName: "http11.test",
InsecureSkipVerify: true,
NextProtos: []string{"http/1.1"},
}
tr := &http.Transport{TLSClientConfig: tlsConf}
client := &http.Client{Transport: tr}
resp, err := client.Get(fmt.Sprintf("https://%s/", proxyLn.Addr().String()))
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
// The backend sees HTTP/2 (proxied via h2c) regardless of client protocol.
// Just verify we got a response — the protocol the backend sees depends
// on the h2c transport.
if len(body) == 0 {
t.Fatal("empty response body")
}
}