- Fix 314 errcheck violations (blank identifier for unrecoverable errors) - Fix errorlint violation (errors.Is for io.EOF) - Remove unused serveL7Route test helper - Simplify Duration.Seconds() selectors in tests - Remove unnecessary fmt.Sprintf in test - Migrate exclusion rules from issues.exclusions to linters.exclusions (v2 schema) - Add gosec test exclusions (G115, G304, G402, G705) - Disable fieldalignment govet analyzer (optimization, not correctness) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
281 lines
8.0 KiB
Go
281 lines
8.0 KiB
Go
// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
|
|
package l7
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"errors"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
|
|
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
|
|
"golang.org/x/net/http2"
|
|
)
|
|
|
|
// RouteConfig holds the L7 route parameters needed by the l7 package.
|
|
type RouteConfig struct {
|
|
Hostname string
|
|
Backend string
|
|
TLSCert string
|
|
TLSKey string
|
|
BackendTLS bool
|
|
SendProxyProtocol bool
|
|
ConnectTimeout time.Duration
|
|
Policies []PolicyRule
|
|
}
|
|
|
|
// statusRecorder wraps http.ResponseWriter to capture the status code.
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (sr *statusRecorder) WriteHeader(code int) {
|
|
sr.status = code
|
|
sr.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
func (sr *statusRecorder) Unwrap() http.ResponseWriter {
|
|
return sr.ResponseWriter
|
|
}
|
|
|
|
// contextKey is an unexported type for context keys in this package.
|
|
type contextKey int
|
|
|
|
const clientAddrKey contextKey = 0
|
|
|
|
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
|
|
// handshake with the client using the route's certificate, then reverse
|
|
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
|
|
//
|
|
// peeked contains the TLS ClientHello bytes that were read during SNI
|
|
// extraction. They are replayed into the TLS handshake via PrefixConn.
|
|
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
|
|
// Load the TLS certificate for this route.
|
|
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
|
|
if err != nil {
|
|
return fmt.Errorf("loading TLS cert/key: %w", err)
|
|
}
|
|
|
|
// Wrap the connection to replay the peeked ClientHello.
|
|
pc := NewPrefixConn(conn, peeked)
|
|
|
|
tlsConf := &tls.Config{
|
|
Certificates: []tls.Certificate{cert},
|
|
NextProtos: []string{"h2", "http/1.1"},
|
|
MinVersion: tls.VersionTLS12,
|
|
}
|
|
|
|
tlsConn := tls.Server(pc, tlsConf)
|
|
|
|
// Complete the TLS handshake with a timeout.
|
|
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
|
return fmt.Errorf("setting handshake deadline: %w", err)
|
|
}
|
|
if err := tlsConn.Handshake(); err != nil {
|
|
return fmt.Errorf("TLS handshake: %w", err)
|
|
}
|
|
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
|
return fmt.Errorf("clearing handshake deadline: %w", err)
|
|
}
|
|
|
|
// Build the reverse proxy handler.
|
|
rp, err := newReverseProxy(route, logger)
|
|
if err != nil {
|
|
return fmt.Errorf("creating reverse proxy: %w", err)
|
|
}
|
|
|
|
// Build handler chain: context injection → metrics → L7 policies → reverse proxy.
|
|
var inner http.Handler = rp
|
|
inner = PolicyMiddleware(route.Policies, route.Hostname, inner)
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
|
sr := &statusRecorder{ResponseWriter: w, status: http.StatusOK}
|
|
inner.ServeHTTP(sr, r)
|
|
metrics.L7ResponsesTotal.WithLabelValues(route.Hostname, strconv.Itoa(sr.status)).Inc()
|
|
})
|
|
|
|
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
|
// otherwise fall back to HTTP/1.1.
|
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
|
|
|
if proto == "h2" {
|
|
h2srv := &http2.Server{}
|
|
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
|
|
Context: ctx,
|
|
Handler: handler,
|
|
})
|
|
} else {
|
|
// HTTP/1.1 fallback: serve a single connection.
|
|
srv := &http.Server{
|
|
Handler: handler,
|
|
ReadHeaderTimeout: 30 * time.Second,
|
|
}
|
|
singleConn := newSingleConnListener(tlsConn)
|
|
_ = srv.Serve(singleConn)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newReverseProxy creates an httputil.ReverseProxy for the given route.
|
|
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
|
|
scheme := "http"
|
|
if route.BackendTLS {
|
|
scheme = "https"
|
|
}
|
|
|
|
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("parsing backend URL: %w", err)
|
|
}
|
|
|
|
transport, err := newTransport(route)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rp := &httputil.ReverseProxy{
|
|
Rewrite: func(pr *httputil.ProxyRequest) {
|
|
pr.SetURL(target)
|
|
// Preserve the original Host header from the client.
|
|
pr.Out.Host = pr.In.Host
|
|
// Inject forwarding headers from the real client IP.
|
|
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
|
|
setForwardingHeaders(pr.Out, addr)
|
|
}
|
|
},
|
|
Transport: transport,
|
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
|
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
|
|
if isTimeoutError(err) {
|
|
w.WriteHeader(http.StatusGatewayTimeout)
|
|
} else {
|
|
w.WriteHeader(http.StatusBadGateway)
|
|
}
|
|
},
|
|
}
|
|
|
|
return rp, nil
|
|
}
|
|
|
|
// newTransport creates the HTTP transport for connecting to the backend.
|
|
func newTransport(route RouteConfig) (http.RoundTripper, error) {
|
|
connectTimeout := route.ConnectTimeout
|
|
if connectTimeout == 0 {
|
|
connectTimeout = 5 * time.Second
|
|
}
|
|
|
|
if route.BackendTLS {
|
|
// TLS to backend (h2 over TLS). Backend cert verification is
|
|
// skipped — the proxy connects to trusted internal backends
|
|
// that may use IP addresses or self-signed certificates.
|
|
return &http2.Transport{
|
|
TLSClientConfig: &tls.Config{
|
|
MinVersion: tls.VersionTLS12,
|
|
InsecureSkipVerify: true, //nolint:gosec // trusted backend
|
|
},
|
|
}, nil
|
|
}
|
|
|
|
// Plain HTTP backend. Use standard http.Transport which speaks
|
|
// HTTP/1.1 by default and can upgrade to h2c if the backend
|
|
// supports it. This handles backends like Gitea that only speak
|
|
// HTTP/1.1.
|
|
return &http.Transport{
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialBackend(ctx, network, addr, connectTimeout, route)
|
|
},
|
|
MaxIdleConnsPerHost: 10,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
}, nil
|
|
}
|
|
|
|
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
|
|
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
|
|
d := &net.Dialer{Timeout: timeout}
|
|
conn, err := d.DialContext(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if route.SendProxyProtocol {
|
|
// Get the real client IP from the context if available.
|
|
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
|
|
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
|
if clientAddr.IsValid() {
|
|
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
|
_ = conn.Close()
|
|
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return conn, nil
|
|
}
|
|
|
|
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
|
|
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
|
|
clientIP := clientAddr.Addr().String()
|
|
r.Header.Set("X-Forwarded-For", clientIP)
|
|
r.Header.Set("X-Forwarded-Proto", "https")
|
|
r.Header.Set("X-Real-IP", clientIP)
|
|
}
|
|
|
|
// isTimeoutError returns true if the error is a timeout (context deadline
|
|
// exceeded or net.Error timeout).
|
|
func isTimeoutError(err error) bool {
|
|
if errors.Is(err, context.DeadlineExceeded) {
|
|
return true
|
|
}
|
|
var netErr net.Error
|
|
if errors.As(err, &netErr) && netErr.Timeout() {
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// singleConnListener is a net.Listener that returns a single connection once,
|
|
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
|
|
type singleConnListener struct {
|
|
conn net.Conn
|
|
ch chan net.Conn
|
|
done chan struct{}
|
|
}
|
|
|
|
func newSingleConnListener(conn net.Conn) *singleConnListener {
|
|
ch := make(chan net.Conn, 1)
|
|
ch <- conn
|
|
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
|
|
}
|
|
|
|
func (l *singleConnListener) Accept() (net.Conn, error) {
|
|
select {
|
|
case c := <-l.ch:
|
|
return c, nil
|
|
case <-l.done:
|
|
return nil, net.ErrClosed
|
|
}
|
|
}
|
|
|
|
func (l *singleConnListener) Close() error {
|
|
select {
|
|
case <-l.done:
|
|
default:
|
|
close(l.done)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (l *singleConnListener) Addr() net.Addr {
|
|
return l.conn.LocalAddr()
|
|
}
|