Add PROXY protocol v1/v2 support for multi-hop deployments
New internal/proxyproto package implements PROXY protocol parsing and writing without buffering past the header boundary (reads exact byte counts so the connection is correctly positioned for SNI extraction). Parser: auto-detects v1 (text) and v2 (binary) by first byte. Parses TCP4/TCP6 for both versions plus v2 LOCAL command. Enforces max header sizes and read deadlines. Writer: generates v2 binary headers for IPv4 and IPv6 with PROXY command. Server integration: - Receive: when listener.ProxyProtocol is true, parses PROXY header before firewall check. Real client IP from header is used for firewall evaluation and logging. Malformed headers cause RST. - Send: when route.SendProxyProtocol is true, writes PROXY v2 header to backend before forwarding the ClientHello bytes. Tests cover v1/v2 parsing, malformed rejection, timeout, round-trip write+parse, and five server integration tests: receive with valid header, receive with garbage, send verification, send-disabled verification, and firewall evaluation using the real client IP. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,17 +1,20 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||
)
|
||||
|
||||
// l4Route creates a RouteInfo for an L4 passthrough route.
|
||||
@@ -696,6 +699,345 @@ func TestListenerStateRoutes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolReceive(t *testing.T) {
|
||||
// Backend echoes everything back.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("backend listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
go echoServer(t, backendLn)
|
||||
|
||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("finding free port: %v", err)
|
||||
}
|
||||
proxyAddr := proxyLn.Addr().String()
|
||||
proxyLn.Close()
|
||||
|
||||
srv := newTestServer(t, []ListenerData{
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
ProxyProtocol: true,
|
||||
Routes: map[string]RouteInfo{
|
||||
"echo.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
stop := startAndStop(t, srv)
|
||||
defer stop()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send PROXY v2 header followed by TLS ClientHello.
|
||||
var ppBuf bytes.Buffer
|
||||
proxyproto.WriteV2(&ppBuf,
|
||||
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||
)
|
||||
conn.Write(ppBuf.Bytes())
|
||||
|
||||
hello := buildClientHello("echo.test")
|
||||
conn.Write(hello)
|
||||
|
||||
// Backend should echo the ClientHello back (not the PROXY header).
|
||||
echoed := make([]byte, len(hello))
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
if _, err := io.ReadFull(conn, echoed); err != nil {
|
||||
t.Fatalf("read echoed data: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolReceiveGarbage(t *testing.T) {
|
||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("finding free port: %v", err)
|
||||
}
|
||||
proxyAddr := proxyLn.Addr().String()
|
||||
proxyLn.Close()
|
||||
|
||||
srv := newTestServer(t, []ListenerData{
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
ProxyProtocol: true,
|
||||
Routes: map[string]RouteInfo{
|
||||
"echo.test": l4Route("127.0.0.1:1"),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
stop := startAndStop(t, srv)
|
||||
defer stop()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send garbage instead of a valid PROXY header.
|
||||
conn.Write([]byte("NOT A PROXY HEADER\r\n"))
|
||||
|
||||
// Connection should be closed.
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Fatal("expected connection to be closed for invalid PROXY header")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolSend(t *testing.T) {
|
||||
// Backend that captures the first bytes it receives.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("backend listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
|
||||
received := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
// Read all available data; the proxy sends PROXY header + ClientHello.
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
var all []byte
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := conn.Read(buf)
|
||||
all = append(all, buf[:n]...)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
// We expect at least pp header (28) + some TLS data.
|
||||
if len(all) >= 28+5 {
|
||||
break
|
||||
}
|
||||
}
|
||||
received <- all
|
||||
}()
|
||||
|
||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("finding free port: %v", err)
|
||||
}
|
||||
proxyAddr := proxyLn.Addr().String()
|
||||
proxyLn.Close()
|
||||
|
||||
srv := newTestServer(t, []ListenerData{
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]RouteInfo{
|
||||
"pp.test": {
|
||||
Backend: backendLn.Addr().String(),
|
||||
Mode: "l4",
|
||||
SendProxyProtocol: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
stop := startAndStop(t, srv)
|
||||
defer stop()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
hello := buildClientHello("pp.test")
|
||||
conn.Write(hello)
|
||||
|
||||
// The backend should receive: PROXY v2 header + ClientHello.
|
||||
select {
|
||||
case data := <-received:
|
||||
// PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs).
|
||||
if len(data) < 28 {
|
||||
t.Fatalf("backend received only %d bytes, want at least 28", len(data))
|
||||
}
|
||||
// Check PROXY v2 signature.
|
||||
sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
|
||||
if [12]byte(data[:12]) != sig {
|
||||
t.Fatal("backend data does not start with PROXY v2 signature")
|
||||
}
|
||||
// Verify TLS record header follows the PROXY header.
|
||||
ppLen := 28 // v2 IPv4
|
||||
if len(data) <= ppLen {
|
||||
t.Fatalf("backend received only PROXY header, no TLS data")
|
||||
}
|
||||
if data[ppLen] != 0x16 {
|
||||
t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for backend data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolNotSent(t *testing.T) {
|
||||
// Backend captures first bytes.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("backend listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
|
||||
received := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 4096)
|
||||
n, _ := conn.Read(buf)
|
||||
received <- buf[:n]
|
||||
}()
|
||||
|
||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("finding free port: %v", err)
|
||||
}
|
||||
proxyAddr := proxyLn.Addr().String()
|
||||
proxyLn.Close()
|
||||
|
||||
srv := newTestServer(t, []ListenerData{
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
Routes: map[string]RouteInfo{
|
||||
"nopp.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
stop := startAndStop(t, srv)
|
||||
defer stop()
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
hello := buildClientHello("nopp.test")
|
||||
conn.Write(hello)
|
||||
|
||||
select {
|
||||
case data := <-received:
|
||||
// First byte should be TLS record header, not PROXY signature.
|
||||
if data[0] != 0x16 {
|
||||
t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0])
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for backend data")
|
||||
}
|
||||
}
|
||||
|
||||
func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
||||
// Backend that should never be reached.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("backend listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
|
||||
reached := make(chan struct{}, 1)
|
||||
go func() {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
conn.Close()
|
||||
reached <- struct{}{}
|
||||
}()
|
||||
|
||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("finding free port: %v", err)
|
||||
}
|
||||
proxyAddr := proxyLn.Addr().String()
|
||||
proxyLn.Close()
|
||||
|
||||
// Block 203.0.113.50 (the "real" client IP from PROXY header).
|
||||
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
|
||||
fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("creating firewall: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Proxy: config.Proxy{
|
||||
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
||||
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
||||
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
||||
},
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
srv := New(cfg, fw, []ListenerData{
|
||||
{
|
||||
ID: 1,
|
||||
Addr: proxyAddr,
|
||||
ProxyProtocol: true,
|
||||
Routes: map[string]RouteInfo{
|
||||
"blocked.test": l4Route(backendLn.Addr().String()),
|
||||
},
|
||||
},
|
||||
}, logger, "test")
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
srv.Run(ctx)
|
||||
}()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("dial proxy: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Send PROXY v2 with the blocked real IP.
|
||||
var ppBuf bytes.Buffer
|
||||
proxyproto.WriteV2(&ppBuf,
|
||||
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||
)
|
||||
conn.Write(ppBuf.Bytes())
|
||||
conn.Write(buildClientHello("blocked.test"))
|
||||
|
||||
// Connection should be closed (firewall blocks real IP).
|
||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
_, err = conn.Read(make([]byte, 1))
|
||||
if err == nil {
|
||||
t.Fatal("expected connection to be closed by firewall")
|
||||
}
|
||||
|
||||
// Backend should not have been reached.
|
||||
select {
|
||||
case <-reached:
|
||||
t.Fatal("backend was reached despite firewall blocking real IP")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// Expected.
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
|
||||
|
||||
func buildClientHello(serverName string) []byte {
|
||||
|
||||
Reference in New Issue
Block a user