// Package proxyproto implements PROXY protocol v1 and v2 parsing and v2 writing. // // See https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt package proxyproto import ( "encoding/binary" "fmt" "io" "net" "net/netip" "strconv" "strings" "time" ) // Command represents the PROXY protocol command. type Command byte const ( CommandLocal Command = 0x00 CommandProxy Command = 0x01 ) // Header is a parsed PROXY protocol header. type Header struct { Version int // 1 or 2 Command Command SrcAddr netip.AddrPort DstAddr netip.AddrPort } // v2 binary signature (12 bytes). var v2Signature = [12]byte{ 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, } const ( // maxV1Length is the maximum length of a v1 header line (including CRLF). maxV1Length = 108 // v2HeaderLen is the fixed portion of a v2 header (signature + ver/cmd + fam + len). v2HeaderLen = 16 // maxV2AddrLen limits the address block we'll read. maxV2AddrLen = 536 ) // Parse reads a PROXY protocol header from conn, using deadline for timeout. // It reads only the exact bytes of the PROXY header, leaving the connection // positioned at the first byte after the header (e.g., TLS ClientHello). func Parse(conn net.Conn, deadline time.Time) (Header, error) { conn.SetReadDeadline(deadline) defer conn.SetReadDeadline(time.Time{}) // Read the first byte to determine version. var first [1]byte if _, err := io.ReadFull(conn, first[:]); err != nil { return Header{}, fmt.Errorf("reading first byte: %w", err) } if first[0] == 'P' { return parseV1(conn) } if first[0] == v2Signature[0] { return parseV2(conn, first[0]) } return Header{}, fmt.Errorf("invalid PROXY protocol header: unexpected first byte 0x%02x", first[0]) } // parseV1 parses a v1 text header. The leading 'P' has already been consumed. // Reads byte-by-byte until CRLF to avoid buffering past the header. func parseV1(conn net.Conn) (Header, error) { // We already consumed 'P'. Read until \r\n, byte by byte. var buf []byte buf = append(buf, 'P') var b [1]byte for len(buf) < maxV1Length { if _, err := io.ReadFull(conn, b[:]); err != nil { return Header{}, fmt.Errorf("reading v1 header: %w", err) } buf = append(buf, b[0]) if len(buf) >= 2 && buf[len(buf)-2] == '\r' && buf[len(buf)-1] == '\n' { break } } line := string(buf) if !strings.HasSuffix(line, "\r\n") { return Header{}, fmt.Errorf("v1 header too long or missing CRLF") } line = strings.TrimSuffix(line, "\r\n") parts := strings.Split(line, " ") if len(parts) != 6 { return Header{}, fmt.Errorf("v1 header: expected 6 fields, got %d", len(parts)) } if parts[0] != "PROXY" { return Header{}, fmt.Errorf("v1 header: expected PROXY prefix, got %q", parts[0]) } proto := parts[1] if proto != "TCP4" && proto != "TCP6" { return Header{}, fmt.Errorf("v1 header: unsupported protocol %q", proto) } srcIP, err := netip.ParseAddr(parts[2]) if err != nil { return Header{}, fmt.Errorf("v1 header: invalid source IP %q: %w", parts[2], err) } dstIP, err := netip.ParseAddr(parts[3]) if err != nil { return Header{}, fmt.Errorf("v1 header: invalid destination IP %q: %w", parts[3], err) } srcPort, err := strconv.ParseUint(parts[4], 10, 16) if err != nil { return Header{}, fmt.Errorf("v1 header: invalid source port %q: %w", parts[4], err) } dstPort, err := strconv.ParseUint(parts[5], 10, 16) if err != nil { return Header{}, fmt.Errorf("v1 header: invalid destination port %q: %w", parts[5], err) } return Header{ Version: 1, Command: CommandProxy, SrcAddr: netip.AddrPortFrom(srcIP, uint16(srcPort)), DstAddr: netip.AddrPortFrom(dstIP, uint16(dstPort)), }, nil } // parseV2 parses a v2 binary header. The first byte has already been consumed. func parseV2(conn net.Conn, firstByte byte) (Header, error) { // Read the remaining 15 bytes of the fixed header. var hdr [v2HeaderLen]byte hdr[0] = firstByte if _, err := io.ReadFull(conn, hdr[1:]); err != nil { return Header{}, fmt.Errorf("reading v2 header: %w", err) } // Verify signature. if [12]byte(hdr[:12]) != v2Signature { return Header{}, fmt.Errorf("v2 header: invalid signature") } // Version (upper 4 bits of byte 12) must be 0x2. verCmd := hdr[12] if verCmd>>4 != 0x02 { return Header{}, fmt.Errorf("v2 header: unsupported version %d", verCmd>>4) } cmd := Command(verCmd & 0x0F) if cmd != CommandLocal && cmd != CommandProxy { return Header{}, fmt.Errorf("v2 header: unsupported command %d", cmd) } addrLen := binary.BigEndian.Uint16(hdr[14:16]) if int(addrLen) > maxV2AddrLen { return Header{}, fmt.Errorf("v2 header: address length %d exceeds maximum", addrLen) } // Read the exact address block. addrBuf := make([]byte, addrLen) if addrLen > 0 { if _, err := io.ReadFull(conn, addrBuf); err != nil { return Header{}, fmt.Errorf("reading v2 addresses: %w", err) } } // LOCAL command: no addresses to parse. if cmd == CommandLocal { return Header{ Version: 2, Command: CommandLocal, }, nil } // PROXY command: parse addresses based on family/protocol. famProto := hdr[13] family := famProto >> 4 protocol := famProto & 0x0F // We only support STREAM (TCP) protocol = 0x01. if protocol != 0x01 { return Header{}, fmt.Errorf("v2 header: unsupported protocol %d (only TCP/stream supported)", protocol) } switch family { case 0x01: // AF_INET (IPv4) if len(addrBuf) < 12 { return Header{}, fmt.Errorf("v2 header: IPv4 address block too short (%d bytes)", len(addrBuf)) } srcIP := netip.AddrFrom4([4]byte(addrBuf[0:4])) dstIP := netip.AddrFrom4([4]byte(addrBuf[4:8])) srcPort := binary.BigEndian.Uint16(addrBuf[8:10]) dstPort := binary.BigEndian.Uint16(addrBuf[10:12]) return Header{ Version: 2, Command: CommandProxy, SrcAddr: netip.AddrPortFrom(srcIP, srcPort), DstAddr: netip.AddrPortFrom(dstIP, dstPort), }, nil case 0x02: // AF_INET6 (IPv6) if len(addrBuf) < 36 { return Header{}, fmt.Errorf("v2 header: IPv6 address block too short (%d bytes)", len(addrBuf)) } srcIP := netip.AddrFrom16([16]byte(addrBuf[0:16])) dstIP := netip.AddrFrom16([16]byte(addrBuf[16:32])) srcPort := binary.BigEndian.Uint16(addrBuf[32:34]) dstPort := binary.BigEndian.Uint16(addrBuf[34:36]) return Header{ Version: 2, Command: CommandProxy, SrcAddr: netip.AddrPortFrom(srcIP, srcPort), DstAddr: netip.AddrPortFrom(dstIP, dstPort), }, nil default: return Header{}, fmt.Errorf("v2 header: unsupported address family %d", family) } } // WriteV2 writes a PROXY protocol v2 header with the PROXY command. // src is the original client address; dst is the backend/proxy address. func WriteV2(w io.Writer, src, dst netip.AddrPort) error { var buf []byte // Signature (12 bytes). buf = append(buf, v2Signature[:]...) // Version (0x2) and command (PROXY = 0x1) → 0x21. buf = append(buf, 0x21) srcAddr := src.Addr() dstAddr := dst.Addr() if srcAddr.Is4() && dstAddr.Is4() { // AF_INET, STREAM → 0x11. buf = append(buf, 0x11) // Address length: 4+4+2+2 = 12. buf = binary.BigEndian.AppendUint16(buf, 12) src4 := srcAddr.As4() dst4 := dstAddr.As4() buf = append(buf, src4[:]...) buf = append(buf, dst4[:]...) buf = binary.BigEndian.AppendUint16(buf, src.Port()) buf = binary.BigEndian.AppendUint16(buf, dst.Port()) } else { // AF_INET6, STREAM → 0x21. buf = append(buf, 0x21) // Address length: 16+16+2+2 = 36. buf = binary.BigEndian.AppendUint16(buf, 36) src16 := srcAddr.As16() dst16 := dstAddr.As16() buf = append(buf, src16[:]...) buf = append(buf, dst16[:]...) buf = binary.BigEndian.AppendUint16(buf, src.Port()) buf = binary.BigEndian.AppendUint16(buf, dst.Port()) } _, err := w.Write(buf) return err }