Add L7 TLS-terminating HTTP/2 reverse proxy

New internal/l7 package implements TLS termination and HTTP/2 reverse
proxying for L7 routes. The proxy terminates the client TLS connection
using per-route certificates, then forwards HTTP/2 traffic to backends
over h2c (plaintext HTTP/2) or h2 (re-encrypted TLS).

PrefixConn replays the peeked ClientHello bytes into crypto/tls.Server
so the TLS handshake sees the complete ClientHello despite SNI
extraction having already read it.

Serve() is the L7 entry point: TLS handshake with route certificate,
ALPN negotiation (h2 preferred, HTTP/1.1 fallback), then HTTP reverse
proxy via httputil.ReverseProxy. Backend transport uses h2c by default
(AllowHTTP + plain TCP dial) or h2-over-TLS when backend_tls is set.

Forwarding headers (X-Forwarded-For, X-Forwarded-Proto, X-Real-IP)
are injected from the real client IP in the Rewrite function. PROXY
protocol v2 is sent to backends when send_proxy_protocol is enabled,
using the request context to carry the client address through the
HTTP/2 transport's dial function.

Server integration: handleConn dispatches to handleL7 when route.Mode
is "l7". The L7 handler converts RouteInfo to l7.RouteConfig and
delegates to l7.Serve.

L7 package tests: PrefixConn (4 tests), h2c backend round-trip,
forwarding header injection, backend unreachable (502), multiple
HTTP/2 requests over one connection.

Server integration tests: L7 route through full server pipeline with
TLS client, mixed L4+L7 routes on the same listener verifying both
paths work independently.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 13:43:20 -07:00
parent 1ad9a1a43b
commit 97909b7fbc
7 changed files with 1038 additions and 7 deletions

View File

@@ -28,10 +28,10 @@ proceeds. Each item is marked:
## Phase 3: L7 Proxying ## Phase 3: L7 Proxying
- [ ] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point) - [x] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point)
- [ ] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`) - [x] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`)
- [ ] 3.3 PROXY protocol sending in L7 path - [x] 3.3 PROXY protocol sending in L7 path
- [ ] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7) - [x] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7)
## Phase 4: gRPC API & CLI Updates ## Phase 4: gRPC API & CLI Updates

28
internal/l7/prefixconn.go Normal file
View File

@@ -0,0 +1,28 @@
package l7
import "net"
// PrefixConn wraps a net.Conn, prepending buffered bytes before reading
// from the underlying connection. This is used to replay the TLS ClientHello
// bytes that were peeked during SNI extraction back into crypto/tls.Server.
type PrefixConn struct {
net.Conn
prefix []byte
off int
}
// NewPrefixConn creates a PrefixConn that returns prefix bytes first,
// then reads from the underlying conn.
func NewPrefixConn(conn net.Conn, prefix []byte) *PrefixConn {
return &PrefixConn{Conn: conn, prefix: prefix}
}
// Read returns buffered prefix bytes first, then reads from the underlying conn.
func (pc *PrefixConn) Read(b []byte) (int, error) {
if pc.off < len(pc.prefix) {
n := copy(b, pc.prefix[pc.off:])
pc.off += n
return n, nil
}
return pc.Conn.Read(b)
}

View File

