Files
mc-proxy/internal/l7/serve.go
Kyle Isom 5bc8f4fc8e Fix three doc-vs-implementation gaps found during audit
1. DB migration: add CHECK(mode IN ('l4', 'l7')) constraint on the
   routes.mode column. ARCHITECTURE.md documented this constraint but
   migration v2 omitted it. Enforces mode validity at the database
   level in addition to application-level validation.

2. L7 reverse proxy: distinguish timeout errors from connection errors
   in the ErrorHandler. Backend timeouts now return HTTP 504 Gateway
   Timeout instead of 502. Uses errors.Is(context.DeadlineExceeded)
   and net.Error.Timeout() detection. Added isTimeoutError unit tests.

3. Config validation: warn when L4 routes have tls_cert or tls_key set
   (they are silently ignored). ARCHITECTURE.md documented this warning
   but config.validate() did not emit it. Uses slog.Warn.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 14:25:41 -07:00

255 lines
6.9 KiB
Go

// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
package l7
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
)
// RouteConfig holds the L7 route parameters needed by the l7 package.
type RouteConfig struct {
Backend string
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
ConnectTimeout time.Duration
}
// contextKey is an unexported type for context keys in this package.
type contextKey int
const clientAddrKey contextKey = 0
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
// handshake with the client using the route's certificate, then reverse
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
//
// peeked contains the TLS ClientHello bytes that were read during SNI
// extraction. They are replayed into the TLS handshake via PrefixConn.
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
// Load the TLS certificate for this route.
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
if err != nil {
return fmt.Errorf("loading TLS cert/key: %w", err)
}
// Wrap the connection to replay the peeked ClientHello.
pc := NewPrefixConn(conn, peeked)
tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
tlsConn := tls.Server(pc, tlsConf)
// Complete the TLS handshake with a timeout.
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return fmt.Errorf("setting handshake deadline: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
return fmt.Errorf("TLS handshake: %w", err)
}
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("clearing handshake deadline: %w", err)
}
// Build the reverse proxy handler.
rp, err := newReverseProxy(route, logger)
if err != nil {
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Wrap the handler to inject the real client IP into the request context.
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
rp.ServeHTTP(w, r)
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
// otherwise fall back to HTTP/1.1.
proto := tlsConn.ConnectionState().NegotiatedProtocol
if proto == "h2" {
h2srv := &http2.Server{}
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
Context: ctx,
Handler: handler,
})
} else {
// HTTP/1.1 fallback: serve a single connection.
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: 30 * time.Second,
}
singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn)
}
return nil
}
// newReverseProxy creates an httputil.ReverseProxy for the given route.
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
scheme := "http"
if route.BackendTLS {
scheme = "https"
}
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
if err != nil {
return nil, fmt.Errorf("parsing backend URL: %w", err)
}
transport, err := newTransport(route)
if err != nil {
return nil, err
}
rp := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(target)
// Preserve the original Host header from the client.
pr.Out.Host = pr.In.Host
// Inject forwarding headers from the real client IP.
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
setForwardingHeaders(pr.Out, addr)
}
},
Transport: transport,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
if isTimeoutError(err) {
w.WriteHeader(http.StatusGatewayTimeout)
} else {
w.WriteHeader(http.StatusBadGateway)
}
},
}
return rp, nil
}
// newTransport creates the HTTP transport for connecting to the backend.
func newTransport(route RouteConfig) (http.RoundTripper, error) {
connectTimeout := route.ConnectTimeout
if connectTimeout == 0 {
connectTimeout = 5 * time.Second
}
if route.BackendTLS {
// TLS to backend (h2 over TLS).
return &http2.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}, nil
}
// h2c: HTTP/2 over plaintext TCP.
return &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
if err != nil {
return nil, err
}
return conn, nil
},
}, nil
}
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
d := &net.Dialer{Timeout: timeout}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
if route.SendProxyProtocol {
// Get the real client IP from the context if available.
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
}
}
}
return conn, nil
}
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
clientIP := clientAddr.Addr().String()
r.Header.Set("X-Forwarded-For", clientIP)
r.Header.Set("X-Forwarded-Proto", "https")
r.Header.Set("X-Real-IP", clientIP)
}
// isTimeoutError returns true if the error is a timeout (context deadline
// exceeded or net.Error timeout).
func isTimeoutError(err error) bool {
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
return false
}
// singleConnListener is a net.Listener that returns a single connection once,
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
type singleConnListener struct {
conn net.Conn
ch chan net.Conn
done chan struct{}
}
func newSingleConnListener(conn net.Conn) *singleConnListener {
ch := make(chan net.Conn, 1)
ch <- conn
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
select {
case c := <-l.ch:
return c, nil
case <-l.done:
return nil, net.ErrClosed
}
}
func (l *singleConnListener) Close() error {
select {
case <-l.done:
default:
close(l.done)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}