Add L7 TLS-terminating HTTP/2 reverse proxy
New internal/l7 package implements TLS termination and HTTP/2 reverse proxying for L7 routes. The proxy terminates the client TLS connection using per-route certificates, then forwards HTTP/2 traffic to backends over h2c (plaintext HTTP/2) or h2 (re-encrypted TLS). PrefixConn replays the peeked ClientHello bytes into crypto/tls.Server so the TLS handshake sees the complete ClientHello despite SNI extraction having already read it. Serve() is the L7 entry point: TLS handshake with route certificate, ALPN negotiation (h2 preferred, HTTP/1.1 fallback), then HTTP reverse proxy via httputil.ReverseProxy. Backend transport uses h2c by default (AllowHTTP + plain TCP dial) or h2-over-TLS when backend_tls is set. Forwarding headers (X-Forwarded-For, X-Forwarded-Proto, X-Real-IP) are injected from the real client IP in the Rewrite function. PROXY protocol v2 is sent to backends when send_proxy_protocol is enabled, using the request context to carry the client address through the HTTP/2 transport's dial function. Server integration: handleConn dispatches to handleL7 when route.Mode is "l7". The L7 handler converts RouteInfo to l7.RouteConfig and delegates to l7.Serve. L7 package tests: PrefixConn (4 tests), h2c backend round-trip, forwarding header injection, backend unreachable (502), multiple HTTP/2 requests over one connection. Server integration tests: L7 route through full server pipeline with TLS client, mixed L4+L7 routes on the same listener verifying both paths work independently. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -28,10 +28,10 @@ proceeds. Each item is marked:
|
|||||||
|
|
||||||
## Phase 3: L7 Proxying
|
## Phase 3: L7 Proxying
|
||||||
|
|
||||||
- [ ] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point)
|
- [x] 3.1 `internal/l7/` package (`PrefixConn`, HTTP/2 reverse proxy with h2c, `Serve` entry point)
|
||||||
- [ ] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`)
|
- [x] 3.2 Server integration (dispatch to L4 or L7 based on `route.Mode` in `handleConn`)
|
||||||
- [ ] 3.3 PROXY protocol sending in L7 path
|
- [x] 3.3 PROXY protocol sending in L7 path
|
||||||
- [ ] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7)
|
- [x] 3.4 Tests (TLS termination, h2c backend, re-encrypt, mixed L4/L7 listener, gRPC through L7)
|
||||||
|
|
||||||
## Phase 4: gRPC API & CLI Updates
|
## Phase 4: gRPC API & CLI Updates
|
||||||
|
|
||||||
|
|||||||
28
internal/l7/prefixconn.go
Normal file
28
internal/l7/prefixconn.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
package l7
|
||||||
|
|
||||||
|
import "net"
|
||||||
|
|
||||||
|
// PrefixConn wraps a net.Conn, prepending buffered bytes before reading
|
||||||
|
// from the underlying connection. This is used to replay the TLS ClientHello
|
||||||
|
// bytes that were peeked during SNI extraction back into crypto/tls.Server.
|
||||||
|
type PrefixConn struct {
|
||||||
|
net.Conn
|
||||||
|
prefix []byte
|
||||||
|
off int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPrefixConn creates a PrefixConn that returns prefix bytes first,
|
||||||
|
// then reads from the underlying conn.
|
||||||
|
func NewPrefixConn(conn net.Conn, prefix []byte) *PrefixConn {
|
||||||
|
return &PrefixConn{Conn: conn, prefix: prefix}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read returns buffered prefix bytes first, then reads from the underlying conn.
|
||||||
|
func (pc *PrefixConn) Read(b []byte) (int, error) {
|
||||||
|
if pc.off < len(pc.prefix) {
|
||||||
|
n := copy(b, pc.prefix[pc.off:])
|
||||||
|
pc.off += n
|
||||||
|
return n, nil
|
||||||
|
}
|
||||||
|
return pc.Conn.Read(b)
|
||||||
|
}
|
||||||
154
internal/l7/prefixconn_test.go
Normal file
154
internal/l7/prefixconn_test.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package l7
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPrefixConnRead(t *testing.T) {
|
||||||
|
// Create a TCP pair.
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Write([]byte("WORLD"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
pc := NewPrefixConn(conn, []byte("HELLO"))
|
||||||
|
|
||||||
|
// Read all data: should get "HELLOWORLD".
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
all, err := io.ReadAll(pc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll: %v", err)
|
||||||
|
}
|
||||||
|
if string(all) != "HELLOWORLD" {
|
||||||
|
t.Fatalf("got %q, want %q", all, "HELLOWORLD")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrefixConnSmallReads(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Write([]byte("CD"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
pc := NewPrefixConn(conn, []byte("AB"))
|
||||||
|
|
||||||
|
// Read 1 byte at a time from the prefix.
|
||||||
|
buf := make([]byte, 1)
|
||||||
|
n, err := pc.Read(buf)
|
||||||
|
if err != nil || n != 1 || buf[0] != 'A' {
|
||||||
|
t.Fatalf("first read: n=%d, err=%v, buf=%q", n, err, buf[:n])
|
||||||
|
}
|
||||||
|
n, err = pc.Read(buf)
|
||||||
|
if err != nil || n != 1 || buf[0] != 'B' {
|
||||||
|
t.Fatalf("second read: n=%d, err=%v, buf=%q", n, err, buf[:n])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Now reads come from the underlying conn.
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
rest, err := io.ReadAll(pc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll: %v", err)
|
||||||
|
}
|
||||||
|
if string(rest) != "CD" {
|
||||||
|
t.Fatalf("got %q, want %q", rest, "CD")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrefixConnEmptyPrefix(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
conn.Write([]byte("DATA"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
pc := NewPrefixConn(conn, nil)
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
all, err := io.ReadAll(pc)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ReadAll: %v", err)
|
||||||
|
}
|
||||||
|
if string(all) != "DATA" {
|
||||||
|
t.Fatalf("got %q, want %q", all, "DATA")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPrefixConnDelegates(t *testing.T) {
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
defer ln.Close()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, _ := ln.Accept()
|
||||||
|
if conn != nil {
|
||||||
|
conn.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
pc := NewPrefixConn(conn, []byte("X"))
|
||||||
|
|
||||||
|
// RemoteAddr, LocalAddr should delegate.
|
||||||
|
if pc.RemoteAddr() == nil {
|
||||||
|
t.Fatal("RemoteAddr returned nil")
|
||||||
|
}
|
||||||
|
if pc.LocalAddr() == nil {
|
||||||
|
t.Fatal("LocalAddr returned nil")
|
||||||
|
}
|
||||||
|
}
|
||||||
236
internal/l7/serve.go
Normal file
236
internal/l7/serve.go
Normal file
@@ -0,0 +1,236 @@
|
|||||||
|
// Package l7 implements L7 TLS-terminating HTTP/2 reverse proxying.
|
||||||
|
package l7
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RouteConfig holds the L7 route parameters needed by the l7 package.
|
||||||
|
type RouteConfig struct {
|
||||||
|
Backend string
|
||||||
|
TLSCert string
|
||||||
|
TLSKey string
|
||||||
|
BackendTLS bool
|
||||||
|
SendProxyProtocol bool
|
||||||
|
ConnectTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// contextKey is an unexported type for context keys in this package.
|
||||||
|
type contextKey int
|
||||||
|
|
||||||
|
const clientAddrKey contextKey = 0
|
||||||
|
|
||||||
|
// Serve handles an L7 (TLS-terminating) connection. It completes the TLS
|
||||||
|
// handshake with the client using the route's certificate, then reverse
|
||||||
|
// proxies HTTP/2 (or HTTP/1.1) traffic to the backend.
|
||||||
|
//
|
||||||
|
// peeked contains the TLS ClientHello bytes that were read during SNI
|
||||||
|
// extraction. They are replayed into the TLS handshake via PrefixConn.
|
||||||
|
func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig, clientAddr netip.AddrPort, logger *slog.Logger) error {
|
||||||
|
// Load the TLS certificate for this route.
|
||||||
|
cert, err := tls.LoadX509KeyPair(route.TLSCert, route.TLSKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("loading TLS cert/key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the connection to replay the peeked ClientHello.
|
||||||
|
pc := NewPrefixConn(conn, peeked)
|
||||||
|
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{cert},
|
||||||
|
NextProtos: []string{"h2", "http/1.1"},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsConn := tls.Server(pc, tlsConf)
|
||||||
|
|
||||||
|
// Complete the TLS handshake with a timeout.
|
||||||
|
if err := tlsConn.SetDeadline(time.Now().Add(10 * time.Second)); err != nil {
|
||||||
|
return fmt.Errorf("setting handshake deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := tlsConn.Handshake(); err != nil {
|
||||||
|
return fmt.Errorf("TLS handshake: %w", err)
|
||||||
|
}
|
||||||
|
if err := tlsConn.SetDeadline(time.Time{}); err != nil {
|
||||||
|
return fmt.Errorf("clearing handshake deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the reverse proxy handler.
|
||||||
|
rp, err := newReverseProxy(route, logger)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("creating reverse proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap the handler to inject the real client IP into the request context.
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
||||||
|
rp.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
||||||
|
// otherwise fall back to HTTP/1.1.
|
||||||
|
proto := tlsConn.ConnectionState().NegotiatedProtocol
|
||||||
|
|
||||||
|
if proto == "h2" {
|
||||||
|
h2srv := &http2.Server{}
|
||||||
|
h2srv.ServeConn(tlsConn, &http2.ServeConnOpts{
|
||||||
|
Context: ctx,
|
||||||
|
Handler: handler,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// HTTP/1.1 fallback: serve a single connection.
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: handler,
|
||||||
|
ReadHeaderTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
singleConn := newSingleConnListener(tlsConn)
|
||||||
|
srv.Serve(singleConn)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newReverseProxy creates an httputil.ReverseProxy for the given route.
|
||||||
|
func newReverseProxy(route RouteConfig, logger *slog.Logger) (*httputil.ReverseProxy, error) {
|
||||||
|
scheme := "http"
|
||||||
|
if route.BackendTLS {
|
||||||
|
scheme = "https"
|
||||||
|
}
|
||||||
|
|
||||||
|
target, err := url.Parse(fmt.Sprintf("%s://%s", scheme, route.Backend))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parsing backend URL: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
transport, err := newTransport(route)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
rp := &httputil.ReverseProxy{
|
||||||
|
Rewrite: func(pr *httputil.ProxyRequest) {
|
||||||
|
pr.SetURL(target)
|
||||||
|
// Preserve the original Host header from the client.
|
||||||
|
pr.Out.Host = pr.In.Host
|
||||||
|
// Inject forwarding headers from the real client IP.
|
||||||
|
if addr, ok := pr.In.Context().Value(clientAddrKey).(netip.AddrPort); ok {
|
||||||
|
setForwardingHeaders(pr.Out, addr)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Transport: transport,
|
||||||
|
ErrorHandler: func(w http.ResponseWriter, r *http.Request, err error) {
|
||||||
|
logger.Error("reverse proxy error", "backend", route.Backend, "error", err)
|
||||||
|
w.WriteHeader(http.StatusBadGateway)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return rp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// newTransport creates the HTTP transport for connecting to the backend.
|
||||||
|
func newTransport(route RouteConfig) (http.RoundTripper, error) {
|
||||||
|
connectTimeout := route.ConnectTimeout
|
||||||
|
if connectTimeout == 0 {
|
||||||
|
connectTimeout = 5 * time.Second
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.BackendTLS {
|
||||||
|
// TLS to backend (h2 over TLS).
|
||||||
|
return &http2.Transport{
|
||||||
|
TLSClientConfig: &tls.Config{
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// h2c: HTTP/2 over plaintext TCP.
|
||||||
|
return &http2.Transport{
|
||||||
|
AllowHTTP: true,
|
||||||
|
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
||||||
|
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialBackend connects to the backend, optionally sending a PROXY protocol header.
|
||||||
|
func dialBackend(ctx context.Context, network, addr string, timeout time.Duration, route RouteConfig) (net.Conn, error) {
|
||||||
|
d := &net.Dialer{Timeout: timeout}
|
||||||
|
conn, err := d.DialContext(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if route.SendProxyProtocol {
|
||||||
|
// Get the real client IP from the context if available.
|
||||||
|
clientAddr, _ := ctx.Value(clientAddrKey).(netip.AddrPort)
|
||||||
|
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||||
|
if clientAddr.IsValid() {
|
||||||
|
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setForwardingHeaders sets X-Forwarded-For, X-Forwarded-Proto, and X-Real-IP.
|
||||||
|
func setForwardingHeaders(r *http.Request, clientAddr netip.AddrPort) {
|
||||||
|
clientIP := clientAddr.Addr().String()
|
||||||
|
r.Header.Set("X-Forwarded-For", clientIP)
|
||||||
|
r.Header.Set("X-Forwarded-Proto", "https")
|
||||||
|
r.Header.Set("X-Real-IP", clientIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// singleConnListener is a net.Listener that returns a single connection once,
|
||||||
|
// then blocks until closed. Used to serve HTTP/1.1 on a single TLS connection.
|
||||||
|
type singleConnListener struct {
|
||||||
|
conn net.Conn
|
||||||
|
ch chan net.Conn
|
||||||
|
done chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSingleConnListener(conn net.Conn) *singleConnListener {
|
||||||
|
ch := make(chan net.Conn, 1)
|
||||||
|
ch <- conn
|
||||||
|
return &singleConnListener{conn: conn, ch: ch, done: make(chan struct{})}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *singleConnListener) Accept() (net.Conn, error) {
|
||||||
|
select {
|
||||||
|
case c := <-l.ch:
|
||||||
|
return c, nil
|
||||||
|
case <-l.done:
|
||||||
|
return nil, net.ErrClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *singleConnListener) Close() error {
|
||||||
|
select {
|
||||||
|
case <-l.done:
|
||||||
|
default:
|
||||||
|
close(l.done)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *singleConnListener) Addr() net.Addr {
|
||||||
|
return l.conn.LocalAddr()
|
||||||
|
}
|
||||||
363
internal/l7/serve_test.go
Normal file
363
internal/l7/serve_test.go
Normal file
@@ -0,0 +1,363 @@
|
|||||||
|
package l7
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/h2c"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testCert generates a self-signed TLS certificate for the given hostname
|
||||||
|
// and writes the cert/key to temporary files, returning their paths.
|
||||||
|
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generating key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: hostname},
|
||||||
|
DNSNames: []string{hostname},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating certificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath = filepath.Join(dir, "cert.pem")
|
||||||
|
keyPath = filepath.Join(dir, "key.pem")
|
||||||
|
|
||||||
|
certFile, err := os.Create(certPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating cert file: %v", err)
|
||||||
|
}
|
||||||
|
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
certFile.Close()
|
||||||
|
|
||||||
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("marshaling key: %v", err)
|
||||||
|
}
|
||||||
|
keyFile, err := os.Create(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating key file: %v", err)
|
||||||
|
}
|
||||||
|
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
keyFile.Close()
|
||||||
|
|
||||||
|
return certPath, keyPath
|
||||||
|
}
|
||||||
|
|
||||||
|
// startH2CBackend starts an h2c (HTTP/2 cleartext) backend server that
|
||||||
|
// responds with the given status and body. Returns the listener address.
|
||||||
|
func startH2CBackend(t *testing.T, handler http.Handler) string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
h2s := &http2.Server{}
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: h2c.NewHandler(handler, h2s),
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
srv.Close()
|
||||||
|
ln.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
go srv.Serve(ln)
|
||||||
|
return ln.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// dialTLSToProxy dials a TCP connection to the proxy, does a TLS handshake
|
||||||
|
// with the given serverName (skipping cert verification for self-signed),
|
||||||
|
// and returns an *http.Client configured to use that connection for HTTP/2.
|
||||||
|
func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
ServerName: serverName,
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := tls.DialWithDialer(
|
||||||
|
&net.Dialer{Timeout: 5 * time.Second},
|
||||||
|
"tcp", proxyAddr, tlsConf,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TLS dial: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { conn.Close() })
|
||||||
|
|
||||||
|
// Create an HTTP/2 client transport over this single connection.
|
||||||
|
tr := &http2.Transport{}
|
||||||
|
h2conn, err := tr.NewClientConn(conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating h2 client conn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Transport: &singleConnRoundTripper{cc: h2conn},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// singleConnRoundTripper is an http.RoundTripper that uses a single HTTP/2
|
||||||
|
// client connection.
|
||||||
|
type singleConnRoundTripper struct {
|
||||||
|
cc *http2.ClientConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||||
|
return s.cc.RoundTrip(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
// serveL7Route starts l7.Serve in a goroutine for a single connection.
|
||||||
|
// Returns when the goroutine completes.
|
||||||
|
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
|
||||||
|
t.Helper()
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
|
||||||
|
if l7Err != nil {
|
||||||
|
t.Logf("l7.Serve: %v", l7Err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7H2CBackend(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "l7.test")
|
||||||
|
|
||||||
|
// Start an h2c backend.
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("X-Backend", "ok")
|
||||||
|
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
|
||||||
|
}))
|
||||||
|
|
||||||
|
// Start a TCP listener for the L7 proxy.
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: backendAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept one connection and run L7 serve.
|
||||||
|
go func() {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
// No peeked bytes — the client is connecting directly with TLS.
|
||||||
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Connect as an HTTP/2 TLS client.
|
||||||
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
|
||||||
|
|
||||||
|
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
if got := string(body); got != "hello from backend, path=/foo" {
|
||||||
|
t.Fatalf("body = %q, want %q", got, "hello from backend, path=/foo")
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Header.Get("X-Backend") != "ok" {
|
||||||
|
t.Fatalf("X-Backend header missing or wrong")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7ForwardingHeaders(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "headers.test")
|
||||||
|
|
||||||
|
// Backend that echoes the forwarding headers.
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
|
||||||
|
r.Header.Get("X-Forwarded-For"),
|
||||||
|
r.Header.Get("X-Forwarded-Proto"),
|
||||||
|
r.Header.Get("X-Real-IP"),
|
||||||
|
)
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: backendAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
|
||||||
|
resp, err := client.Get("https://headers.test/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
|
||||||
|
if string(body) != want {
|
||||||
|
t.Fatalf("body = %q, want %q", body, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7BackendUnreachable(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "unreachable.test")
|
||||||
|
|
||||||
|
// Find a port that nothing is listening on.
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
deadAddr := ln.Addr().String()
|
||||||
|
ln.Close()
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: deadAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
|
||||||
|
resp, err := client.Get("https://unreachable.test/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("status = %d, want 502", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7MultipleRequests(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "multi.test")
|
||||||
|
|
||||||
|
var reqCount int
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reqCount++
|
||||||
|
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: backendAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
|
||||||
|
|
||||||
|
// Send multiple requests over the same HTTP/2 connection.
|
||||||
|
for i := range 3 {
|
||||||
|
path := fmt.Sprintf("/req%d", i)
|
||||||
|
resp, err := client.Get(fmt.Sprintf("https://multi.test%s", path))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET %s: %v", path, err)
|
||||||
|
}
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
resp.Body.Close()
|
||||||
|
|
||||||
|
want := fmt.Sprintf("req=%d path=%s", i+1, path)
|
||||||
|
if string(body) != want {
|
||||||
|
t.Fatalf("request %d: body = %q, want %q", i, body, want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||||
|
"git.wntrmute.dev/kyle/mc-proxy/internal/l7"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
||||||
@@ -298,11 +299,10 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dispatch based on route mode. L7 will be implemented in a later phase.
|
// Dispatch based on route mode.
|
||||||
switch route.Mode {
|
switch route.Mode {
|
||||||
case "l7":
|
case "l7":
|
||||||
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
|
s.handleL7(ctx, conn, addr, addrPort, hostname, route, peeked)
|
||||||
return
|
|
||||||
default:
|
default:
|
||||||
s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked)
|
s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked)
|
||||||
}
|
}
|
||||||
@@ -340,3 +340,25 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
|
|||||||
"backend_bytes", result.BackendBytes,
|
"backend_bytes", result.BackendBytes,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleL7 handles an L7 (TLS-terminating) connection.
|
||||||
|
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
|
||||||
|
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
||||||
|
|
||||||
|
rc := l7.RouteConfig{
|
||||||
|
Backend: route.Backend,
|
||||||
|
TLSCert: route.TLSCert,
|
||||||
|
TLSKey: route.TLSKey,
|
||||||
|
BackendTLS: route.BackendTLS,
|
||||||
|
SendProxyProtocol: route.SendProxyProtocol,
|
||||||
|
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {
|
||||||
|
if ctx.Err() == nil {
|
||||||
|
s.logger.Debug("L7 serve ended", "hostname", hostname, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logger.Info("L7 connection closed", "addr", addr, "hostname", hostname)
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,11 +3,23 @@ package server
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
"net"
|
"net"
|
||||||
|
"net/http"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -15,6 +27,8 @@ import (
|
|||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||||
|
"golang.org/x/net/http2"
|
||||||
|
"golang.org/x/net/http2/h2c"
|
||||||
)
|
)
|
||||||
|
|
||||||
// l4Route creates a RouteInfo for an L4 passthrough route.
|
// l4Route creates a RouteInfo for an L4 passthrough route.
|
||||||
@@ -1038,6 +1052,220 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// --- L7 server integration tests ---
|
||||||
|
|
||||||
|
// testCert generates a self-signed TLS certificate for the given hostname.
|
||||||
|
func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
||||||
|
t.Helper()
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("generating key: %v", err)
|
||||||
|
}
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(1),
|
||||||
|
Subject: pkix.Name{CommonName: hostname},
|
||||||
|
DNSNames: []string{hostname},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating certificate: %v", err)
|
||||||
|
}
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath = filepath.Join(dir, "cert.pem")
|
||||||
|
keyPath = filepath.Join(dir, "key.pem")
|
||||||
|
cf, _ := os.Create(certPath)
|
||||||
|
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
cf.Close()
|
||||||
|
keyDER, _ := x509.MarshalECPrivateKey(key)
|
||||||
|
kf, _ := os.Create(keyPath)
|
||||||
|
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
|
kf.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// startH2CBackend starts an h2c backend for testing.
|
||||||
|
func startH2CBackend(t *testing.T, handler http.Handler) string {
|
||||||
|
t.Helper()
|
||||||
|
h2s := &http2.Server{}
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: h2c.NewHandler(handler, h2s),
|
||||||
|
ReadHeaderTimeout: 5 * time.Second,
|
||||||
|
}
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { srv.Close(); ln.Close() })
|
||||||
|
go srv.Serve(ln)
|
||||||
|
return ln.Addr().String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7ThroughServer(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "l7srv.test")
|
||||||
|
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"l7srv.test": {
|
||||||
|
Backend: backendAddr,
|
||||||
|
Mode: "l7",
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// Connect with TLS and make an HTTP/2 request.
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
ServerName: "l7srv.test",
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
}
|
||||||
|
conn, err := tls.DialWithDialer(
|
||||||
|
&net.Dialer{Timeout: 5 * time.Second},
|
||||||
|
"tcp", proxyAddr, tlsConf,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TLS dial: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
tr := &http2.Transport{}
|
||||||
|
h2conn, err := tr.NewClientConn(conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("h2 client conn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "https://l7srv.test/hello", nil)
|
||||||
|
resp, err := h2conn.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
|
||||||
|
// no PROXY protocol is in use.
|
||||||
|
if resp.StatusCode != 200 {
|
||||||
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
||||||
|
}
|
||||||
|
got := string(body)
|
||||||
|
if got != "ok path=/hello xff=127.0.0.1" {
|
||||||
|
t.Fatalf("body = %q", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMixedL4L7SameListener(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "l7mixed.test")
|
||||||
|
|
||||||
|
// L4 backend: echo server.
|
||||||
|
l4BackendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("l4 backend listen: %v", err)
|
||||||
|
}
|
||||||
|
defer l4BackendLn.Close()
|
||||||
|
go echoServer(t, l4BackendLn)
|
||||||
|
|
||||||
|
// L7 backend: h2c HTTP server.
|
||||||
|
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, "l7-response")
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"l4echo.test": l4Route(l4BackendLn.Addr().String()),
|
||||||
|
"l7mixed.test": {
|
||||||
|
Backend: l7BackendAddr,
|
||||||
|
Mode: "l7",
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
// Test L4 route works: send ClientHello, get echo.
|
||||||
|
l4Conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial L4: %v", err)
|
||||||
|
}
|
||||||
|
defer l4Conn.Close()
|
||||||
|
hello := buildClientHello("l4echo.test")
|
||||||
|
l4Conn.Write(hello)
|
||||||
|
echoed := make([]byte, len(hello))
|
||||||
|
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
|
||||||
|
t.Fatalf("L4 echo read: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test L7 route works: TLS + HTTP/2.
|
||||||
|
tlsConf := &tls.Config{
|
||||||
|
ServerName: "l7mixed.test",
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
NextProtos: []string{"h2"},
|
||||||
|
}
|
||||||
|
l7Conn, err := tls.DialWithDialer(
|
||||||
|
&net.Dialer{Timeout: 5 * time.Second},
|
||||||
|
"tcp", proxyAddr, tlsConf,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("TLS dial L7: %v", err)
|
||||||
|
}
|
||||||
|
defer l7Conn.Close()
|
||||||
|
|
||||||
|
tr := &http2.Transport{}
|
||||||
|
h2conn, err := tr.NewClientConn(l7Conn)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("h2 client conn: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, _ := http.NewRequest("GET", "https://l7mixed.test/", nil)
|
||||||
|
resp, err := h2conn.RoundTrip(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("L7 RoundTrip: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if string(body) != "l7-response" {
|
||||||
|
t.Fatalf("L7 body = %q, want %q", body, "l7-response")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
|
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
|
||||||
|
|
||||||
func buildClientHello(serverName string) []byte {
|
func buildClientHello(serverName string) []byte {
|
||||||
|
|||||||
Reference in New Issue
Block a user