Per-route HTTP-level blocking policies for L7 routes. Two rule types: block_user_agent (substring match against User-Agent, returns 403) and require_header (named header must be present, returns 403). Config: L7Policy struct with type/value fields, added as L7Policies slice on Route. Validated in config (type enum, non-empty value, warning if set on L4 routes). DB: Migration 4 creates l7_policies table with route_id FK (cascade delete), type CHECK constraint, UNIQUE(route_id, type, value). New l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy, GetRouteID. Seed updated to persist policies from config. L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates rules in order, returns 403 on first match, no-op if empty. Composed into the handler chain between context injection and reverse proxy. Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy mutation methods on ListenerState. handleL7 threads policies into l7.RouteConfig. Startup loads policies per L7 route from DB. Proto: L7Policy message, repeated l7_policies on Route. Three new RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the write-through pattern. Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy methods. CLI: mcproxyctl policies list/add/remove subcommands. Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match, header present/absent, multiple rules). 4 DB tests (CRUD, cascade, duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation). 2 end-to-end L7 tests (UA block, required header with allow/deny). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
258 lines
7.0 KiB
Go
258 lines
7.0 KiB
Go
// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
|
|
package l7
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
// RouteConfig holds the L7 route parameters needed by the l7 package.
|
|
type RouteConfig struct {
|
|
Backend string
|
|
TLSCert string
|
|
TLSKey string
|
|
BackendTLS bool
|
|
SendProxyProtocol bool
|
|
ConnectTimeout time.Duration
|
|
Policies []PolicyRule
|
|
}
|
|
|
|
// contextKey is an unexported type for context keys in this package.
|
|
type contextKey int
|
|
|
|
const clientAddrKey contextKey = 0
|
|
|
|
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
|
|
// handshake with the client using the route's certificate, then reverse
|
|
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
|
|
//
|
|
// peeked contains the TLS ClientHello bytes that were read during SNI
|
|
// extraction. They are replayed into the TLS handshake via PrefixConn.
|
|
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
|
|
// Load the TLS certificate for this route.
|
|
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
|
|
if err != nil {
|
|
return fmt.Errorf("loading TLS cert/key: %w", err)
|
|
}
|
|
|
|
// Wrap the connection to replay the peeked ClientHello.
|
|
pc := NewPrefixConn(conn, peeked)
|
|
|
|
tlsConf := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
tlsConn := tls.Server(pc, tlsConf)
|
|
|
|
// Complete the TLS handshake with a timeout.
|
|
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return fmt.Errorf("setting handshake deadline: %w", err)
|
|
}
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return fmt.Errorf("TLS handshake: %w", err)
|
|
}
|
|
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
|
return fmt.Errorf("clearing handshake deadline: %w", err)
|
|
}
|
|
|
|
// Build the reverse proxy handler.
|
|
rp, err := newReverseProxy(route, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("creating reverse proxy: %w", err)
|
|
}
|
|
|
|
// Build handler chain: context injection → L7 policies → reverse proxy.
|
|
var inner http.Handler = rp
|
|
inner = PolicyMiddleware(route.Policies, inner)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
|
inner.ServeHTTP(w, r)
|
|
})
|
|
|
|
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
|
// otherwise fall back to HTTP/1.1.
|
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
|
|
|
if proto == "h2" {
|
|
h2srv := &http2.Server{}
|
|
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
|
|
Context: ctx,
|
|
Handler: handler,
|
|
})
|
|
} else {
|
|
// HTTP/1.1 fallback: serve a single connection.
|
|
srv := &http.Server{
|
|
Handler: handler,
|
|
ReadHeaderTimeout: 30 * time.Second,
|
|
}
|
|
singleConn := newSingleConnListener(tlsConn)
|
|
srv.Serve(singleConn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newReverseProxy creates an httputil.ReverseProxy for the given route.
|
|
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
|
|
scheme := "http"
|
|
if route.BackendTLS {
|
|
scheme = "https"
|
|
}
|
|
|
|
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing backend URL: %w", err)
|
|
}
|
|
|
|
transport, err := newTransport(route)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rp := &httputil.ReverseProxy{
|
|
Rewrite: func(pr *httputil.ProxyRequest) {
|
|
pr.SetURL(target)
|
|
// Preserve the original Host header from the client.
|
|
pr.Out.Host = pr.In.Host
|
|
// Inject forwarding headers from the real client IP.
|
|
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
|
|
setForwardingHeaders(pr.Out, addr)
|
|
}
|
|
},
|
|
Transport: transport,
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
|
|
if isTimeoutError(err) {
|
|
w.WriteHeader(http.StatusGatewayTimeout)
|
|
} else {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
},
|
|
}
|
|
|
|
return rp, nil
|
|
}
|
|
|
|
// newTransport creates the HTTP transport for connecting to the backend.
|
|
func newTransport(route RouteConfig) (http.RoundTripper, error) {
|
|
connectTimeout := route.ConnectTimeout
|
|
if connectTimeout == 0 {
|
|
connectTimeout = 5 * time.Second
|
|
}
|
|
|
|
if route.BackendTLS {
|
|
// TLS to backend (h2 over TLS).
|
|
return &http2.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// h2c: HTTP/2 over plaintext TCP.
|
|
return &http2.Transport{
|
|
AllowHTTP: true,
|
|
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
|
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
|
|
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
|
|
d := &net.Dialer{Timeout: timeout}
|
|
conn, err := d.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if route.SendProxyProtocol {
|
|
// Get the real client IP from the context if available.
|
|
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
|
|
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
|
if clientAddr.IsValid() {
|
|
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
|
|
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
|
|
clientIP := clientAddr.Addr().String()
|
|
r.Header.Set("X-Forwarded-For", clientIP)
|
|
r.Header.Set("X-Forwarded-Proto", "https")
|
|
r.Header.Set("X-Real-IP", clientIP)
|
|
}
|
|
|
|
// isTimeoutError returns true if the error is a timeout (context deadline
|
|
// exceeded or net.Error timeout).
|
|
func isTimeoutError(err error) bool {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return true
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// singleConnListener is a net.Listener that returns a single connection once,
|
|
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
|
|
type singleConnListener struct {
|
|
conn net.Conn
|
|
ch chan net.Conn
|
|
done chan struct{}
|
|
}
|
|
|
|
func newSingleConnListener(conn net.Conn) *singleConnListener {
|
|
ch := make(chan net.Conn, 1)
|
|
ch <- conn
|
|
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
|
|
}
|
|
|
|
func (l *singleConnListener) Accept() (net.Conn, error) {
|
|
select {
|
|
case c := <-l.ch:
|
|
return c, nil
|
|
case <-l.done:
|
|
return nil, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
func (l *singleConnListener) Close() error {
|
|
select {
|
|
case <-l.done:
|
|
default:
|
|
close(l.done)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *singleConnListener) Addr() net.Addr {
|
|
return l.conn.LocalAddr()
|
|
}
|