Add PROXY protocol v1/v2 support for multi-hop deployments
New internal/proxyproto package implements PROXY protocol parsing and writing without buffering past the header boundary (reads exact byte counts so the connection is correctly positioned for SNI extraction). Parser: auto-detects v1 (text) and v2 (binary) by first byte. Parses TCP4/TCP6 for both versions plus v2 LOCAL command. Enforces max header sizes and read deadlines. Writer: generates v2 binary headers for IPv4 and IPv6 with PROXY command. Server integration: - Receive: when listener.ProxyProtocol is true, parses PROXY header before firewall check. Real client IP from header is used for firewall evaluation and logging. Malformed headers cause RST. - Send: when route.SendProxyProtocol is true, writes PROXY v2 header to backend before forwarding the ClientHello bytes. Tests cover v1/v2 parsing, malformed rejection, timeout, round-trip write+parse, and five server integration tests: receive with valid header, receive with garbage, send verification, send-disabled verification, and firewall evaluation using the real client IP. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,10 +21,10 @@ proceeds. Each item is marked:
|
|||||||
|
|
||||||
## Phase 2: PROXY Protocol
|
## Phase 2: PROXY Protocol
|
||||||
|
|
||||||
- [ ] 2.1 `internal/proxyproto/` package (v1/v2 parser, v2 writer)
|
- [x] 2.1 `internal/proxyproto/` package (v1/v2 parser, v2 writer)
|
||||||
- [ ] 2.2 Server integration — receive (parse PROXY header before firewall on enabled listeners)
|
- [x] 2.2 Server integration — receive (parse PROXY header before firewall on enabled listeners)
|
||||||
- [ ] 2.3 Server integration — send on L4 (write PROXY v2 header before ClientHello on enabled routes)
|
- [x] 2.3 Server integration — send on L4 (write PROXY v2 header before ClientHello on enabled routes)
|
||||||
- [ ] 2.4 Tests (receive, send, firewall uses real IP, malformed header rejection)
|
- [x] 2.4 Tests (receive, send, firewall uses real IP, malformed header rejection)
|
||||||
|
|
||||||
## Phase 3: L7 Proxying
|
## Phase 3: L7 Proxying
|
||||||
|
|
||||||
|
|||||||
264
internal/proxyproto/proxyproto.go
Normal file
264
internal/proxyproto/proxyproto.go
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
// Package proxyproto implements PROXY protocol v1 and v2 parsing and v2 writing.
|
||||||
|
//
|
||||||
|
// See https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt
|
||||||
|
package proxyproto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Command represents the PROXY protocol command.
|
||||||
|
type Command byte
|
||||||
|
|
||||||
|
const (
|
||||||
|
CommandLocal Command = 0x00
|
||||||
|
CommandProxy Command = 0x01
|
||||||
|
)
|
||||||
|
|
||||||
|
// Header is a parsed PROXY protocol header.
|
||||||
|
type Header struct {
|
||||||
|
Version int // 1 or 2
|
||||||
|
Command Command
|
||||||
|
SrcAddr netip.AddrPort
|
||||||
|
DstAddr netip.AddrPort
|
||||||
|
}
|
||||||
|
|
||||||
|
// v2 binary signature (12 bytes).
|
||||||
|
var v2Signature = [12]byte{
|
||||||
|
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// maxV1Length is the maximum length of a v1 header line (including CRLF).
|
||||||
|
maxV1Length = 108
|
||||||
|
|
||||||
|
// v2HeaderLen is the fixed portion of a v2 header (signature + ver/cmd + fam + len).
|
||||||
|
v2HeaderLen = 16
|
||||||
|
|
||||||
|
// maxV2AddrLen limits the address block we'll read.
|
||||||
|
maxV2AddrLen = 536
|
||||||
|
)
|
||||||
|
|
||||||
|
// Parse reads a PROXY protocol header from conn, using deadline for timeout.
|
||||||
|
// It reads only the exact bytes of the PROXY header, leaving the connection
|
||||||
|
// positioned at the first byte after the header (e.g., TLS ClientHello).
|
||||||
|
func Parse(conn net.Conn, deadline time.Time) (Header, error) {
|
||||||
|
conn.SetReadDeadline(deadline)
|
||||||
|
defer conn.SetReadDeadline(time.Time{})
|
||||||
|
|
||||||
|
// Read the first byte to determine version.
|
||||||
|
var first [1]byte
|
||||||
|
if _, err := io.ReadFull(conn, first[:]); err != nil {
|
||||||
|
return Header{}, fmt.Errorf("reading first byte: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if first[0] == 'P' {
|
||||||
|
return parseV1(conn)
|
||||||
|
}
|
||||||
|
if first[0] == v2Signature[0] {
|
||||||
|
return parseV2(conn, first[0])
|
||||||
|
}
|
||||||
|
return Header{}, fmt.Errorf("invalid PROXY protocol header: unexpected first byte 0x%02x", first[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseV1 parses a v1 text header. The leading 'P' has already been consumed.
|
||||||
|
// Reads byte-by-byte until CRLF to avoid buffering past the header.
|
||||||
|
func parseV1(conn net.Conn) (Header, error) {
|
||||||
|
// We already consumed 'P'. Read until \r\n, byte by byte.
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, 'P')
|
||||||
|
var b [1]byte
|
||||||
|
for len(buf) < maxV1Length {
|
||||||
|
if _, err := io.ReadFull(conn, b[:]); err != nil {
|
||||||
|
return Header{}, fmt.Errorf("reading v1 header: %w", err)
|
||||||
|
}
|
||||||
|
buf = append(buf, b[0])
|
||||||
|
if len(buf) >= 2 && buf[len(buf)-2] == '\r' && buf[len(buf)-1] == '\n' {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
line := string(buf)
|
||||||
|
if !strings.HasSuffix(line, "\r\n") {
|
||||||
|
return Header{}, fmt.Errorf("v1 header too long or missing CRLF")
|
||||||
|
}
|
||||||
|
line = strings.TrimSuffix(line, "\r\n")
|
||||||
|
|
||||||
|
parts := strings.Split(line, " ")
|
||||||
|
if len(parts) != 6 {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: expected 6 fields, got %d", len(parts))
|
||||||
|
}
|
||||||
|
if parts[0] != "PROXY" {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: expected PROXY prefix, got %q", parts[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := parts[1]
|
||||||
|
if proto != "TCP4" && proto != "TCP6" {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: unsupported protocol %q", proto)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP, err := netip.ParseAddr(parts[2])
|
||||||
|
if err != nil {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: invalid source IP %q: %w", parts[2], err)
|
||||||
|
}
|
||||||
|
dstIP, err := netip.ParseAddr(parts[3])
|
||||||
|
if err != nil {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: invalid destination IP %q: %w", parts[3], err)
|
||||||
|
}
|
||||||
|
srcPort, err := strconv.ParseUint(parts[4], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: invalid source port %q: %w", parts[4], err)
|
||||||
|
}
|
||||||
|
dstPort, err := strconv.ParseUint(parts[5], 10, 16)
|
||||||
|
if err != nil {
|
||||||
|
return Header{}, fmt.Errorf("v1 header: invalid destination port %q: %w", parts[5], err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Header{
|
||||||
|
Version: 1,
|
||||||
|
Command: CommandProxy,
|
||||||
|
SrcAddr: netip.AddrPortFrom(srcIP, uint16(srcPort)),
|
||||||
|
DstAddr: netip.AddrPortFrom(dstIP, uint16(dstPort)),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseV2 parses a v2 binary header. The first byte has already been consumed.
|
||||||
|
func parseV2(conn net.Conn, firstByte byte) (Header, error) {
|
||||||
|
// Read the remaining 15 bytes of the fixed header.
|
||||||
|
var hdr [v2HeaderLen]byte
|
||||||
|
hdr[0] = firstByte
|
||||||
|
if _, err := io.ReadFull(conn, hdr[1:]); err != nil {
|
||||||
|
return Header{}, fmt.Errorf("reading v2 header: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify signature.
|
||||||
|
if [12]byte(hdr[:12]) != v2Signature {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: invalid signature")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version (upper 4 bits of byte 12) must be 0x2.
|
||||||
|
verCmd := hdr[12]
|
||||||
|
if verCmd>>4 != 0x02 {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: unsupported version %d", verCmd>>4)
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := Command(verCmd & 0x0F)
|
||||||
|
if cmd != CommandLocal && cmd != CommandProxy {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: unsupported command %d", cmd)
|
||||||
|
}
|
||||||
|
|
||||||
|
addrLen := binary.BigEndian.Uint16(hdr[14:16])
|
||||||
|
if int(addrLen) > maxV2AddrLen {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: address length %d exceeds maximum", addrLen)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the exact address block.
|
||||||
|
addrBuf := make([]byte, addrLen)
|
||||||
|
if addrLen > 0 {
|
||||||
|
if _, err := io.ReadFull(conn, addrBuf); err != nil {
|
||||||
|
return Header{}, fmt.Errorf("reading v2 addresses: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// LOCAL command: no addresses to parse.
|
||||||
|
if cmd == CommandLocal {
|
||||||
|
return Header{
|
||||||
|
Version: 2,
|
||||||
|
Command: CommandLocal,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PROXY command: parse addresses based on family/protocol.
|
||||||
|
famProto := hdr[13]
|
||||||
|
family := famProto >> 4
|
||||||
|
protocol := famProto & 0x0F
|
||||||
|
|
||||||
|
// We only support STREAM (TCP) protocol = 0x01.
|
||||||
|
if protocol != 0x01 {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: unsupported protocol %d (only TCP/stream supported)", protocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch family {
|
||||||
|
case 0x01: // AF_INET (IPv4)
|
||||||
|
if len(addrBuf) < 12 {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: IPv4 address block too short (%d bytes)", len(addrBuf))
|
||||||
|
}
|
||||||
|
srcIP := netip.AddrFrom4([4]byte(addrBuf[0:4]))
|
||||||
|
dstIP := netip.AddrFrom4([4]byte(addrBuf[4:8]))
|
||||||
|
srcPort := binary.BigEndian.Uint16(addrBuf[8:10])
|
||||||
|
dstPort := binary.BigEndian.Uint16(addrBuf[10:12])
|
||||||
|
return Header{
|
||||||
|
Version: 2,
|
||||||
|
Command: CommandProxy,
|
||||||
|
SrcAddr: netip.AddrPortFrom(srcIP, srcPort),
|
||||||
|
DstAddr: netip.AddrPortFrom(dstIP, dstPort),
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case 0x02: // AF_INET6 (IPv6)
|
||||||
|
if len(addrBuf) < 36 {
|
||||||
|
return Header{}, fmt.Errorf("v2 header: IPv6 address block too short (%d bytes)", len(addrBuf))
|
||||||
|
}
|
||||||
|
srcIP := netip.AddrFrom16([16]byte(addrBuf[0:16]))
|
||||||
|
dstIP := netip.AddrFrom16([16]byte(addrBuf[16:32]))
|
||||||
|
srcPort := binary.BigEndian.Uint16(addrBuf[32:34])
|
||||||
|
dstPort := binary.BigEndian.Uint16(addrBuf[34:36])
|
||||||
|
return Header{
|
||||||
|
Version: 2,
|
||||||
|
Command: CommandProxy,
|
||||||
|
SrcAddr: netip.AddrPortFrom(srcIP, srcPort),
|
||||||
|
DstAddr: netip.AddrPortFrom(dstIP, dstPort),
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return Header{}, fmt.Errorf("v2 header: unsupported address family %d", family)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteV2 writes a PROXY protocol v2 header with the PROXY command.
|
||||||
|
// src is the original client address; dst is the backend/proxy address.
|
||||||
|
func WriteV2(w io.Writer, src, dst netip.AddrPort) error {
|
||||||
|
var buf []byte
|
||||||
|
|
||||||
|
// Signature (12 bytes).
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
|
||||||
|
// Version (0x2) and command (PROXY = 0x1) → 0x21.
|
||||||
|
buf = append(buf, 0x21)
|
||||||
|
|
||||||
|
srcAddr := src.Addr()
|
||||||
|
dstAddr := dst.Addr()
|
||||||
|
|
||||||
|
if srcAddr.Is4() && dstAddr.Is4() {
|
||||||
|
// AF_INET, STREAM → 0x11.
|
||||||
|
buf = append(buf, 0x11)
|
||||||
|
// Address length: 4+4+2+2 = 12.
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 12)
|
||||||
|
src4 := srcAddr.As4()
|
||||||
|
dst4 := dstAddr.As4()
|
||||||
|
buf = append(buf, src4[:]...)
|
||||||
|
buf = append(buf, dst4[:]...)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, src.Port())
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, dst.Port())
|
||||||
|
} else {
|
||||||
|
// AF_INET6, STREAM → 0x21.
|
||||||
|
buf = append(buf, 0x21)
|
||||||
|
// Address length: 16+16+2+2 = 36.
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 36)
|
||||||
|
src16 := srcAddr.As16()
|
||||||
|
dst16 := dstAddr.As16()
|
||||||
|
buf = append(buf, src16[:]...)
|
||||||
|
buf = append(buf, dst16[:]...)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, src.Port())
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, dst.Port())
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := w.Write(buf)
|
||||||
|
return err
|
||||||
|
}
|
||||||
389
internal/proxyproto/proxyproto_test.go
Normal file
389
internal/proxyproto/proxyproto_test.go
Normal file
@@ -0,0 +1,389 @@
|
|||||||
|
package proxyproto
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// pipeWithDeadline returns a net.Conn pair where the reader side supports deadlines.
|
||||||
|
func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
|
||||||
|
t.Helper()
|
||||||
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("listen: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { ln.Close() })
|
||||||
|
|
||||||
|
ch := make(chan net.Conn, 1)
|
||||||
|
go func() {
|
||||||
|
c, err := ln.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ch <- c
|
||||||
|
}()
|
||||||
|
|
||||||
|
w, err := net.Dial("tcp", ln.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { w.Close() })
|
||||||
|
|
||||||
|
r := <-ch
|
||||||
|
t.Cleanup(func() { r.Close() })
|
||||||
|
return r, w
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV1TCP4(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 1 {
|
||||||
|
t.Fatalf("version = %d, want 1", hdr.Version)
|
||||||
|
}
|
||||||
|
if hdr.Command != CommandProxy {
|
||||||
|
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
|
||||||
|
}
|
||||||
|
if got := hdr.SrcAddr.String(); got != "192.168.1.1:56324" {
|
||||||
|
t.Fatalf("src = %s, want 192.168.1.1:56324", got)
|
||||||
|
}
|
||||||
|
if got := hdr.DstAddr.String(); got != "10.0.0.1:443" {
|
||||||
|
t.Fatalf("dst = %s, want 10.0.0.1:443", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV1TCP6(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 1 {
|
||||||
|
t.Fatalf("version = %d, want 1", hdr.Version)
|
||||||
|
}
|
||||||
|
wantSrc := "[2001:db8::1]:56324"
|
||||||
|
if got := hdr.SrcAddr.String(); got != wantSrc {
|
||||||
|
t.Fatalf("src = %s, want %s", got, wantSrc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV2TCP4(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x21) // version 2, PROXY command
|
||||||
|
buf = append(buf, 0x11) // AF_INET, STREAM
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 12)
|
||||||
|
buf = append(buf, 192, 168, 1, 100) // src IP
|
||||||
|
buf = append(buf, 10, 0, 0, 1) // dst IP
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 12345)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 443)
|
||||||
|
writer.Write(buf)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 2 {
|
||||||
|
t.Fatalf("version = %d, want 2", hdr.Version)
|
||||||
|
}
|
||||||
|
if hdr.Command != CommandProxy {
|
||||||
|
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
|
||||||
|
}
|
||||||
|
if got := hdr.SrcAddr.String(); got != "192.168.1.100:12345" {
|
||||||
|
t.Fatalf("src = %s, want 192.168.1.100:12345", got)
|
||||||
|
}
|
||||||
|
if got := hdr.DstAddr.String(); got != "10.0.0.1:443" {
|
||||||
|
t.Fatalf("dst = %s, want 10.0.0.1:443", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV2TCP6(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x21) // version 2, PROXY command
|
||||||
|
buf = append(buf, 0x21) // AF_INET6, STREAM
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 36)
|
||||||
|
// src: 2001:db8::1
|
||||||
|
src := netip.MustParseAddr("2001:db8::1").As16()
|
||||||
|
buf = append(buf, src[:]...)
|
||||||
|
// dst: 2001:db8::2
|
||||||
|
dst := netip.MustParseAddr("2001:db8::2").As16()
|
||||||
|
buf = append(buf, dst[:]...)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 56324)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 8443)
|
||||||
|
writer.Write(buf)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 2 {
|
||||||
|
t.Fatalf("version = %d, want 2", hdr.Version)
|
||||||
|
}
|
||||||
|
wantSrc := "[2001:db8::1]:56324"
|
||||||
|
if got := hdr.SrcAddr.String(); got != wantSrc {
|
||||||
|
t.Fatalf("src = %s, want %s", got, wantSrc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV2Local(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x20) // version 2, LOCAL command
|
||||||
|
buf = append(buf, 0x00) // unspec family, unspec protocol
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
|
writer.Write(buf)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 2 {
|
||||||
|
t.Fatalf("version = %d, want 2", hdr.Version)
|
||||||
|
}
|
||||||
|
if hdr.Command != CommandLocal {
|
||||||
|
t.Fatalf("command = %d, want CommandLocal", hdr.Command)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV1Malformed(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
data string
|
||||||
|
}{
|
||||||
|
{"wrong prefix", "XROXY TCP4 1.2.3.4 5.6.7.8 1 2\r\n"},
|
||||||
|
{"missing fields", "PROXY TCP4 1.2.3.4\r\n"},
|
||||||
|
{"bad protocol", "PROXY UDP4 1.2.3.4 5.6.7.8 1 2\r\n"},
|
||||||
|
{"bad src IP", "PROXY TCP4 not-an-ip 5.6.7.8 1 2\r\n"},
|
||||||
|
{"bad src port", "PROXY TCP4 1.2.3.4 5.6.7.8 notport 2\r\n"},
|
||||||
|
{"missing CRLF", "PROXY TCP4 1.2.3.4 5.6.7.8 1 2\n"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
writer.Write([]byte(tt.data))
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseV2Malformed(t *testing.T) {
|
||||||
|
t.Run("bad signature", func(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
bad := make([]byte, v2HeaderLen)
|
||||||
|
bad[0] = v2Signature[0] // first byte matches but rest doesn't
|
||||||
|
writer.Write(bad)
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for bad signature")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("bad version", func(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x31) // version 3, PROXY command
|
||||||
|
buf = append(buf, 0x11)
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
|
writer.Write(buf)
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for bad version")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("truncated address", func(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x21) // version 2, PROXY
|
||||||
|
buf = append(buf, 0x11) // AF_INET, STREAM
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 12)
|
||||||
|
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
|
||||||
|
writer.Write(buf)
|
||||||
|
writer.Close()
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for truncated address")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unsupported family", func(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
var buf []byte
|
||||||
|
buf = append(buf, v2Signature[:]...)
|
||||||
|
buf = append(buf, 0x21) // version 2, PROXY
|
||||||
|
buf = append(buf, 0x31) // AF_UNIX (3), STREAM
|
||||||
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
|
writer.Write(buf)
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for unsupported family")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseTimeout(t *testing.T) {
|
||||||
|
reader, _ := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
// No data sent — should hit deadline.
|
||||||
|
_, err := Parse(reader, time.Now().Add(100*time.Millisecond))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected timeout error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteV2IPv4(t *testing.T) {
|
||||||
|
src := netip.MustParseAddrPort("192.168.1.1:56324")
|
||||||
|
dst := netip.MustParseAddrPort("10.0.0.1:443")
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := WriteV2(&buf, src, dst); err != nil {
|
||||||
|
t.Fatalf("WriteV2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b := buf.Bytes()
|
||||||
|
// Signature (12) + ver/cmd (1) + fam (1) + len (2) + 4+4+2+2 = 28
|
||||||
|
if len(b) != 28 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 28", len(b))
|
||||||
|
}
|
||||||
|
// Verify signature.
|
||||||
|
if [12]byte(b[:12]) != v2Signature {
|
||||||
|
t.Fatal("bad signature")
|
||||||
|
}
|
||||||
|
if b[12] != 0x21 {
|
||||||
|
t.Fatalf("ver/cmd = 0x%02x, want 0x21", b[12])
|
||||||
|
}
|
||||||
|
if b[13] != 0x11 {
|
||||||
|
t.Fatalf("fam/proto = 0x%02x, want 0x11", b[13])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWriteV2IPv6(t *testing.T) {
|
||||||
|
src := netip.MustParseAddrPort("[2001:db8::1]:56324")
|
||||||
|
dst := netip.MustParseAddrPort("[2001:db8::2]:443")
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := WriteV2(&buf, src, dst); err != nil {
|
||||||
|
t.Fatalf("WriteV2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b := buf.Bytes()
|
||||||
|
// Signature (12) + ver/cmd (1) + fam (1) + len (2) + 16+16+2+2 = 52
|
||||||
|
if len(b) != 52 {
|
||||||
|
t.Fatalf("wrote %d bytes, want 52", len(b))
|
||||||
|
}
|
||||||
|
if b[13] != 0x21 {
|
||||||
|
t.Fatalf("fam/proto = 0x%02x, want 0x21", b[13])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripV2IPv4(t *testing.T) {
|
||||||
|
src := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
|
dst := netip.MustParseAddrPort("198.51.100.1:443")
|
||||||
|
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
WriteV2(writer, src, dst)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.Version != 2 {
|
||||||
|
t.Fatalf("version = %d, want 2", hdr.Version)
|
||||||
|
}
|
||||||
|
if hdr.Command != CommandProxy {
|
||||||
|
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
|
||||||
|
}
|
||||||
|
if hdr.SrcAddr != src {
|
||||||
|
t.Fatalf("src = %s, want %s", hdr.SrcAddr, src)
|
||||||
|
}
|
||||||
|
if hdr.DstAddr != dst {
|
||||||
|
t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRoundTripV2IPv6(t *testing.T) {
|
||||||
|
src := netip.MustParseAddrPort("[2001:db8:cafe::1]:40000")
|
||||||
|
dst := netip.MustParseAddrPort("[2001:db8:beef::2]:8443")
|
||||||
|
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
WriteV2(writer, src, dst)
|
||||||
|
}()
|
||||||
|
|
||||||
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Parse: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if hdr.SrcAddr != src {
|
||||||
|
t.Fatalf("src = %s, want %s", hdr.SrcAddr, src)
|
||||||
|
}
|
||||||
|
if hdr.DstAddr != dst {
|
||||||
|
t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseGarbageFirstByte(t *testing.T) {
|
||||||
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
go func() {
|
||||||
|
writer.Write([]byte{0xFF, 0x00, 0x01})
|
||||||
|
}()
|
||||||
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for garbage first byte")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
||||||
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -266,6 +267,20 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
|
|||||||
}
|
}
|
||||||
addr := addrPort.Addr()
|
addr := addrPort.Addr()
|
||||||
|
|
||||||
|
// Parse PROXY protocol header if enabled on this listener.
|
||||||
|
if ls.ProxyProtocol {
|
||||||
|
hdr, err := proxyproto.Parse(conn, time.Now().Add(5*time.Second))
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Debug("PROXY protocol parse failed", "addr", addr, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if hdr.Command == proxyproto.CommandProxy {
|
||||||
|
addr = hdr.SrcAddr.Addr()
|
||||||
|
addrPort = hdr.SrcAddr
|
||||||
|
s.logger.Debug("PROXY protocol", "real_addr", addr, "peer_addr", remoteAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if s.fw.Blocked(addr) {
|
if s.fw.Blocked(addr) {
|
||||||
s.logger.Debug("blocked by firewall", "addr", addr)
|
s.logger.Debug("blocked by firewall", "addr", addr)
|
||||||
return
|
return
|
||||||
@@ -289,12 +304,12 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
|
|||||||
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
|
s.logger.Error("L7 mode not yet implemented", "hostname", hostname)
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
s.handleL4(ctx, conn, ls, addr, hostname, route, peeked)
|
s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleL4 handles an L4 (passthrough) connection.
|
// handleL4 handles an L4 (passthrough) connection.
|
||||||
func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, addr netip.Addr, hostname string, route RouteInfo, peeked []byte) {
|
func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
|
||||||
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
||||||
@@ -302,6 +317,15 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState,
|
|||||||
}
|
}
|
||||||
defer backendConn.Close()
|
defer backendConn.Close()
|
||||||
|
|
||||||
|
// Send PROXY protocol v2 header to backend if configured.
|
||||||
|
if route.SendProxyProtocol {
|
||||||
|
backendAddrPort, _ := netip.ParseAddrPort(backendConn.RemoteAddr().String())
|
||||||
|
if err := proxyproto.WriteV2(backendConn, clientAddrPort, backendAddrPort); err != nil {
|
||||||
|
s.logger.Error("writing PROXY protocol header", "hostname", hostname, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
||||||
|
|
||||||
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
|
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
|
||||||
|
|||||||
@@ -1,17 +1,20 @@
|
|||||||
package server
|
package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"io"
|
"io"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||||
|
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
||||||
)
|
)
|
||||||
|
|
||||||
// l4Route creates a RouteInfo for an L4 passthrough route.
|
// l4Route creates a RouteInfo for an L4 passthrough route.
|
||||||
@@ -696,6 +699,345 @@ func TestListenerStateRoutes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolReceive(t *testing.T) {
|
||||||
|
// Backend echoes everything back.
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend listen: %v", err)
|
||||||
|
}
|
||||||
|
defer backendLn.Close()
|
||||||
|
go echoServer(t, backendLn)
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("finding free port: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
ProxyProtocol: true,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"echo.test": l4Route(backendLn.Addr().String()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send PROXY v2 header followed by TLS ClientHello.
|
||||||
|
var ppBuf bytes.Buffer
|
||||||
|
proxyproto.WriteV2(&ppBuf,
|
||||||
|
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||||
|
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||||
|
)
|
||||||
|
conn.Write(ppBuf.Bytes())
|
||||||
|
|
||||||
|
hello := buildClientHello("echo.test")
|
||||||
|
conn.Write(hello)
|
||||||
|
|
||||||
|
// Backend should echo the ClientHello back (not the PROXY header).
|
||||||
|
echoed := make([]byte, len(hello))
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
if _, err := io.ReadFull(conn, echoed); err != nil {
|
||||||
|
t.Fatalf("read echoed data: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolReceiveGarbage(t *testing.T) {
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("finding free port: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
ProxyProtocol: true,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"echo.test": l4Route("127.0.0.1:1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send garbage instead of a valid PROXY header.
|
||||||
|
conn.Write([]byte("NOT A PROXY HEADER\r\n"))
|
||||||
|
|
||||||
|
// Connection should be closed.
|
||||||
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
_, err = conn.Read(make([]byte, 1))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected connection to be closed for invalid PROXY header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolSend(t *testing.T) {
|
||||||
|
// Backend that captures the first bytes it receives.
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend listen: %v", err)
|
||||||
|
}
|
||||||
|
defer backendLn.Close()
|
||||||
|
|
||||||
|
received := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := backendLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
// Read all available data; the proxy sends PROXY header + ClientHello.
|
||||||
|
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
|
var all []byte
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
for {
|
||||||
|
n, err := conn.Read(buf)
|
||||||
|
all = append(all, buf[:n]...)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// We expect at least pp header (28) + some TLS data.
|
||||||
|
if len(all) >= 28+5 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
received <- all
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("finding free port: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"pp.test": {
|
||||||
|
Backend: backendLn.Addr().String(),
|
||||||
|
Mode: "l4",
|
||||||
|
SendProxyProtocol: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
hello := buildClientHello("pp.test")
|
||||||
|
conn.Write(hello)
|
||||||
|
|
||||||
|
// The backend should receive: PROXY v2 header + ClientHello.
|
||||||
|
select {
|
||||||
|
case data := <-received:
|
||||||
|
// PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs).
|
||||||
|
if len(data) < 28 {
|
||||||
|
t.Fatalf("backend received only %d bytes, want at least 28", len(data))
|
||||||
|
}
|
||||||
|
// Check PROXY v2 signature.
|
||||||
|
sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A}
|
||||||
|
if [12]byte(data[:12]) != sig {
|
||||||
|
t.Fatal("backend data does not start with PROXY v2 signature")
|
||||||
|
}
|
||||||
|
// Verify TLS record header follows the PROXY header.
|
||||||
|
ppLen := 28 // v2 IPv4
|
||||||
|
if len(data) <= ppLen {
|
||||||
|
t.Fatalf("backend received only PROXY header, no TLS data")
|
||||||
|
}
|
||||||
|
if data[ppLen] != 0x16 {
|
||||||
|
t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen])
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for backend data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolNotSent(t *testing.T) {
|
||||||
|
// Backend captures first bytes.
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend listen: %v", err)
|
||||||
|
}
|
||||||
|
defer backendLn.Close()
|
||||||
|
|
||||||
|
received := make(chan []byte, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := backendLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
buf := make([]byte, 4096)
|
||||||
|
n, _ := conn.Read(buf)
|
||||||
|
received <- buf[:n]
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("finding free port: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
srv := newTestServer(t, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"nopp.test": l4Route(backendLn.Addr().String()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
stop := startAndStop(t, srv)
|
||||||
|
defer stop()
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
hello := buildClientHello("nopp.test")
|
||||||
|
conn.Write(hello)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case data := <-received:
|
||||||
|
// First byte should be TLS record header, not PROXY signature.
|
||||||
|
if data[0] != 0x16 {
|
||||||
|
t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0])
|
||||||
|
}
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
t.Fatal("timeout waiting for backend data")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
||||||
|
// Backend that should never be reached.
|
||||||
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("backend listen: %v", err)
|
||||||
|
}
|
||||||
|
defer backendLn.Close()
|
||||||
|
|
||||||
|
reached := make(chan struct{}, 1)
|
||||||
|
go func() {
|
||||||
|
conn, err := backendLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn.Close()
|
||||||
|
reached <- struct{}{}
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("finding free port: %v", err)
|
||||||
|
}
|
||||||
|
proxyAddr := proxyLn.Addr().String()
|
||||||
|
proxyLn.Close()
|
||||||
|
|
||||||
|
// Block 203.0.113.50 (the "real" client IP from PROXY header).
|
||||||
|
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
|
||||||
|
fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("creating firewall: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
Proxy: config.Proxy{
|
||||||
|
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
||||||
|
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
||||||
|
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
srv := New(cfg, fw, []ListenerData{
|
||||||
|
{
|
||||||
|
ID: 1,
|
||||||
|
Addr: proxyAddr,
|
||||||
|
ProxyProtocol: true,
|
||||||
|
Routes: map[string]RouteInfo{
|
||||||
|
"blocked.test": l4Route(backendLn.Addr().String()),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}, logger, "test")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
srv.Run(ctx)
|
||||||
|
}()
|
||||||
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
|
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("dial proxy: %v", err)
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
// Send PROXY v2 with the blocked real IP.
|
||||||
|
var ppBuf bytes.Buffer
|
||||||
|
proxyproto.WriteV2(&ppBuf,
|
||||||
|
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||||
|
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||||
|
)
|
||||||
|
conn.Write(ppBuf.Bytes())
|
||||||
|
conn.Write(buildClientHello("blocked.test"))
|
||||||
|
|
||||||
|
// Connection should be closed (firewall blocks real IP).
|
||||||
|
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
|
_, err = conn.Read(make([]byte, 1))
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected connection to be closed by firewall")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Backend should not have been reached.
|
||||||
|
select {
|
||||||
|
case <-reached:
|
||||||
|
t.Fatal("backend was reached despite firewall blocking real IP")
|
||||||
|
case <-time.After(200 * time.Millisecond):
|
||||||
|
// Expected.
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
|
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
|
||||||
|
|
||||||
func buildClientHello(serverName string) []byte {
|
func buildClientHello(serverName string) []byte {
|
||||||
|
|||||||
Reference in New Issue
Block a user