1. DB migration: add CHECK(mode IN ('l4', 'l7')) constraint on the
routes.mode column. ARCHITECTURE.md documented this constraint but
migration v2 omitted it. Enforces mode validity at the database
level in addition to application-level validation.
2. L7 reverse proxy: distinguish timeout errors from connection errors
in the ErrorHandler. Backend timeouts now return HTTP 504 Gateway
Timeout instead of 502. Uses errors.Is(context.DeadlineExceeded)
and net.Error.Timeout() detection. Added isTimeoutError unit tests.
3. Config validation: warn when L4 routes have tls_cert or tls_key set
(they are silently ignored). ARCHITECTURE.md documented this warning
but config.validate() did not emit it. Uses slog.Warn.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
255 lines
6.9 KiB
Go
255 lines
6.9 KiB
Go
// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
|
|
package l7
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
// RouteConfig holds the L7 route parameters needed by the l7 package.
|
|
type RouteConfig struct {
|
|
Backend string
|
|
TLSCert string
|
|
TLSKey string
|
|
BackendTLS bool
|
|
SendProxyProtocol bool
|
|
ConnectTimeout time.Duration
|
|
}
|
|
|
|
// contextKey is an unexported type for context keys in this package.
|
|
type contextKey int
|
|
|
|
const clientAddrKey contextKey = 0
|
|
|
|
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
|
|
// handshake with the client using the route's certificate, then reverse
|
|
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
|
|
//
|
|
// peeked contains the TLS ClientHello bytes that were read during SNI
|
|
// extraction. They are replayed into the TLS handshake via PrefixConn.
|
|
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
|
|
// Load the TLS certificate for this route.
|
|
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
|
|
if err != nil {
|
|
return fmt.Errorf("loading TLS cert/key: %w", err)
|
|
}
|
|
|
|
// Wrap the connection to replay the peeked ClientHello.
|
|
pc := NewPrefixConn(conn, peeked)
|
|
|
|
tlsConf := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
tlsConn := tls.Server(pc, tlsConf)
|
|
|
|
// Complete the TLS handshake with a timeout.
|
|
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return fmt.Errorf("setting handshake deadline: %w", err)
|
|
}
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return fmt.Errorf("TLS handshake: %w", err)
|
|
}
|
|
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
|
return fmt.Errorf("clearing handshake deadline: %w", err)
|
|
}
|
|
|
|
// Build the reverse proxy handler.
|
|
rp, err := newReverseProxy(route, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("creating reverse proxy: %w", err)
|
|
}
|
|
|
|
// Wrap the handler to inject the real client IP into the request context.
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
|
rp.ServeHTTP(w, r)
|
|
})
|
|
|
|
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
|
// otherwise fall back to HTTP/1.1.
|
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
|
|
|
if proto == "h2" {
|
|
h2srv := &http2.Server{}
|
|
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
|
|
Context: ctx,
|
|
Handler: handler,
|
|
})
|
|
} else {
|
|
// HTTP/1.1 fallback: serve a single connection.
|
|
srv := &http.Server{
|
|
Handler: handler,
|
|
ReadHeaderTimeout: 30 * time.Second,
|
|
}
|
|
singleConn := newSingleConnListener(tlsConn)
|
|
srv.Serve(singleConn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newReverseProxy creates an httputil.ReverseProxy for the given route.
|
|
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
|
|
scheme := "http"
|
|
if route.BackendTLS {
|
|
scheme = "https"
|
|
}
|
|
|
|
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing backend URL: %w", err)
|
|
}
|
|
|
|
transport, err := newTransport(route)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rp := &httputil.ReverseProxy{
|
|
Rewrite: func(pr *httputil.ProxyRequest) {
|
|
pr.SetURL(target)
|
|
// Preserve the original Host header from the client.
|
|
pr.Out.Host = pr.In.Host
|
|
// Inject forwarding headers from the real client IP.
|
|
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
|
|
setForwardingHeaders(pr.Out, addr)
|
|
}
|
|
},
|
|
Transport: transport,
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
|
|
if isTimeoutError(err) {
|
|
w.WriteHeader(http.StatusGatewayTimeout)
|
|
} else {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
},
|
|
}
|
|
|
|
return rp, nil
|
|
}
|
|
|
|
// newTransport creates the HTTP transport for connecting to the backend.
|
|
func newTransport(route RouteConfig) (http.RoundTripper, error) {
|
|
connectTimeout := route.ConnectTimeout
|
|
if connectTimeout == 0 {
|
|
connectTimeout = 5 * time.Second
|
|
}
|
|
|
|
if route.BackendTLS {
|
|
// TLS to backend (h2 over TLS).
|
|
return &http2.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// h2c: HTTP/2 over plaintext TCP.
|
|
return &http2.Transport{
|
|
AllowHTTP: true,
|
|
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
|
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return conn, nil
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
|
|
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
|
|
d := &net.Dialer{Timeout: timeout}
|
|
conn, err := d.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if route.SendProxyProtocol {
|
|
// Get the real client IP from the context if available.
|
|
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
|
|
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
|
if clientAddr.IsValid() {
|
|
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
|
conn.Close()
|
|
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
|
|
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
|
|
clientIP := clientAddr.Addr().String()
|
|
r.Header.Set("X-Forwarded-For", clientIP)
|
|
r.Header.Set("X-Forwarded-Proto", "https")
|
|
r.Header.Set("X-Real-IP", clientIP)
|
|
}
|
|
|
|
// isTimeoutError returns true if the error is a timeout (context deadline
|
|
// exceeded or net.Error timeout).
|
|
func isTimeoutError(err error) bool {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return true
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// singleConnListener is a net.Listener that returns a single connection once,
|
|
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
|
|
type singleConnListener struct {
|
|
conn net.Conn
|
|
ch chan net.Conn
|
|
done chan struct{}
|
|
}
|
|
|
|
func newSingleConnListener(conn net.Conn) *singleConnListener {
|
|
ch := make(chan net.Conn, 1)
|
|
ch <- conn
|
|
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
|
|
}
|
|
|
|
func (l *singleConnListener) Accept() (net.Conn, error) {
|
|
select {
|
|
case c := <-l.ch:
|
|
return c, nil
|
|
case <-l.done:
|
|
return nil, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
func (l *singleConnListener) Close() error {
|
|
select {
|
|
case <-l.done:
|
|
default:
|
|
close(l.done)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *singleConnListener) Addr() net.Addr {
|
|
return l.conn.LocalAddr()
|
|
}
|