diff --git a/PROGRESS.md b/PROGRESS.md index 1547f45..e2ba26b 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -28,10 +28,10 @@ proceeds. Each item is marked: ## Phase 3: L7 Proxying -- [ ] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point) -- [ ] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`) -- [ ] 3.3 PROXY protocol sending in L7 path -- [ ] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7) +- [x] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point) +- [x] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`) +- [x] 3.3 PROXY protocol sending in L7 path +- [x] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7) ## Phase 4: gRPC API & CLI Updates diff --git a/internal/l7/prefixconn.go b/internal/l7/prefixconn.go new file mode 100644 index 0000000..6b34048 --- /dev/null +++ b/internal/l7/prefixconn.go @@ -0,0 +1,28 @@ +package l7 + +import "net" + +// PrefixConn wraps a net.Conn, prepending buffered bytes before reading +// from the underlying connection. This is used to replay the TLS ClientHello +// bytes that were peeked during SNI extraction back into crypto/tls.Server. +type PrefixConn struct { + net.Conn + prefix []byte + off int +} + +// NewPrefixConn creates a PrefixConn that returns prefix bytes first, +// then reads from the underlying conn. +func NewPrefixConn(conn net.Conn, prefix []byte) *PrefixConn { + return &PrefixConn{Conn: conn, prefix: prefix} +} + +// Read returns buffered prefix bytes first, then reads from the underlying conn. +func (pc *PrefixConn) Read(b []byte) (int, error) { + if pc.off < len(pc.prefix) { + n := copy(b, pc.prefix[pc.off:]) + pc.off += n + return n, nil + } + return pc.Conn.Read(b) +} diff --git a/internal/l7/prefixconn_test.go b/internal/l7/prefixconn_test.go new file mode 100644 index 0000000..515d16b --- /dev/null +++ b/internal/l7/prefixconn_test.go @@ -0,0 +1,154 @@ +package l7 + +import ( + "io" + "net" + "testing" + "time" +) + +func TestPrefixConnRead(t *testing.T) { + // Create a TCP pair. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + conn.Write([]byte("WORLD")) + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + pc := NewPrefixConn(conn, []byte("HELLO")) + + // Read all data: should get "HELLOWORLD". + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + all, err := io.ReadAll(pc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(all) != "HELLOWORLD" { + t.Fatalf("got %q, want %q", all, "HELLOWORLD") + } +} + +func TestPrefixConnSmallReads(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + conn.Write([]byte("CD")) + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + pc := NewPrefixConn(conn, []byte("AB")) + + // Read 1 byte at a time from the prefix. + buf := make([]byte, 1) + n, err := pc.Read(buf) + if err != nil || n != 1 || buf[0] != 'A' { + t.Fatalf("first read: n=%d, err=%v, buf=%q", n, err, buf[:n]) + } + n, err = pc.Read(buf) + if err != nil || n != 1 || buf[0] != 'B' { + t.Fatalf("second read: n=%d, err=%v, buf=%q", n, err, buf[:n]) + } + + // Now reads come from the underlying conn. + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + rest, err := io.ReadAll(pc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(rest) != "CD" { + t.Fatalf("got %q, want %q", rest, "CD") + } +} + +func TestPrefixConnEmptyPrefix(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + conn.Write([]byte("DATA")) + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + pc := NewPrefixConn(conn, nil) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + all, err := io.ReadAll(pc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(all) != "DATA" { + t.Fatalf("got %q, want %q", all, "DATA") + } +} + +func TestPrefixConnDelegates(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + go func() { + conn, _ := ln.Accept() + if conn != nil { + conn.Close() + } + }() + + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + pc := NewPrefixConn(conn, []byte("X")) + + // RemoteAddr, LocalAddr should delegate. + if pc.RemoteAddr() == nil { + t.Fatal("RemoteAddr returned nil") + } + if pc.LocalAddr() == nil { + t.Fatal("LocalAddr returned nil") + } +} diff --git a/internal/l7/serve.go b/internal/l7/serve.go new file mode 100644 index 0000000..00b5ab6 --- /dev/null +++ b/internal/l7/serve.go @@ -0,0 +1,236 @@ +// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying. +package l7 + +import ( + "context" + "crypto/tls" + "fmt" + "log/slog" + "net" + "net/http" + "net/http/httputil" + "net/netip" + "net/url" + "time" + + "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" + "golang.org/x/net/http2" +) + +// RouteConfig holds the L7 route parameters needed by the l7 package. +type RouteConfig struct { + Backend string + TLSCert string + TLSKey string + BackendTLS bool + SendProxyProtocol bool + ConnectTimeout time.Duration +} + +// contextKey is an unexported type for context keys in this package. +type contextKey int + +const clientAddrKey contextKey = 0 + +// Serve handles an L7 (TLS-terminating) connection. It completes the TLS +// handshake with the client using the route's certificate, then reverse +// proxies HTTP/2 (or HTTP/1.1) traffic to the backend. +// +// peeked contains the TLS ClientHello bytes that were read during SNI +// extraction. They are replayed into the TLS handshake via PrefixConn. +func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error { + // Load the TLS certificate for this route. + cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey) + if err != nil { + return fmt.Errorf("loading TLS cert/key: %w", err) + } + + // Wrap the connection to replay the peeked ClientHello. + pc := NewPrefixConn(conn, peeked) + + tlsConf := &tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"h2", "http/1.1"}, + MinVersion: tls.VersionTLS12, + } + + tlsConn := tls.Server(pc, tlsConf) + + // Complete the TLS handshake with a timeout. + if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil { + return fmt.Errorf("setting handshake deadline: %w", err) + } + if err := tlsConn.Handshake(); err != nil { + return fmt.Errorf("TLS handshake: %w", err) + } + if err := tlsConn.SetDeadline(time.Time{}); err != nil { + return fmt.Errorf("clearing handshake deadline: %w", err) + } + + // Build the reverse proxy handler. + rp, err := newReverseProxy(route, logger) + if err != nil { + return fmt.Errorf("creating reverse proxy: %w", err) + } + + // Wrap the handler to inject the real client IP into the request context. + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr)) + rp.ServeHTTP(w, r) + }) + + // Serve HTTP on the TLS connection. Use HTTP/2 if negotiated, + // otherwise fall back to HTTP/1.1. + proto := tlsConn.ConnectionState().NegotiatedProtocol + + if proto == "h2" { + h2srv := &http2.Server{} + h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{ + Context: ctx, + Handler: handler, + }) + } else { + // HTTP/1.1 fallback: serve a single connection. + srv := &http.Server{ + Handler: handler, + ReadHeaderTimeout: 30 * time.Second, + } + singleConn := newSingleConnListener(tlsConn) + srv.Serve(singleConn) + } + + return nil +} + +// newReverseProxy creates an httputil.ReverseProxy for the given route. +func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) { + scheme := "http" + if route.BackendTLS { + scheme = "https" + } + + target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend)) + if err != nil { + return nil, fmt.Errorf("parsing backend URL: %w", err) + } + + transport, err := newTransport(route) + if err != nil { + return nil, err + } + + rp := &httputil.ReverseProxy{ + Rewrite: func(pr *httputil.ProxyRequest) { + pr.SetURL(target) + // Preserve the original Host header from the client. + pr.Out.Host = pr.In.Host + // Inject forwarding headers from the real client IP. + if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok { + setForwardingHeaders(pr.Out, addr) + } + }, + Transport: transport, + ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) { + logger.Error("reverse proxy error", "backend", route.Backend, "error", err) + w.WriteHeader(http.StatusBadGateway) + }, + } + + return rp, nil +} + +// newTransport creates the HTTP transport for connecting to the backend. +func newTransport(route RouteConfig) (http.RoundTripper, error) { + connectTimeout := route.ConnectTimeout + if connectTimeout == 0 { + connectTimeout = 5 * time.Second + } + + if route.BackendTLS { + // TLS to backend (h2 over TLS). + return &http2.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, nil + } + + // h2c: HTTP/2 over plaintext TCP. + return &http2.Transport{ + AllowHTTP: true, + DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { + conn, err := dialBackend(ctx, network, addr, connectTimeout, route) + if err != nil { + return nil, err + } + return conn, nil + }, + }, nil +} + +// dialBackend connects to the backend, optionally sending a PROXY protocol header. +func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) { + d := &net.Dialer{Timeout: timeout} + conn, err := d.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + if route.SendProxyProtocol { + // Get the real client IP from the context if available. + clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort) + backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String()) + if clientAddr.IsValid() { + if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil { + conn.Close() + return nil, fmt.Errorf("writing PROXY protocol header: %w", err) + } + } + } + + return conn, nil +} + +// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP. +func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) { + clientIP := clientAddr.Addr().String() + r.Header.Set("X-Forwarded-For", clientIP) + r.Header.Set("X-Forwarded-Proto", "https") + r.Header.Set("X-Real-IP", clientIP) +} + +// singleConnListener is a net.Listener that returns a single connection once, +// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection. +type singleConnListener struct { + conn net.Conn + ch chan net.Conn + done chan struct{} +} + +func newSingleConnListener(conn net.Conn) *singleConnListener { + ch := make(chan net.Conn, 1) + ch <- conn + return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})} +} + +func (l *singleConnListener) Accept() (net.Conn, error) { + select { + case c := <-l.ch: + return c, nil + case <-l.done: + return nil, net.ErrClosed + } +} + +func (l *singleConnListener) Close() error { + select { + case <-l.done: + default: + close(l.done) + } + return nil +} + +func (l *singleConnListener) Addr() net.Addr { + return l.conn.LocalAddr() +} diff --git a/internal/l7/serve_test.go b/internal/l7/serve_test.go new file mode 100644 index 0000000..fc99f8b --- /dev/null +++ b/internal/l7/serve_test.go @@ -0,0 +1,363 @@ +package l7 + +import ( + "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "io" + "log/slog" + "math/big" + "net" + "net/http" + "net/netip" + "os" + "path/filepath" + "testing" + "time" + + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" +) + +// testCert generates a self-signed TLS certificate for the given hostname +// and writes the cert/key to temporary files, returning their paths. +func testCert(t *testing.T, hostname string) (certPath, keyPath string) { + t.Helper() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generating key: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: hostname}, + DNSNames: []string{hostname}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("creating certificate: %v", err) + } + + dir := t.TempDir() + certPath = filepath.Join(dir, "cert.pem") + keyPath = filepath.Join(dir, "key.pem") + + certFile, err := os.Create(certPath) + if err != nil { + t.Fatalf("creating cert file: %v", err) + } + pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + certFile.Close() + + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("marshaling key: %v", err) + } + keyFile, err := os.Create(keyPath) + if err != nil { + t.Fatalf("creating key file: %v", err) + } + pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + keyFile.Close() + + return certPath, keyPath +} + +// startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that +// responds with the given status and body. Returns the listener address. +func startH2CBackend(t *testing.T, handler http.Handler) string { + t.Helper() + + h2s := &http2.Server{} + srv := &http.Server{ + Handler: h2c.NewHandler(handler, h2s), + ReadHeaderTimeout: 5 * time.Second, + } + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { + srv.Close() + ln.Close() + }) + + go srv.Serve(ln) + return ln.Addr().String() +} + +// dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake +// with the given serverName (skipping cert verification for self-signed), +// and returns an *http.Client configured to use that connection for HTTP/2. +func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client { + t.Helper() + + tlsConf := &tls.Config{ + ServerName: serverName, + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } + + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 5 * time.Second}, + "tcp", proxyAddr, tlsConf, + ) + if err != nil { + t.Fatalf("TLS dial: %v", err) + } + t.Cleanup(func() { conn.Close() }) + + // Create an HTTP/2 client transport over this single connection. + tr := &http2.Transport{} + h2conn, err := tr.NewClientConn(conn) + if err != nil { + t.Fatalf("creating h2 client conn: %v", err) + } + + return &http.Client{ + Transport: &singleConnRoundTripper{cc: h2conn}, + } +} + +// singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2 +// client connection. +type singleConnRoundTripper struct { + cc *http2.ClientConn +} + +func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return s.cc.RoundTrip(req) +} + +// serveL7Route starts l7.Serve in a goroutine for a single connection. +// Returns when the goroutine completes. +func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) { + t.Helper() + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") + ctx := context.Background() + + go func() { + l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger) + if l7Err != nil { + t.Logf("l7.Serve: %v", l7Err) + } + }() +} + +func TestL7H2CBackend(t *testing.T) { + certPath, keyPath := testCert(t, "l7.test") + + // Start an h2c backend. + backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Backend", "ok") + fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path) + })) + + // Start a TCP listener for the L7 proxy. + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + defer proxyLn.Close() + + route := RouteConfig{ + Backend: backendAddr, + TLSCert: certPath, + TLSKey: keyPath, + ConnectTimeout: 5 * time.Second, + } + + // Accept one connection and run L7 serve. + go func() { + conn, err := proxyLn.Accept() + if err != nil { + return + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") + // No peeked bytes — the client is connecting directly with TLS. + Serve(context.Background(), conn, nil, route, clientAddr, logger) + }() + + // Connect as an HTTP/2 TLS client. + client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test") + + resp, err := client.Get(fmt.Sprintf("https://l7.test/foo")) + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + + body, _ := io.ReadAll(resp.Body) + if got := string(body); got != "hello from backend, path=/foo" { + t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo") + } + + if resp.Header.Get("X-Backend") != "ok" { + t.Fatalf("X-Backend header missing or wrong") + } +} + +func TestL7ForwardingHeaders(t *testing.T) { + certPath, keyPath := testCert(t, "headers.test") + + // Backend that echoes the forwarding headers. + backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "xff=%s xfp=%s xri=%s", + r.Header.Get("X-Forwarded-For"), + r.Header.Get("X-Forwarded-Proto"), + r.Header.Get("X-Real-IP"), + ) + })) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + defer proxyLn.Close() + + route := RouteConfig{ + Backend: backendAddr, + TLSCert: certPath, + TLSKey: keyPath, + ConnectTimeout: 5 * time.Second, + } + + go func() { + conn, err := proxyLn.Accept() + if err != nil { + return + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") + Serve(context.Background(), conn, nil, route, clientAddr, logger) + }() + + client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test") + resp, err := client.Get("https://headers.test/") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + want := "xff=203.0.113.50 xfp=https xri=203.0.113.50" + if string(body) != want { + t.Fatalf("body = %q, want %q", body, want) + } +} + +func TestL7BackendUnreachable(t *testing.T) { + certPath, keyPath := testCert(t, "unreachable.test") + + // Find a port that nothing is listening on. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + deadAddr := ln.Addr().String() + ln.Close() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + defer proxyLn.Close() + + route := RouteConfig{ + Backend: deadAddr, + TLSCert: certPath, + TLSKey: keyPath, + ConnectTimeout: 1 * time.Second, + } + + go func() { + conn, err := proxyLn.Accept() + if err != nil { + return + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") + Serve(context.Background(), conn, nil, route, clientAddr, logger) + }() + + client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test") + resp, err := client.Get("https://unreachable.test/") + if err != nil { + t.Fatalf("GET: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusBadGateway { + t.Fatalf("status = %d, want 502", resp.StatusCode) + } +} + +func TestL7MultipleRequests(t *testing.T) { + certPath, keyPath := testCert(t, "multi.test") + + var reqCount int + backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reqCount++ + fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path) + })) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + defer proxyLn.Close() + + route := RouteConfig{ + Backend: backendAddr, + TLSCert: certPath, + TLSKey: keyPath, + ConnectTimeout: 5 * time.Second, + } + + go func() { + conn, err := proxyLn.Accept() + if err != nil { + return + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") + Serve(context.Background(), conn, nil, route, clientAddr, logger) + }() + + client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test") + + // Send multiple requests over the same HTTP/2 connection. + for i := range 3 { + path := fmt.Sprintf("/req%d", i) + resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path)) + if err != nil { + t.Fatalf("GET %s: %v", path, err) + } + body, _ := io.ReadAll(resp.Body) + resp.Body.Close() + + want := fmt.Sprintf("req=%d path=%s", i+1, path) + if string(body) != want { + t.Fatalf("request %d: body = %q, want %q", i, body, want) + } + } +} diff --git a/internal/server/server.go b/internal/server/server.go index c22396d..2cc2edd 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -13,6 +13,7 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" + "git.wntrmute.dev/kyle/mc-proxy/internal/l7" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/sni" @@ -298,11 +299,10 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat return } - // Dispatch based on route mode. L7 will be implemented in a later phase. + // Dispatch based on route mode. switch route.Mode { case "l7": - s.logger.Error("L7 mode not yet implemented", "hostname", hostname) - return + s.handleL7(ctx, conn, addr, addrPort, hostname, route, peeked) default: s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked) } @@ -340,3 +340,25 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c "backend_bytes", result.BackendBytes, ) } + +// handleL7 handles an L7 (TLS-terminating) connection. +func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { + s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend) + + rc := l7.RouteConfig{ + Backend: route.Backend, + TLSCert: route.TLSCert, + TLSKey: route.TLSKey, + BackendTLS: route.BackendTLS, + SendProxyProtocol: route.SendProxyProtocol, + ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration, + } + + if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil { + if ctx.Err() == nil { + s.logger.Debug("L7 serve ended", "hostname", hostname, "error", err) + } + } + + s.logger.Info("L7 connection closed", "addr", addr, "hostname", hostname) +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 8ca9577..ceda29b 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -3,11 +3,23 @@ package server import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" "encoding/binary" + "encoding/pem" + "fmt" "io" "log/slog" + "math/big" "net" + "net/http" "net/netip" + "os" + "path/filepath" "sync" "testing" "time" @@ -15,6 +27,8 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" + "golang.org/x/net/http2" + "golang.org/x/net/http2/h2c" ) // l4Route creates a RouteInfo for an L4 passthrough route. @@ -1038,6 +1052,220 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) { wg.Wait() } +// --- L7 server integration tests --- + +// testCert generates a self-signed TLS certificate for the given hostname. +func testCert(t *testing.T, hostname string) (certPath, keyPath string) { + t.Helper() + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generating key: %v", err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: hostname}, + DNSNames: []string{hostname}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key) + if err != nil { + t.Fatalf("creating certificate: %v", err) + } + dir := t.TempDir() + certPath = filepath.Join(dir, "cert.pem") + keyPath = filepath.Join(dir, "key.pem") + cf, _ := os.Create(certPath) + pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + cf.Close() + keyDER, _ := x509.MarshalECPrivateKey(key) + kf, _ := os.Create(keyPath) + pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + kf.Close() + return +} + +// startH2CBackend starts an h2c backend for testing. +func startH2CBackend(t *testing.T, handler http.Handler) string { + t.Helper() + h2s := &http2.Server{} + srv := &http.Server{ + Handler: h2c.NewHandler(handler, h2s), + ReadHeaderTimeout: 5 * time.Second, + } + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { srv.Close(); ln.Close() }) + go srv.Serve(ln) + return ln.Addr().String() +} + +func TestL7ThroughServer(t *testing.T) { + certPath, keyPath := testCert(t, "l7srv.test") + + backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For")) + })) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]RouteInfo{ + "l7srv.test": { + Backend: backendAddr, + Mode: "l7", + TLSCert: certPath, + TLSKey: keyPath, + }, + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + // Connect with TLS and make an HTTP/2 request. + tlsConf := &tls.Config{ + ServerName: "l7srv.test", + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } + conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 5 * time.Second}, + "tcp", proxyAddr, tlsConf, + ) + if err != nil { + t.Fatalf("TLS dial: %v", err) + } + defer conn.Close() + + tr := &http2.Transport{} + h2conn, err := tr.NewClientConn(conn) + if err != nil { + t.Fatalf("h2 client conn: %v", err) + } + + req, _ := http.NewRequest("GET", "https://l7srv.test/hello", nil) + resp, err := h2conn.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip: %v", err) + } + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + // The X-Forwarded-For should be the TCP source IP (127.0.0.1) since + // no PROXY protocol is in use. + if resp.StatusCode != 200 { + t.Fatalf("status = %d, want 200", resp.StatusCode) + } + got := string(body) + if got != "ok path=/hello xff=127.0.0.1" { + t.Fatalf("body = %q", got) + } +} + +func TestMixedL4L7SameListener(t *testing.T) { + certPath, keyPath := testCert(t, "l7mixed.test") + + // L4 backend: echo server. + l4BackendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("l4 backend listen: %v", err) + } + defer l4BackendLn.Close() + go echoServer(t, l4BackendLn) + + // L7 backend: h2c HTTP server. + l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "l7-response") + })) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]RouteInfo{ + "l4echo.test": l4Route(l4BackendLn.Addr().String()), + "l7mixed.test": { + Backend: l7BackendAddr, + Mode: "l7", + TLSCert: certPath, + TLSKey: keyPath, + }, + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + // Test L4 route works: send ClientHello, get echo. + l4Conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial L4: %v", err) + } + defer l4Conn.Close() + hello := buildClientHello("l4echo.test") + l4Conn.Write(hello) + echoed := make([]byte, len(hello)) + l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.ReadFull(l4Conn, echoed); err != nil { + t.Fatalf("L4 echo read: %v", err) + } + + // Test L7 route works: TLS + HTTP/2. + tlsConf := &tls.Config{ + ServerName: "l7mixed.test", + InsecureSkipVerify: true, + NextProtos: []string{"h2"}, + } + l7Conn, err := tls.DialWithDialer( + &net.Dialer{Timeout: 5 * time.Second}, + "tcp", proxyAddr, tlsConf, + ) + if err != nil { + t.Fatalf("TLS dial L7: %v", err) + } + defer l7Conn.Close() + + tr := &http2.Transport{} + h2conn, err := tr.NewClientConn(l7Conn) + if err != nil { + t.Fatalf("h2 client conn: %v", err) + } + + req, _ := http.NewRequest("GET", "https://l7mixed.test/", nil) + resp, err := h2conn.RoundTrip(req) + if err != nil { + t.Fatalf("L7 RoundTrip: %v", err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + + if string(body) != "l7-response" { + t.Fatalf("L7 body = %q, want %q", body, "l7-response") + } +} + // --- ClientHello builder helpers (mirrors internal/sni test helpers) --- func buildClientHello(serverName string) []byte {