diff --git a/PROGRESS.md b/PROGRESS.md index 1cdd835..1547f45 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -21,10 +21,10 @@ proceeds. Each item is marked: ## Phase 2: PROXY Protocol -- [ ] 2.1 `internal/proxyproto/` package (v1/v2 parser, v2 writer) -- [ ] 2.2 Server integration — receive (parse PROXY header before firewall on enabled listeners) -- [ ] 2.3 Server integration — send on L4 (write PROXY v2 header before ClientHello on enabled routes) -- [ ] 2.4 Tests (receive, send, firewall uses real IP, malformed header rejection) +- [x] 2.1 `internal/proxyproto/` package (v1/v2 parser, v2 writer) +- [x] 2.2 Server integration — receive (parse PROXY header before firewall on enabled listeners) +- [x] 2.3 Server integration — send on L4 (write PROXY v2 header before ClientHello on enabled routes) +- [x] 2.4 Tests (receive, send, firewall uses real IP, malformed header rejection) ## Phase 3: L7 Proxying diff --git a/internal/proxyproto/proxyproto.go b/internal/proxyproto/proxyproto.go new file mode 100644 index 0000000..acffed8 --- /dev/null +++ b/internal/proxyproto/proxyproto.go @@ -0,0 +1,264 @@ +// Package proxyproto implements PROXY protocol v1 and v2 parsing and v2 writing. +// +// See https://www.haproxy.org/download/2.9/doc/proxy-protocol.txt +package proxyproto + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "net/netip" + "strconv" + "strings" + "time" +) + +// Command represents the PROXY protocol command. +type Command byte + +const ( + CommandLocal Command = 0x00 + CommandProxy Command = 0x01 +) + +// Header is a parsed PROXY protocol header. +type Header struct { + Version int // 1 or 2 + Command Command + SrcAddr netip.AddrPort + DstAddr netip.AddrPort +} + +// v2 binary signature (12 bytes). +var v2Signature = [12]byte{ + 0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A, +} + +const ( + // maxV1Length is the maximum length of a v1 header line (including CRLF). + maxV1Length = 108 + + // v2HeaderLen is the fixed portion of a v2 header (signature + ver/cmd + fam + len). + v2HeaderLen = 16 + + // maxV2AddrLen limits the address block we'll read. + maxV2AddrLen = 536 +) + +// Parse reads a PROXY protocol header from conn, using deadline for timeout. +// It reads only the exact bytes of the PROXY header, leaving the connection +// positioned at the first byte after the header (e.g., TLS ClientHello). +func Parse(conn net.Conn, deadline time.Time) (Header, error) { + conn.SetReadDeadline(deadline) + defer conn.SetReadDeadline(time.Time{}) + + // Read the first byte to determine version. + var first [1]byte + if _, err := io.ReadFull(conn, first[:]); err != nil { + return Header{}, fmt.Errorf("reading first byte: %w", err) + } + + if first[0] == 'P' { + return parseV1(conn) + } + if first[0] == v2Signature[0] { + return parseV2(conn, first[0]) + } + return Header{}, fmt.Errorf("invalid PROXY protocol header: unexpected first byte 0x%02x", first[0]) +} + +// parseV1 parses a v1 text header. The leading 'P' has already been consumed. +// Reads byte-by-byte until CRLF to avoid buffering past the header. +func parseV1(conn net.Conn) (Header, error) { + // We already consumed 'P'. Read until \r\n, byte by byte. + var buf []byte + buf = append(buf, 'P') + var b [1]byte + for len(buf) < maxV1Length { + if _, err := io.ReadFull(conn, b[:]); err != nil { + return Header{}, fmt.Errorf("reading v1 header: %w", err) + } + buf = append(buf, b[0]) + if len(buf) >= 2 && buf[len(buf)-2] == '\r' && buf[len(buf)-1] == '\n' { + break + } + } + + line := string(buf) + if !strings.HasSuffix(line, "\r\n") { + return Header{}, fmt.Errorf("v1 header too long or missing CRLF") + } + line = strings.TrimSuffix(line, "\r\n") + + parts := strings.Split(line, " ") + if len(parts) != 6 { + return Header{}, fmt.Errorf("v1 header: expected 6 fields, got %d", len(parts)) + } + if parts[0] != "PROXY" { + return Header{}, fmt.Errorf("v1 header: expected PROXY prefix, got %q", parts[0]) + } + + proto := parts[1] + if proto != "TCP4" && proto != "TCP6" { + return Header{}, fmt.Errorf("v1 header: unsupported protocol %q", proto) + } + + srcIP, err := netip.ParseAddr(parts[2]) + if err != nil { + return Header{}, fmt.Errorf("v1 header: invalid source IP %q: %w", parts[2], err) + } + dstIP, err := netip.ParseAddr(parts[3]) + if err != nil { + return Header{}, fmt.Errorf("v1 header: invalid destination IP %q: %w", parts[3], err) + } + srcPort, err := strconv.ParseUint(parts[4], 10, 16) + if err != nil { + return Header{}, fmt.Errorf("v1 header: invalid source port %q: %w", parts[4], err) + } + dstPort, err := strconv.ParseUint(parts[5], 10, 16) + if err != nil { + return Header{}, fmt.Errorf("v1 header: invalid destination port %q: %w", parts[5], err) + } + + return Header{ + Version: 1, + Command: CommandProxy, + SrcAddr: netip.AddrPortFrom(srcIP, uint16(srcPort)), + DstAddr: netip.AddrPortFrom(dstIP, uint16(dstPort)), + }, nil +} + +// parseV2 parses a v2 binary header. The first byte has already been consumed. +func parseV2(conn net.Conn, firstByte byte) (Header, error) { + // Read the remaining 15 bytes of the fixed header. + var hdr [v2HeaderLen]byte + hdr[0] = firstByte + if _, err := io.ReadFull(conn, hdr[1:]); err != nil { + return Header{}, fmt.Errorf("reading v2 header: %w", err) + } + + // Verify signature. + if [12]byte(hdr[:12]) != v2Signature { + return Header{}, fmt.Errorf("v2 header: invalid signature") + } + + // Version (upper 4 bits of byte 12) must be 0x2. + verCmd := hdr[12] + if verCmd>>4 != 0x02 { + return Header{}, fmt.Errorf("v2 header: unsupported version %d", verCmd>>4) + } + + cmd := Command(verCmd & 0x0F) + if cmd != CommandLocal && cmd != CommandProxy { + return Header{}, fmt.Errorf("v2 header: unsupported command %d", cmd) + } + + addrLen := binary.BigEndian.Uint16(hdr[14:16]) + if int(addrLen) > maxV2AddrLen { + return Header{}, fmt.Errorf("v2 header: address length %d exceeds maximum", addrLen) + } + + // Read the exact address block. + addrBuf := make([]byte, addrLen) + if addrLen > 0 { + if _, err := io.ReadFull(conn, addrBuf); err != nil { + return Header{}, fmt.Errorf("reading v2 addresses: %w", err) + } + } + + // LOCAL command: no addresses to parse. + if cmd == CommandLocal { + return Header{ + Version: 2, + Command: CommandLocal, + }, nil + } + + // PROXY command: parse addresses based on family/protocol. + famProto := hdr[13] + family := famProto >> 4 + protocol := famProto & 0x0F + + // We only support STREAM (TCP) protocol = 0x01. + if protocol != 0x01 { + return Header{}, fmt.Errorf("v2 header: unsupported protocol %d (only TCP/stream supported)", protocol) + } + + switch family { + case 0x01: // AF_INET (IPv4) + if len(addrBuf) < 12 { + return Header{}, fmt.Errorf("v2 header: IPv4 address block too short (%d bytes)", len(addrBuf)) + } + srcIP := netip.AddrFrom4([4]byte(addrBuf[0:4])) + dstIP := netip.AddrFrom4([4]byte(addrBuf[4:8])) + srcPort := binary.BigEndian.Uint16(addrBuf[8:10]) + dstPort := binary.BigEndian.Uint16(addrBuf[10:12]) + return Header{ + Version: 2, + Command: CommandProxy, + SrcAddr: netip.AddrPortFrom(srcIP, srcPort), + DstAddr: netip.AddrPortFrom(dstIP, dstPort), + }, nil + + case 0x02: // AF_INET6 (IPv6) + if len(addrBuf) < 36 { + return Header{}, fmt.Errorf("v2 header: IPv6 address block too short (%d bytes)", len(addrBuf)) + } + srcIP := netip.AddrFrom16([16]byte(addrBuf[0:16])) + dstIP := netip.AddrFrom16([16]byte(addrBuf[16:32])) + srcPort := binary.BigEndian.Uint16(addrBuf[32:34]) + dstPort := binary.BigEndian.Uint16(addrBuf[34:36]) + return Header{ + Version: 2, + Command: CommandProxy, + SrcAddr: netip.AddrPortFrom(srcIP, srcPort), + DstAddr: netip.AddrPortFrom(dstIP, dstPort), + }, nil + + default: + return Header{}, fmt.Errorf("v2 header: unsupported address family %d", family) + } +} + +// WriteV2 writes a PROXY protocol v2 header with the PROXY command. +// src is the original client address; dst is the backend/proxy address. +func WriteV2(w io.Writer, src, dst netip.AddrPort) error { + var buf []byte + + // Signature (12 bytes). + buf = append(buf, v2Signature[:]...) + + // Version (0x2) and command (PROXY = 0x1) → 0x21. + buf = append(buf, 0x21) + + srcAddr := src.Addr() + dstAddr := dst.Addr() + + if srcAddr.Is4() && dstAddr.Is4() { + // AF_INET, STREAM → 0x11. + buf = append(buf, 0x11) + // Address length: 4+4+2+2 = 12. + buf = binary.BigEndian.AppendUint16(buf, 12) + src4 := srcAddr.As4() + dst4 := dstAddr.As4() + buf = append(buf, src4[:]...) + buf = append(buf, dst4[:]...) + buf = binary.BigEndian.AppendUint16(buf, src.Port()) + buf = binary.BigEndian.AppendUint16(buf, dst.Port()) + } else { + // AF_INET6, STREAM → 0x21. + buf = append(buf, 0x21) + // Address length: 16+16+2+2 = 36. + buf = binary.BigEndian.AppendUint16(buf, 36) + src16 := srcAddr.As16() + dst16 := dstAddr.As16() + buf = append(buf, src16[:]...) + buf = append(buf, dst16[:]...) + buf = binary.BigEndian.AppendUint16(buf, src.Port()) + buf = binary.BigEndian.AppendUint16(buf, dst.Port()) + } + + _, err := w.Write(buf) + return err +} diff --git a/internal/proxyproto/proxyproto_test.go b/internal/proxyproto/proxyproto_test.go new file mode 100644 index 0000000..24706ee --- /dev/null +++ b/internal/proxyproto/proxyproto_test.go @@ -0,0 +1,389 @@ +package proxyproto + +import ( + "bytes" + "encoding/binary" + "net" + "net/netip" + "testing" + "time" +) + +// pipeWithDeadline returns a net.Conn pair where the reader side supports deadlines. +func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) { + t.Helper() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + t.Cleanup(func() { ln.Close() }) + + ch := make(chan net.Conn, 1) + go func() { + c, err := ln.Accept() + if err != nil { + return + } + ch <- c + }() + + w, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + t.Cleanup(func() { w.Close() }) + + r := <-ch + t.Cleanup(func() { r.Close() }) + return r, w +} + +func TestParseV1TCP4(t *testing.T) { + reader, writer := pipeWithDeadline(t) + + go func() { + writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n")) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 1 { + t.Fatalf("version = %d, want 1", hdr.Version) + } + if hdr.Command != CommandProxy { + t.Fatalf("command = %d, want CommandProxy", hdr.Command) + } + if got := hdr.SrcAddr.String(); got != "192.168.1.1:56324" { + t.Fatalf("src = %s, want 192.168.1.1:56324", got) + } + if got := hdr.DstAddr.String(); got != "10.0.0.1:443" { + t.Fatalf("dst = %s, want 10.0.0.1:443", got) + } +} + +func TestParseV1TCP6(t *testing.T) { + reader, writer := pipeWithDeadline(t) + + go func() { + writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n")) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 1 { + t.Fatalf("version = %d, want 1", hdr.Version) + } + wantSrc := "[2001:db8::1]:56324" + if got := hdr.SrcAddr.String(); got != wantSrc { + t.Fatalf("src = %s, want %s", got, wantSrc) + } +} + +func TestParseV2TCP4(t *testing.T) { + reader, writer := pipeWithDeadline(t) + + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x21) // version 2, PROXY command + buf = append(buf, 0x11) // AF_INET, STREAM + buf = binary.BigEndian.AppendUint16(buf, 12) + buf = append(buf, 192, 168, 1, 100) // src IP + buf = append(buf, 10, 0, 0, 1) // dst IP + buf = binary.BigEndian.AppendUint16(buf, 12345) + buf = binary.BigEndian.AppendUint16(buf, 443) + writer.Write(buf) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 2 { + t.Fatalf("version = %d, want 2", hdr.Version) + } + if hdr.Command != CommandProxy { + t.Fatalf("command = %d, want CommandProxy", hdr.Command) + } + if got := hdr.SrcAddr.String(); got != "192.168.1.100:12345" { + t.Fatalf("src = %s, want 192.168.1.100:12345", got) + } + if got := hdr.DstAddr.String(); got != "10.0.0.1:443" { + t.Fatalf("dst = %s, want 10.0.0.1:443", got) + } +} + +func TestParseV2TCP6(t *testing.T) { + reader, writer := pipeWithDeadline(t) + + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x21) // version 2, PROXY command + buf = append(buf, 0x21) // AF_INET6, STREAM + buf = binary.BigEndian.AppendUint16(buf, 36) + // src: 2001:db8::1 + src := netip.MustParseAddr("2001:db8::1").As16() + buf = append(buf, src[:]...) + // dst: 2001:db8::2 + dst := netip.MustParseAddr("2001:db8::2").As16() + buf = append(buf, dst[:]...) + buf = binary.BigEndian.AppendUint16(buf, 56324) + buf = binary.BigEndian.AppendUint16(buf, 8443) + writer.Write(buf) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 2 { + t.Fatalf("version = %d, want 2", hdr.Version) + } + wantSrc := "[2001:db8::1]:56324" + if got := hdr.SrcAddr.String(); got != wantSrc { + t.Fatalf("src = %s, want %s", got, wantSrc) + } +} + +func TestParseV2Local(t *testing.T) { + reader, writer := pipeWithDeadline(t) + + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x20) // version 2, LOCAL command + buf = append(buf, 0x00) // unspec family, unspec protocol + buf = binary.BigEndian.AppendUint16(buf, 0) + writer.Write(buf) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 2 { + t.Fatalf("version = %d, want 2", hdr.Version) + } + if hdr.Command != CommandLocal { + t.Fatalf("command = %d, want CommandLocal", hdr.Command) + } +} + +func TestParseV1Malformed(t *testing.T) { + tests := []struct { + name string + data string + }{ + {"wrong prefix", "XROXY TCP4 1.2.3.4 5.6.7.8 1 2\r\n"}, + {"missing fields", "PROXY TCP4 1.2.3.4\r\n"}, + {"bad protocol", "PROXY UDP4 1.2.3.4 5.6.7.8 1 2\r\n"}, + {"bad src IP", "PROXY TCP4 not-an-ip 5.6.7.8 1 2\r\n"}, + {"bad src port", "PROXY TCP4 1.2.3.4 5.6.7.8 notport 2\r\n"}, + {"missing CRLF", "PROXY TCP4 1.2.3.4 5.6.7.8 1 2\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + writer.Write([]byte(tt.data)) + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error") + } + }) + } +} + +func TestParseV2Malformed(t *testing.T) { + t.Run("bad signature", func(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + bad := make([]byte, v2HeaderLen) + bad[0] = v2Signature[0] // first byte matches but rest doesn't + writer.Write(bad) + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error for bad signature") + } + }) + + t.Run("bad version", func(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x31) // version 3, PROXY command + buf = append(buf, 0x11) + buf = binary.BigEndian.AppendUint16(buf, 0) + writer.Write(buf) + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error for bad version") + } + }) + + t.Run("truncated address", func(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x21) // version 2, PROXY + buf = append(buf, 0x11) // AF_INET, STREAM + buf = binary.BigEndian.AppendUint16(buf, 12) + buf = append(buf, 1, 2, 3) // only 3 bytes, need 12 + writer.Write(buf) + writer.Close() + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error for truncated address") + } + }) + + t.Run("unsupported family", func(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + var buf []byte + buf = append(buf, v2Signature[:]...) + buf = append(buf, 0x21) // version 2, PROXY + buf = append(buf, 0x31) // AF_UNIX (3), STREAM + buf = binary.BigEndian.AppendUint16(buf, 0) + writer.Write(buf) + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error for unsupported family") + } + }) +} + +func TestParseTimeout(t *testing.T) { + reader, _ := pipeWithDeadline(t) + + // No data sent — should hit deadline. + _, err := Parse(reader, time.Now().Add(100*time.Millisecond)) + if err == nil { + t.Fatal("expected timeout error") + } +} + +func TestWriteV2IPv4(t *testing.T) { + src := netip.MustParseAddrPort("192.168.1.1:56324") + dst := netip.MustParseAddrPort("10.0.0.1:443") + + var buf bytes.Buffer + if err := WriteV2(&buf, src, dst); err != nil { + t.Fatalf("WriteV2: %v", err) + } + + b := buf.Bytes() + // Signature (12) + ver/cmd (1) + fam (1) + len (2) + 4+4+2+2 = 28 + if len(b) != 28 { + t.Fatalf("wrote %d bytes, want 28", len(b)) + } + // Verify signature. + if [12]byte(b[:12]) != v2Signature { + t.Fatal("bad signature") + } + if b[12] != 0x21 { + t.Fatalf("ver/cmd = 0x%02x, want 0x21", b[12]) + } + if b[13] != 0x11 { + t.Fatalf("fam/proto = 0x%02x, want 0x11", b[13]) + } +} + +func TestWriteV2IPv6(t *testing.T) { + src := netip.MustParseAddrPort("[2001:db8::1]:56324") + dst := netip.MustParseAddrPort("[2001:db8::2]:443") + + var buf bytes.Buffer + if err := WriteV2(&buf, src, dst); err != nil { + t.Fatalf("WriteV2: %v", err) + } + + b := buf.Bytes() + // Signature (12) + ver/cmd (1) + fam (1) + len (2) + 16+16+2+2 = 52 + if len(b) != 52 { + t.Fatalf("wrote %d bytes, want 52", len(b)) + } + if b[13] != 0x21 { + t.Fatalf("fam/proto = 0x%02x, want 0x21", b[13]) + } +} + +func TestRoundTripV2IPv4(t *testing.T) { + src := netip.MustParseAddrPort("203.0.113.50:12345") + dst := netip.MustParseAddrPort("198.51.100.1:443") + + reader, writer := pipeWithDeadline(t) + + go func() { + WriteV2(writer, src, dst) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.Version != 2 { + t.Fatalf("version = %d, want 2", hdr.Version) + } + if hdr.Command != CommandProxy { + t.Fatalf("command = %d, want CommandProxy", hdr.Command) + } + if hdr.SrcAddr != src { + t.Fatalf("src = %s, want %s", hdr.SrcAddr, src) + } + if hdr.DstAddr != dst { + t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst) + } +} + +func TestRoundTripV2IPv6(t *testing.T) { + src := netip.MustParseAddrPort("[2001:db8:cafe::1]:40000") + dst := netip.MustParseAddrPort("[2001:db8:beef::2]:8443") + + reader, writer := pipeWithDeadline(t) + + go func() { + WriteV2(writer, src, dst) + }() + + hdr, err := Parse(reader, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("Parse: %v", err) + } + + if hdr.SrcAddr != src { + t.Fatalf("src = %s, want %s", hdr.SrcAddr, src) + } + if hdr.DstAddr != dst { + t.Fatalf("dst = %s, want %s", hdr.DstAddr, dst) + } +} + +func TestParseGarbageFirstByte(t *testing.T) { + reader, writer := pipeWithDeadline(t) + go func() { + writer.Write([]byte{0xFF, 0x00, 0x01}) + }() + _, err := Parse(reader, time.Now().Add(2*time.Second)) + if err == nil { + t.Fatal("expected error for garbage first byte") + } +} diff --git a/internal/server/server.go b/internal/server/server.go index c8712ab..c22396d 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -14,6 +14,7 @@ import ( "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" + "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/kyle/mc-proxy/internal/sni" ) @@ -266,6 +267,20 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat } addr := addrPort.Addr() + // Parse PROXY protocol header if enabled on this listener. + if ls.ProxyProtocol { + hdr, err := proxyproto.Parse(conn, time.Now().Add(5*time.Second)) + if err != nil { + s.logger.Debug("PROXY protocol parse failed", "addr", addr, "error", err) + return + } + if hdr.Command == proxyproto.CommandProxy { + addr = hdr.SrcAddr.Addr() + addrPort = hdr.SrcAddr + s.logger.Debug("PROXY protocol", "real_addr", addr, "peer_addr", remoteAddr) + } + } + if s.fw.Blocked(addr) { s.logger.Debug("blocked by firewall", "addr", addr) return @@ -289,12 +304,12 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat s.logger.Error("L7 mode not yet implemented", "hostname", hostname) return default: - s.handleL4(ctx, conn, ls, addr, hostname, route, peeked) + s.handleL4(ctx, conn, addr, addrPort, hostname, route, peeked) } } // handleL4 handles an L4 (passthrough) connection. -func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, addr netip.Addr, hostname string, route RouteInfo, peeked []byte) { +func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) { backendConn, err := net.DialTimeout("tcp", route.Backend, s.cfg.Proxy.ConnectTimeout.Duration) if err != nil { s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err) @@ -302,6 +317,15 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, _ *ListenerState, } defer backendConn.Close() + // Send PROXY protocol v2 header to backend if configured. + if route.SendProxyProtocol { + backendAddrPort, _ := netip.ParseAddrPort(backendConn.RemoteAddr().String()) + if err := proxyproto.WriteV2(backendConn, clientAddrPort, backendAddrPort); err != nil { + s.logger.Error("writing PROXY protocol header", "hostname", hostname, "error", err) + return + } + } + s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", route.Backend) result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index e93375d..8ca9577 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1,17 +1,20 @@ package server import ( + "bytes" "context" "encoding/binary" "io" "log/slog" "net" + "net/netip" "sync" "testing" "time" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" + "git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" ) // l4Route creates a RouteInfo for an L4 passthrough route. @@ -696,6 +699,345 @@ func TestListenerStateRoutes(t *testing.T) { } } +func TestProxyProtocolReceive(t *testing.T) { + // Backend echoes everything back. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + go echoServer(t, backendLn) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + ProxyProtocol: true, + Routes: map[string]RouteInfo{ + "echo.test": l4Route(backendLn.Addr().String()), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + // Send PROXY v2 header followed by TLS ClientHello. + var ppBuf bytes.Buffer + proxyproto.WriteV2(&ppBuf, + netip.MustParseAddrPort("203.0.113.50:12345"), + netip.MustParseAddrPort("198.51.100.1:443"), + ) + conn.Write(ppBuf.Bytes()) + + hello := buildClientHello("echo.test") + conn.Write(hello) + + // Backend should echo the ClientHello back (not the PROXY header). + echoed := make([]byte, len(hello)) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.ReadFull(conn, echoed); err != nil { + t.Fatalf("read echoed data: %v", err) + } +} + +func TestProxyProtocolReceiveGarbage(t *testing.T) { + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + ProxyProtocol: true, + Routes: map[string]RouteInfo{ + "echo.test": l4Route("127.0.0.1:1"), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + // Send garbage instead of a valid PROXY header. + conn.Write([]byte("NOT A PROXY HEADER\r\n")) + + // Connection should be closed. + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed for invalid PROXY header") + } +} + +func TestProxyProtocolSend(t *testing.T) { + // Backend that captures the first bytes it receives. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + received := make(chan []byte, 1) + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + // Read all available data; the proxy sends PROXY header + ClientHello. + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + var all []byte + buf := make([]byte, 4096) + for { + n, err := conn.Read(buf) + all = append(all, buf[:n]...) + if err != nil { + break + } + // We expect at least pp header (28) + some TLS data. + if len(all) >= 28+5 { + break + } + } + received <- all + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]RouteInfo{ + "pp.test": { + Backend: backendLn.Addr().String(), + Mode: "l4", + SendProxyProtocol: true, + }, + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("pp.test") + conn.Write(hello) + + // The backend should receive: PROXY v2 header + ClientHello. + select { + case data := <-received: + // PROXY v2 IPv4 header is 28 bytes (12 sig + 1 ver/cmd + 1 fam + 2 len + 12 addrs). + if len(data) < 28 { + t.Fatalf("backend received only %d bytes, want at least 28", len(data)) + } + // Check PROXY v2 signature. + sig := [12]byte{0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51, 0x55, 0x49, 0x54, 0x0A} + if [12]byte(data[:12]) != sig { + t.Fatal("backend data does not start with PROXY v2 signature") + } + // Verify TLS record header follows the PROXY header. + ppLen := 28 // v2 IPv4 + if len(data) <= ppLen { + t.Fatalf("backend received only PROXY header, no TLS data") + } + if data[ppLen] != 0x16 { + t.Fatalf("expected TLS record (0x16) after PROXY header, got 0x%02x", data[ppLen]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for backend data") + } +} + +func TestProxyProtocolNotSent(t *testing.T) { + // Backend captures first bytes. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + received := make(chan []byte, 1) + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 4096) + n, _ := conn.Read(buf) + received <- buf[:n] + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]RouteInfo{ + "nopp.test": l4Route(backendLn.Addr().String()), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("nopp.test") + conn.Write(hello) + + select { + case data := <-received: + // First byte should be TLS record header, not PROXY signature. + if data[0] != 0x16 { + t.Fatalf("expected TLS record (0x16) as first byte, got 0x%02x", data[0]) + } + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for backend data") + } +} + +func TestProxyProtocolFirewallUsesRealIP(t *testing.T) { + // Backend that should never be reached. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + reached := make(chan struct{}, 1) + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + conn.Close() + reached <- struct{}{} + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + // Block 203.0.113.50 (the "real" client IP from PROXY header). + // 127.0.0.1 (the actual TCP peer) is NOT blocked. + fw, err := firewall.New("", []string{"203.0.113.50"}, nil, nil, 0, 0) + if err != nil { + t.Fatalf("creating firewall: %v", err) + } + + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(cfg, fw, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + ProxyProtocol: true, + Routes: map[string]RouteInfo{ + "blocked.test": l4Route(backendLn.Addr().String()), + }, + }, + }, logger, "test") + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Run(ctx) + }() + time.Sleep(50 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + // Send PROXY v2 with the blocked real IP. + var ppBuf bytes.Buffer + proxyproto.WriteV2(&ppBuf, + netip.MustParseAddrPort("203.0.113.50:12345"), + netip.MustParseAddrPort("198.51.100.1:443"), + ) + conn.Write(ppBuf.Bytes()) + conn.Write(buildClientHello("blocked.test")) + + // Connection should be closed (firewall blocks real IP). + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed by firewall") + } + + // Backend should not have been reached. + select { + case <-reached: + t.Fatal("backend was reached despite firewall blocking real IP") + case <-time.After(200 * time.Millisecond): + // Expected. + } + + cancel() + wg.Wait() +} + // --- ClientHello builder helpers (mirrors internal/sni test helpers) --- func buildClientHello(serverName string) []byte {