Files
mc-proxy/internal/l7/serve.go
Kyle Isom 42c7fffc3e Add L7 policies for user-agent blocking and required headers
Per-route HTTP-level blocking policies for L7 routes. Two rule types:
block_user_agent (substring match against User-Agent, returns 403)
and require_header (named header must be present, returns 403).

Config: L7Policy struct with type/value fields, added as L7Policies
slice on Route. Validated in config (type enum, non-empty value,
warning if set on L4 routes).

DB: Migration 4 creates l7_policies table with route_id FK (cascade
delete), type CHECK constraint, UNIQUE(route_id, type, value). New
l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy,
GetRouteID. Seed updated to persist policies from config.

L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates
rules in order, returns 403 on first match, no-op if empty. Composed
into the handler chain between context injection and reverse proxy.

Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy
mutation methods on ListenerState. handleL7 threads policies into
l7.RouteConfig. Startup loads policies per L7 route from DB.

Proto: L7Policy message, repeated l7_policies on Route. Three new
RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the
write-through pattern.

Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy
methods. CLI: mcproxyctl policies list/add/remove subcommands.

Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match,
header present/absent, multiple rules). 4 DB tests (CRUD, cascade,
duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation).
2 end-to-end L7 tests (UA block, required header with allow/deny).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:11:05 -07:00

258 lines
7.0 KiB
Go

// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
package l7
import (
"context"
"crypto/tls"
"errors"
"fmt"
"log/slog"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2"
)
// RouteConfig holds the L7 route parameters needed by the l7 package.
type RouteConfig struct {
Backend string
TLSCert string
TLSKey string
BackendTLS bool
SendProxyProtocol bool
ConnectTimeout time.Duration
Policies []PolicyRule
}
// contextKey is an unexported type for context keys in this package.
type contextKey int
const clientAddrKey contextKey = 0
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
// handshake with the client using the route's certificate, then reverse
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
//
// peeked contains the TLS ClientHello bytes that were read during SNI
// extraction. They are replayed into the TLS handshake via PrefixConn.
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
// Load the TLS certificate for this route.
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
if err != nil {
return fmt.Errorf("loading TLS cert/key: %w", err)
}
// Wrap the connection to replay the peeked ClientHello.
pc := NewPrefixConn(conn, peeked)
tlsConf := &tls.Config{
Certificates: []tls.Certificate{cert},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
tlsConn := tls.Server(pc, tlsConf)
// Complete the TLS handshake with a timeout.
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
return fmt.Errorf("setting handshake deadline: %w", err)
}
if err := tlsConn.Handshake(); err != nil {
return fmt.Errorf("TLS handshake: %w", err)
}
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
return fmt.Errorf("clearing handshake deadline: %w", err)
}
// Build the reverse proxy handler.
rp, err := newReverseProxy(route, logger)
if err != nil {
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Build handler chain: context injection → L7 policies → reverse proxy.
var inner http.Handler = rp
inner = PolicyMiddleware(route.Policies, inner)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
inner.ServeHTTP(w, r)
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
// otherwise fall back to HTTP/1.1.
proto := tlsConn.ConnectionState().NegotiatedProtocol
if proto == "h2" {
h2srv := &http2.Server{}
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
Context: ctx,
Handler: handler,
})
} else {
// HTTP/1.1 fallback: serve a single connection.
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: 30 * time.Second,
}
singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn)
}
return nil
}
// newReverseProxy creates an httputil.ReverseProxy for the given route.
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
scheme := "http"
if route.BackendTLS {
scheme = "https"
}
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
if err != nil {
return nil, fmt.Errorf("parsing backend URL: %w", err)
}
transport, err := newTransport(route)
if err != nil {
return nil, err
}
rp := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
pr.SetURL(target)
// Preserve the original Host header from the client.
pr.Out.Host = pr.In.Host
// Inject forwarding headers from the real client IP.
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
setForwardingHeaders(pr.Out, addr)
}
},
Transport: transport,
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
if isTimeoutError(err) {
w.WriteHeader(http.StatusGatewayTimeout)
} else {
w.WriteHeader(http.StatusBadGateway)
}
},
}
return rp, nil
}
// newTransport creates the HTTP transport for connecting to the backend.
func newTransport(route RouteConfig) (http.RoundTripper, error) {
connectTimeout := route.ConnectTimeout
if connectTimeout == 0 {
connectTimeout = 5 * time.Second
}
if route.BackendTLS {
// TLS to backend (h2 over TLS).
return &http2.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
}, nil
}
// h2c: HTTP/2 over plaintext TCP.
return &http2.Transport{
AllowHTTP: true,
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
if err != nil {
return nil, err
}
return conn, nil
},
}, nil
}
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
d := &net.Dialer{Timeout: timeout}
conn, err := d.DialContext(ctx, network, addr)
if err != nil {
return nil, err
}
if route.SendProxyProtocol {
// Get the real client IP from the context if available.
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
}
}
}
return conn, nil
}
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
clientIP := clientAddr.Addr().String()
r.Header.Set("X-Forwarded-For", clientIP)
r.Header.Set("X-Forwarded-Proto", "https")
r.Header.Set("X-Real-IP", clientIP)
}
// isTimeoutError returns true if the error is a timeout (context deadline
// exceeded or net.Error timeout).
func isTimeoutError(err error) bool {
if errors.Is(err, context.DeadlineExceeded) {
return true
}
var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}
return false
}
// singleConnListener is a net.Listener that returns a single connection once,
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
type singleConnListener struct {
conn net.Conn
ch chan net.Conn
done chan struct{}
}
func newSingleConnListener(conn net.Conn) *singleConnListener {
ch := make(chan net.Conn, 1)
ch <- conn
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
select {
case c := <-l.ch:
return c, nil
case <-l.done:
return nil, net.ErrClosed
}
}
func (l *singleConnListener) Close() error {
select {
case <-l.done:
default:
close(l.done)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}