Add PROXY protocol v1/v2 support for multi-hop deployments

New internal/proxyproto package implements PROXY protocol parsing and
writing without buffering past the header boundary (reads exact byte
counts so the connection is correctly positioned for SNI extraction).

Parser: auto-detects v1 (text) and v2 (binary) by first byte. Parses
TCP4/TCP6 for both versions plus v2 LOCAL command. Enforces max header
sizes and read deadlines.

Writer: generates v2 binary headers for IPv4 and IPv6 with PROXY
command.

Server integration:
- Receive: when listener.ProxyProtocol is true, parses PROXY header
  before firewall check. Real client IP from header is used for
  firewall evaluation and logging. Malformed headers cause RST.
- Send: when route.SendProxyProtocol is true, writes PROXY v2 header
  to backend before forwarding the ClientHello bytes.

Tests cover v1/v2 parsing, malformed rejection, timeout, round-trip
write+parse, and five server integration tests: receive with valid
header, receive with garbage, send verification, send-disabled
verification, and firewall evaluation using the real client IP.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 13:28:49 -07:00
parent ed94548dfa
commit 1ad9a1a43b
5 changed files with 1025 additions and 6 deletions

View File

@@ -0,0 +1,264 @@
// Package proxyproto implements PROXY protocol v1 and v2 parsing and v2 writing.
//
// See https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt
package proxyproto
import (
"encoding/binary"
"fmt"
"io"
"net"
"net/netip"
"strconv"
"strings"
"time"
)
// Command represents the PROXY protocol command.
type Command byte
const (
CommandLocal Command = 0x00
CommandProxy Command = 0x01
)
// Header is a parsed PROXY protocol header.
type Header struct {
Version int // 1 or 2
Command Command
SrcAddr netip.AddrPort
DstAddr netip.AddrPort
}
// v2 binary signature (12 bytes).
var v2Signature = [12]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A,
}
const (
// maxV1Length is the maximum length of a v1 header line (including CRLF).
maxV1Length = 108
// v2HeaderLen is the fixed portion of a v2 header (signature + ver/cmd + fam + len).
v2HeaderLen = 16
// maxV2AddrLen limits the address block we'll read.
maxV2AddrLen = 536
)
// Parse reads a PROXY protocol header from conn, using deadline for timeout.
// It reads only the exact bytes of the PROXY header, leaving the connection
// positioned at the first byte after the header (e.g., TLS ClientHello).
func Parse(conn net.Conn, deadline time.Time) (Header, error) {
conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{})
// Read the first byte to determine version.
var first [1]byte
if _, err := io.ReadFull(conn, first[:]); err != nil {
return Header{}, fmt.Errorf("reading first byte: %w", err)
}
if first[0] == 'P' {
return parseV1(conn)
}
if first[0] == v2Signature[0] {
return parseV2(conn, first[0])
}
return Header{}, fmt.Errorf("invalid PROXY protocol header: unexpected first byte 0x%02x", first[0])
}
// parseV1 parses a v1 text header. The leading 'P' has already been consumed.
// Reads byte-by-byte until CRLF to avoid buffering past the header.
func parseV1(conn net.Conn) (Header, error) {
// We already consumed 'P'. Read until \r\n, byte by byte.
var buf []byte
buf = append(buf, 'P')
var b [1]byte
for len(buf) < maxV1Length {
if _, err := io.ReadFull(conn, b[:]); err != nil {
return Header{}, fmt.Errorf("reading v1 header: %w", err)
}
buf = append(buf, b[0])
if len(buf) >= 2 && buf[len(buf)-2] == '\r' && buf[len(buf)-1] == '\n' {
break
}
}
line := string(buf)
if !strings.HasSuffix(line, "\r\n") {
return Header{}, fmt.Errorf("v1 header too long or missing CRLF")
}
line = strings.TrimSuffix(line, "\r\n")
parts := strings.Split(line, " ")
if len(parts) != 6 {
return Header{}, fmt.Errorf("v1 header: expected 6 fields, got %d", len(parts))
}
if parts[0] != "PROXY" {
return Header{}, fmt.Errorf("v1 header: expected PROXY prefix, got %q", parts[0])
}
proto := parts[1]
if proto != "TCP4" && proto != "TCP6" {
return Header{}, fmt.Errorf("v1 header: unsupported protocol %q", proto)
}
srcIP, err := netip.ParseAddr(parts[2])
if err != nil {
return Header{}, fmt.Errorf("v1 header: invalid source IP %q: %w", parts[2], err)
}
dstIP, err := netip.ParseAddr(parts[3])
if err != nil {
return Header{}, fmt.Errorf("v1 header: invalid destination IP %q: %w", parts[3], err)
}
srcPort, err := strconv.ParseUint(parts[4], 10, 16)
if err != nil {
return Header{}, fmt.Errorf("v1 header: invalid source port %q: %w", parts[4], err)
}
dstPort, err := strconv.ParseUint(parts[5], 10, 16)
if err != nil {
return Header{}, fmt.Errorf("v1 header: invalid destination port %q: %w", parts[5], err)
}
return Header{
Version: 1,
Command: CommandProxy,
SrcAddr: netip.AddrPortFrom(srcIP, uint16(srcPort)),
DstAddr: netip.AddrPortFrom(dstIP, uint16(dstPort)),
}, nil
}
// parseV2 parses a v2 binary header. The first byte has already been consumed.
func parseV2(conn net.Conn, firstByte byte) (Header, error) {
// Read the remaining 15 bytes of the fixed header.
var hdr [v2HeaderLen]byte
hdr[0] = firstByte
if _, err := io.ReadFull(conn, hdr[1:]); err != nil {
return Header{}, fmt.Errorf("reading v2 header: %w", err)
}
// Verify signature.
if [12]byte(hdr[:12]) != v2Signature {
return Header{}, fmt.Errorf("v2 header: invalid signature")
}
// Version (upper 4 bits of byte 12) must be 0x2.
verCmd := hdr[12]
if verCmd>>4 != 0x02 {
return Header{}, fmt.Errorf("v2 header: unsupported version %d", verCmd>>4)
}
cmd := Command(verCmd & 0x0F)
if cmd != CommandLocal && cmd != CommandProxy {
return Header{}, fmt.Errorf("v2 header: unsupported command %d", cmd)
}
addrLen := binary.BigEndian.Uint16(hdr[14:16])
if int(addrLen) > maxV2AddrLen {
return Header{}, fmt.Errorf("v2 header: address length %d exceeds maximum", addrLen)
}
// Read the exact address block.
addrBuf := make([]byte, addrLen)
if addrLen > 0 {
if _, err := io.ReadFull(conn, addrBuf); err != nil {
return Header{}, fmt.Errorf("reading v2 addresses: %w", err)
}
}
// LOCAL command: no addresses to parse.
if cmd == CommandLocal {
return Header{
Version: 2,
Command: CommandLocal,
}, nil
}
// PROXY command: parse addresses based on family/protocol.
famProto := hdr[13]
family := famProto >> 4
protocol := famProto & 0x0F
// We only support STREAM (TCP) protocol = 0x01.
if protocol != 0x01 {
return Header{}, fmt.Errorf("v2 header: unsupported protocol %d (only TCP/stream supported)", protocol)
}
switch family {
case 0x01: // AF_INET (IPv4)
if len(addrBuf) < 12 {
return Header{}, fmt.Errorf("v2 header: IPv4 address block too short (%d bytes)", len(addrBuf))
}
srcIP := netip.AddrFrom4([4]byte(addrBuf[0:4]))
dstIP := netip.AddrFrom4([4]byte(addrBuf[4:8]))
srcPort := binary.BigEndian.Uint16(addrBuf[8:10])
dstPort := binary.BigEndian.Uint16(addrBuf[10:12])
return Header{
Version: 2,
Command: CommandProxy,
SrcAddr: netip.AddrPortFrom(srcIP, srcPort),
DstAddr: netip.AddrPortFrom(dstIP, dstPort),
}, nil
case 0x02: // AF_INET6 (IPv6)
if len(addrBuf) < 36 {
return Header{}, fmt.Errorf("v2 header: IPv6 address block too short (%d bytes)", len(addrBuf))
}
srcIP := netip.AddrFrom16([16]byte(addrBuf[0:16]))
dstIP := netip.AddrFrom16([16]byte(addrBuf[16:32]))
srcPort := binary.BigEndian.Uint16(addrBuf[32:34])
dstPort := binary.BigEndian.Uint16(addrBuf[34:36])
return Header{
Version: 2,
Command: CommandProxy,
SrcAddr: netip.AddrPortFrom(srcIP, srcPort),
DstAddr: netip.AddrPortFrom(dstIP, dstPort),
}, nil
default:
return Header{}, fmt.Errorf("v2 header: unsupported address family %d", family)
}
}
// WriteV2 writes a PROXY protocol v2 header with the PROXY command.
// src is the original client address; dst is the backend/proxy address.
func WriteV2(w io.Writer, src, dst netip.AddrPort) error {
var buf []byte
// Signature (12 bytes).
buf = append(buf, v2Signature[:]...)
// Version (0x2) and command (PROXY = 0x1) → 0x21.
buf = append(buf, 0x21)
srcAddr := src.Addr()
dstAddr := dst.Addr()
if srcAddr.Is4() && dstAddr.Is4() {
// AF_INET, STREAM → 0x11.
buf = append(buf, 0x11)
// Address length: 4+4+2+2 = 12.
buf = binary.BigEndian.AppendUint16(buf, 12)
src4 := srcAddr.As4()
dst4 := dstAddr.As4()
buf = append(buf, src4[:]...)
buf = append(buf, dst4[:]...)
buf = binary.BigEndian.AppendUint16(buf, src.Port())
buf = binary.BigEndian.AppendUint16(buf, dst.Port())
} else {
// AF_INET6, STREAM → 0x21.
buf = append(buf, 0x21)
// Address length: 16+16+2+2 = 36.
buf = binary.BigEndian.AppendUint16(buf, 36)
src16 := srcAddr.As16()
dst16 := dstAddr.As16()
buf = append(buf, src16[:]...)
buf = append(buf, dst16[:]...)
buf = binary.BigEndian.AppendUint16(buf, src.Port())
buf = binary.BigEndian.AppendUint16(buf, dst.Port())
}
_, err := w.Write(buf)
return err
}