@@ -0,0 +1,154 @@
package l7
import (
"io"
"net"
"testing"
"time"
)
func TestPrefixConnRead(t *testing.T) {
// Create a TCP pair.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("WORLD"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
pc := NewPrefixConn(conn, []byte("HELLO"))
// Read all data: should get "HELLOWORLD".
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(all) != "HELLOWORLD" {
t.Fatalf("got %q, want %q", all, "HELLOWORLD")
}
}
func TestPrefixConnSmallReads(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("CD"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
pc := NewPrefixConn(conn, []byte("AB"))
// Read 1 byte at a time from the prefix.
buf := make([]byte, 1)
n, err := pc.Read(buf)
if err != nil || n != 1 || buf[0] != 'A' {
t.Fatalf("first read: n=%d, err=%v, buf=%q", n, err, buf[:n])
}
n, err = pc.Read(buf)
if err != nil || n != 1 || buf[0] != 'B' {
t.Fatalf("second read: n=%d, err=%v, buf=%q", n, err, buf[:n])
}
// Now reads come from the underlying conn.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
rest, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(rest) != "CD" {
t.Fatalf("got %q, want %q", rest, "CD")
}
}
func TestPrefixConnEmptyPrefix(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("DATA"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
pc := NewPrefixConn(conn, nil)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(all) != "DATA" {
t.Fatalf("got %q, want %q", all, "DATA")
}
}
func TestPrefixConnDelegates(t *testing.T) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
go func() {
conn, _ := ln.Accept()
if conn != nil {
conn.Close()
}
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
pc := NewPrefixConn(conn, []byte("X"))
// RemoteAddr, LocalAddr should delegate.
if pc.RemoteAddr() == nil {
t.Fatal("RemoteAddr returned nil")
}
if pc.LocalAddr() == nil {
t.Fatal("LocalAddr returned nil")
}
}

236
internal/l7/serve.go Normal file
View File

@@ -0,0 +1,236 @@
// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
package l7
import (
"context"
"crypto/tls"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
)
// RouteConfig holds the L7 route parameters needed by the l7 package.
type RouteConfig struct {
Backend string
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
ConnectTimeout time.Duration
}
// contextKey is an unexported type for context keys in this package.
type contextKey int
const clientAddrKey contextKey = 0
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
// handshake with the client using the route's certificate, then reverse
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
//
// peeked contains the TLS ClientHello bytes that were read during SNI
// extraction. They are replayed into the TLS handshake via PrefixConn.
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
// Load the TLS certificate for this route.
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
if err != nil {
return fmt.Errorf("loading TLS cert/key: %w", err)
}
// Wrap the connection to replay the peeked ClientHello.
pc := NewPrefixConn(conn, peeked)
tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
tlsConn := tls.Server(pc, tlsConf)
// Complete the TLS handshake with a timeout.
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return fmt.Errorf("setting handshake deadline: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
return fmt.Errorf("TLS handshake: %w", err)
}
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("clearing handshake deadline: %w", err)
}
// Build the reverse proxy handler.
rp, err := newReverseProxy(route, logger)
if err != nil {
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Wrap the handler to inject the real client IP into the request context.
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
rp.ServeHTTP(w, r)
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
// otherwise fall back to HTTP/1.1.
proto := tlsConn.ConnectionState().NegotiatedProtocol
if proto == "h2" {
h2srv := &http2.Server{}
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
Context: ctx,
Handler: handler,
})
} else {
// HTTP/1.1 fallback: serve a single connection.
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: 30 * time.Second,
}
singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn)
}
return nil
}
// newReverseProxy creates an httputil.ReverseProxy for the given route.
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
scheme := "http"
if route.BackendTLS {
scheme = "https"
}
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
if err != nil {
return nil, fmt.Errorf("parsing backend URL: %w", err)
}
transport, err := newTransport(route)
if err != nil {
return nil, err
}
rp := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(target)
// Preserve the original Host header from the client.
pr.Out.Host = pr.In.Host
// Inject forwarding headers from the real client IP.
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
setForwardingHeaders(pr.Out, addr)
}
},
Transport: transport,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
w.WriteHeader(http.StatusBadGateway)
},
}
return rp, nil
}
// newTransport creates the HTTP transport for connecting to the backend.
func newTransport(route RouteConfig) (http.RoundTripper, error) {
connectTimeout := route.ConnectTimeout
if connectTimeout == 0 {
connectTimeout = 5 * time.Second
}
if route.BackendTLS {
// TLS to backend (h2 over TLS).
return &http2.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}, nil
}
// h2c: HTTP/2 over plaintext TCP.
return &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
if err != nil {
return nil, err
}
return conn, nil
},
}, nil
}
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
d := &net.Dialer{Timeout: timeout}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
if route.SendProxyProtocol {
// Get the real client IP from the context if available.
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
}
}
}
return conn, nil
}
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
clientIP := clientAddr.Addr().String()
r.Header.Set("X-Forwarded-For", clientIP)
r.Header.Set("X-Forwarded-Proto", "https")
r.Header.Set("X-Real-IP", clientIP)
}
// singleConnListener is a net.Listener that returns a single connection once,
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
type singleConnListener struct {
conn net.Conn
ch chan net.Conn
done chan struct{}
}
func newSingleConnListener(conn net.Conn) *singleConnListener {
ch := make(chan net.Conn, 1)
ch <- conn
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
select {
case c := <-l.ch:
return c, nil
case <-l.done:
return nil, net.ErrClosed
}
}
func (l *singleConnListener) Close() error {
select {
case <-l.done:
default:
close(l.done)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}

363
internal/l7/serve_test.go Normal file
View File

@@ -0,0 +1,363 @@
package l7
import (
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"log/slog"
"math/big"
"net"
"net/http"
"net/netip"
"os"
"path/filepath"
"testing"
"time"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
)
// testCert generates a self-signed TLS certificate for the given hostname
// and writes the cert/key to temporary files, returning their paths.
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generating key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: hostname},
DNSNames: []string{hostname},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("creating certificate: %v", err)
}
dir := t.TempDir()
certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")
certFile, err := os.Create(certPath)
if err != nil {
t.Fatalf("creating cert file: %v", err)
}
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close()
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatalf("marshaling key: %v", err)
}
keyFile, err := os.Create(keyPath)
if err != nil {
t.Fatalf("creating key file: %v", err)
}
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
keyFile.Close()
return certPath, keyPath
}
// startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that
// responds with the given status and body. Returns the listener address.
func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Helper()
h2s := &http2.Server{}
srv := &http.Server{
Handler: h2c.NewHandler(handler, h2s),
ReadHeaderTimeout: 5 * time.Second,
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() {
srv.Close()
ln.Close()
})
go srv.Serve(ln)
return ln.Addr().String()
}
// dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake
// with the given serverName (skipping cert verification for self-signed),
// and returns an *http.Client configured to use that connection for HTTP/2.
func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
t.Helper()
tlsConf := &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
t.Cleanup(func() { conn.Close() })
// Create an HTTP/2 client transport over this single connection.
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
if err != nil {
t.Fatalf("creating h2 client conn: %v", err)
}
return &http.Client{
Transport: &singleConnRoundTripper{cc: h2conn},
}
}
// singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2
// client connection.
type singleConnRoundTripper struct {
cc *http2.ClientConn
}
func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return s.cc.RoundTrip(req)
}
// serveL7Route starts l7.Serve in a goroutine for a single connection.
// Returns when the goroutine completes.
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
ctx := context.Background()
go func() {
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
if l7Err != nil {
t.Logf("l7.Serve: %v", l7Err)
}
}()
}
func TestL7H2CBackend(t *testing.T) {
certPath, keyPath := testCert(t, "l7.test")
// Start an h2c backend.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok")
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
}))
// Start a TCP listener for the L7 proxy.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
// Accept one connection and run L7 serve.
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
// No peeked bytes — the client is connecting directly with TLS.
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
// Connect as an HTTP/2 TLS client.
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
body, _ := io.ReadAll(resp.Body)
if got := string(body); got != "hello from backend, path=/foo" {
t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo")
}
if resp.Header.Get("X-Backend") != "ok" {
t.Fatalf("X-Backend header missing or wrong")
}
}
func TestL7ForwardingHeaders(t *testing.T) {
certPath, keyPath := testCert(t, "headers.test")
// Backend that echoes the forwarding headers.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
r.Header.Get("X-Forwarded-For"),
r.Header.Get("X-Forwarded-Proto"),
r.Header.Get("X-Real-IP"),
)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
resp, err := client.Get("https://headers.test/")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
if string(body) != want {
t.Fatalf("body = %q, want %q", body, want)
}
}
func TestL7BackendUnreachable(t *testing.T) {
certPath, keyPath := testCert(t, "unreachable.test")
// Find a port that nothing is listening on.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: deadAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 1 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
resp, err := client.Get("https://unreachable.test/")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want 502", resp.StatusCode)
}
}
func TestL7MultipleRequests(t *testing.T) {
certPath, keyPath := testCert(t, "multi.test")
var reqCount int
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
// Send multiple requests over the same HTTP/2 connection.
for i := range 3 {
path := fmt.Sprintf("/req%d", i)
resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path))
if err != nil {
t.Fatalf("GET %s: %v", path, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
want := fmt.Sprintf("req=%d path=%s", i+1, path)
if string(body) != want {
t.Fatalf("request %d: body = %q, want %q", i, body, want)
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/l7"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"git.wntrmute.dev/kyle/mc-proxy/internal/sni" "git.wntrmute.dev/kyle/mc-proxy/internal/sni"
@@ -298,11 +299,10 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
return return
} }
// Dispatch based on route mode. L7 will be implemented in a later phase. // Dispatch based on route mode.
switch route.Mode { switch route.Mode {
case "l7": case "l7":
s.logger.Error("L7 mode not yet implemented", "hostname", hostname) s.handleL7(ctx, conn, addr, addrPort, hostname, route, peeked)
return
default: default:
s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked) s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked)
} }
@@ -340,3 +340,25 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
"backend_bytes", result.BackendBytes, "backend_bytes", result.BackendBytes,
) )
} }
// handleL7 handles an L7 (TLS-terminating) connection.
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
rc := l7.RouteConfig{
Backend: route.Backend,
TLSCert: route.TLSCert,
TLSKey: route.TLSKey,
BackendTLS: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
}
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {
if ctx.Err() == nil {
s.logger.Debug("L7 serve ended", "hostname", hostname, "error", err)
}
}
s.logger.Info("L7 connection closed", "addr", addr, "hostname", hostname)
}

View File

@@ -3,11 +3,23 @@ package server
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary" "encoding/binary"
"encoding/pem"
"fmt"
"io" "io"
"log/slog" "log/slog"
"math/big"
"net" "net"
"net/http"
"net/netip" "net/netip"
"os"
"path/filepath"
"sync" "sync"
"testing" "testing"
"time" "time"
@@ -15,6 +27,8 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
"golang.org/x/net/http2/h2c"
) )
// l4Route creates a RouteInfo for an L4 passthrough route. // l4Route creates a RouteInfo for an L4 passthrough route.
@@ -1038,6 +1052,220 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
wg.Wait() wg.Wait()
} }
// --- L7 server integration tests ---
// testCert generates a self-signed TLS certificate for the given hostname.
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generating key: %v", err)
}
tmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: hostname},
DNSNames: []string{hostname},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
t.Fatalf("creating certificate: %v", err)
}
dir := t.TempDir()
certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")
cf, _ := os.Create(certPath)
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
cf.Close()
keyDER, _ := x509.MarshalECPrivateKey(key)
kf, _ := os.Create(keyPath)
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
kf.Close()
return
}
// startH2CBackend starts an h2c backend for testing.
func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Helper()
h2s := &http2.Server{}
srv := &http.Server{
Handler: h2c.NewHandler(handler, h2s),
ReadHeaderTimeout: 5 * time.Second,
}
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { srv.Close(); ln.Close() })
go srv.Serve(ln)
return ln.Addr().String()
}
func TestL7ThroughServer(t *testing.T) {
certPath, keyPath := testCert(t, "l7srv.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"l7srv.test": {
Backend: backendAddr,
Mode: "l7",
TLSCert: certPath,
TLSKey: keyPath,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Connect with TLS and make an HTTP/2 request.
tlsConf := &tls.Config{
ServerName: "l7srv.test",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
defer conn.Close()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
if err != nil {
t.Fatalf("h2 client conn: %v", err)
}
req, _ := http.NewRequest("GET", "https://l7srv.test/hello", nil)
resp, err := h2conn.RoundTrip(req)
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
// no PROXY protocol is in use.
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
}
got := string(body)
if got != "ok path=/hello xff=127.0.0.1" {
t.Fatalf("body = %q", got)
}
}
func TestMixedL4L7SameListener(t *testing.T) {
certPath, keyPath := testCert(t, "l7mixed.test")
// L4 backend: echo server.
l4BackendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("l4 backend listen: %v", err)
}
defer l4BackendLn.Close()
go echoServer(t, l4BackendLn)
// L7 backend: h2c HTTP server.
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "l7-response")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]RouteInfo{
"l4echo.test": l4Route(l4BackendLn.Addr().String()),
"l7mixed.test": {
Backend: l7BackendAddr,
Mode: "l7",
TLSCert: certPath,
TLSKey: keyPath,
},
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Test L4 route works: send ClientHello, get echo.
l4Conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial L4: %v", err)
}
defer l4Conn.Close()
hello := buildClientHello("l4echo.test")
l4Conn.Write(hello)
echoed := make([]byte, len(hello))
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
t.Fatalf("L4 echo read: %v", err)
}
// Test L7 route works: TLS + HTTP/2.
tlsConf := &tls.Config{
ServerName: "l7mixed.test",
InsecureSkipVerify: true,
NextProtos: []string{"h2"},
}
l7Conn, err := tls.DialWithDialer(
&net.Dialer{Timeout: 5 * time.Second},
"tcp", proxyAddr, tlsConf,
)
if err != nil {
t.Fatalf("TLS dial L7: %v", err)
}
defer l7Conn.Close()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(l7Conn)
if err != nil {
t.Fatalf("h2 client conn: %v", err)
}
req, _ := http.NewRequest("GET", "https://l7mixed.test/", nil)
resp, err := h2conn.RoundTrip(req)
if err != nil {
t.Fatalf("L7 RoundTrip: %v", err)
}
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if string(body) != "l7-response" {
t.Fatalf("L7 body = %q, want %q", body, "l7-response")
}
}
// --- ClientHello builder helpers (mirrors internal/sni test helpers) --- // --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
func buildClientHello(serverName string) []byte { func buildClientHello(serverName string) []byte {