View File

@@ -0,0 +1,389 @@
package proxyproto
import (
"bytes"
"encoding/binary"
"net"
"net/netip"
"testing"
"time"
)
// pipeWithDeadline returns a net.Conn pair where the reader side supports deadlines.
func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { ln.Close() })
ch := make(chan net.Conn, 1)
go func() {
c, err := ln.Accept()
if err != nil {
return
}
ch <- c
}()
w, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() { w.Close() })
r := <-ch
t.Cleanup(func() { r.Close() })
return r, w
}
func TestParseV1TCP4(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 1 {
t.Fatalf("version = %d, want 1", hdr.Version)
}
if hdr.Command != CommandProxy {
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
}
if got := hdr.SrcAddr.String(); got != "192.168.1.1:56324" {
t.Fatalf("src = %s, want 192.168.1.1:56324", got)
}
if got := hdr.DstAddr.String(); got != "10.0.0.1:443" {
t.Fatalf("dst = %s, want 10.0.0.1:443", got)
}
}
func TestParseV1TCP6(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 1 {
t.Fatalf("version = %d, want 1", hdr.Version)
}
wantSrc := "[2001:db8::1]:56324"
if got := hdr.SrcAddr.String(); got != wantSrc {
t.Fatalf("src = %s, want %s", got, wantSrc)
}
}
func TestParseV2TCP4(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x21) // version 2, PROXY command
buf = append(buf, 0x11) // AF_INET, STREAM
buf = binary.BigEndian.AppendUint16(buf, 12)
buf = append(buf, 192, 168, 1, 100) // src IP
buf = append(buf, 10, 0, 0, 1) // dst IP
buf = binary.BigEndian.AppendUint16(buf, 12345)
buf = binary.BigEndian.AppendUint16(buf, 443)
writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 2 {
t.Fatalf("version = %d, want 2", hdr.Version)
}
if hdr.Command != CommandProxy {
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
}
if got := hdr.SrcAddr.String(); got != "192.168.1.100:12345" {
t.Fatalf("src = %s, want 192.168.1.100:12345", got)
}
if got := hdr.DstAddr.String(); got != "10.0.0.1:443" {
t.Fatalf("dst = %s, want 10.0.0.1:443", got)
}
}
func TestParseV2TCP6(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x21) // version 2, PROXY command
buf = append(buf, 0x21) // AF_INET6, STREAM
buf = binary.BigEndian.AppendUint16(buf, 36)
// src: 2001:db8::1
src := netip.MustParseAddr("2001:db8::1").As16()
buf = append(buf, src[:]...)
// dst: 2001:db8::2
dst := netip.MustParseAddr("2001:db8::2").As16()
buf = append(buf, dst[:]...)
buf = binary.BigEndian.AppendUint16(buf, 56324)
buf = binary.BigEndian.AppendUint16(buf, 8443)
writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 2 {
t.Fatalf("version = %d, want 2", hdr.Version)
}
wantSrc := "[2001:db8::1]:56324"
if got := hdr.SrcAddr.String(); got != wantSrc {
t.Fatalf("src = %s, want %s", got, wantSrc)
}
}
func TestParseV2Local(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x20) // version 2, LOCAL command
buf = append(buf, 0x00) // unspec family, unspec protocol
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 2 {
t.Fatalf("version = %d, want 2", hdr.Version)
}
if hdr.Command != CommandLocal {
t.Fatalf("command = %d, want CommandLocal", hdr.Command)
}
}
func TestParseV1Malformed(t *testing.T) {
tests := []struct {
name string
data string
}{
{"wrong prefix", "XROXY TCP4 1.2.3.4 5.6.7.8 1 2\r\n"},
{"missing fields", "PROXY TCP4 1.2.3.4\r\n"},
{"bad protocol", "PROXY UDP4 1.2.3.4 5.6.7.8 1 2\r\n"},
{"bad src IP", "PROXY TCP4 not-an-ip 5.6.7.8 1 2\r\n"},
{"bad src port", "PROXY TCP4 1.2.3.4 5.6.7.8 notport 2\r\n"},
{"missing CRLF", "PROXY TCP4 1.2.3.4 5.6.7.8 1 2\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte(tt.data))
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error")
}
})
}
}
func TestParseV2Malformed(t *testing.T) {
t.Run("bad signature", func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
bad := make([]byte, v2HeaderLen)
bad[0] = v2Signature[0] // first byte matches but rest doesn't
writer.Write(bad)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error for bad signature")
}
})
t.Run("bad version", func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x31) // version 3, PROXY command
buf = append(buf, 0x11)
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error for bad version")
}
})
t.Run("truncated address", func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x21) // version 2, PROXY
buf = append(buf, 0x11) // AF_INET, STREAM
buf = binary.BigEndian.AppendUint16(buf, 12)
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
writer.Write(buf)
writer.Close()
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error for truncated address")
}
})
t.Run("unsupported family", func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
var buf []byte
buf = append(buf, v2Signature[:]...)
buf = append(buf, 0x21) // version 2, PROXY
buf = append(buf, 0x31) // AF_UNIX (3), STREAM
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error for unsupported family")
}
})
}
func TestParseTimeout(t *testing.T) {
reader, _ := pipeWithDeadline(t)
// No data sent — should hit deadline.
_, err := Parse(reader, time.Now().Add(100*time.Millisecond))
if err == nil {
t.Fatal("expected timeout error")
}
}
func TestWriteV2IPv4(t *testing.T) {
src := netip.MustParseAddrPort("192.168.1.1:56324")
dst := netip.MustParseAddrPort("10.0.0.1:443")
var buf bytes.Buffer
if err := WriteV2(&buf, src, dst); err != nil {
t.Fatalf("WriteV2: %v", err)
}
b := buf.Bytes()
// Signature (12) + ver/cmd (1) + fam (1) + len (2) + 4+4+2+2 = 28
if len(b) != 28 {
t.Fatalf("wrote %d bytes, want 28", len(b))
}
// Verify signature.
if [12]byte(b[:12]) != v2Signature {
t.Fatal("bad signature")
}
if b[12] != 0x21 {
t.Fatalf("ver/cmd = 0x%02x, want 0x21", b[12])
}
if b[13] != 0x11 {
t.Fatalf("fam/proto = 0x%02x, want 0x11", b[13])
}
}
func TestWriteV2IPv6(t *testing.T) {
src := netip.MustParseAddrPort("[2001:db8::1]:56324")
dst := netip.MustParseAddrPort("[2001:db8::2]:443")
var buf bytes.Buffer
if err := WriteV2(&buf, src, dst); err != nil {
t.Fatalf("WriteV2: %v", err)
}
b := buf.Bytes()
// Signature (12) + ver/cmd (1) + fam (1) + len (2) + 16+16+2+2 = 52
if len(b) != 52 {
t.Fatalf("wrote %d bytes, want 52", len(b))
}
if b[13] != 0x21 {
t.Fatalf("fam/proto = 0x%02x, want 0x21", b[13])
}
}
func TestRoundTripV2IPv4(t *testing.T) {
src := netip.MustParseAddrPort("203.0.113.50:12345")
dst := netip.MustParseAddrPort("198.51.100.1:443")
reader, writer := pipeWithDeadline(t)
go func() {
WriteV2(writer, src, dst)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.Version != 2 {
t.Fatalf("version = %d, want 2", hdr.Version)
}
if hdr.Command != CommandProxy {
t.Fatalf("command = %d, want CommandProxy", hdr.Command)
}
if hdr.SrcAddr != src {
t.Fatalf("src = %s, want %s", hdr.SrcAddr, src)
}
if hdr.DstAddr != dst {
t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst)
}
}
func TestRoundTripV2IPv6(t *testing.T) {
src := netip.MustParseAddrPort("[2001:db8:cafe::1]:40000")
dst := netip.MustParseAddrPort("[2001:db8:beef::2]:8443")
reader, writer := pipeWithDeadline(t)
go func() {
WriteV2(writer, src, dst)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("Parse: %v", err)
}
if hdr.SrcAddr != src {
t.Fatalf("src = %s, want %s", hdr.SrcAddr, src)
}
if hdr.DstAddr != dst {
t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst)
}
}
func TestParseGarbageFirstByte(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte{0xFF, 0x00, 0x01})
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
t.Fatal("expected error for garbage first byte")
}
}