Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
260 lines
5.7 KiB
Go
260 lines
5.7 KiB
Go
package proxy
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"io"
|
|
"net"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func TestRelayBasic(t *testing.T) {
|
|
// Set up a TCP listener to act as the backend.
|
|
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
defer backendLn.Close()
|
|
|
|
peeked := []byte("peeked-hello-bytes")
|
|
clientData := []byte("data from client")
|
|
backendData := []byte("data from backend")
|
|
|
|
// Backend goroutine: accept, read peeked+client data, send response, close.
|
|
backendDone := make(chan []byte, 1)
|
|
go func() {
|
|
conn, err := backendLn.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Read everything the backend receives.
|
|
received, _ := io.ReadAll(conn)
|
|
backendDone <- received
|
|
|
|
// This won't work since ReadAll waits for EOF.
|
|
// Instead, restructure: read expected bytes, write response, close write.
|
|
}()
|
|
|
|
// Restructure: use a more controlled flow.
|
|
backendLn.Close()
|
|
|
|
// Use a real TCP pair for proper half-close.
|
|
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
defer backendLn2.Close()
|
|
|
|
go func() {
|
|
conn, err := backendLn2.Accept()
|
|
if err != nil {
|
|
return
|
|
}
|
|
defer conn.Close()
|
|
|
|
// Read peeked + client data.
|
|
buf := make([]byte, len(peeked)+len(clientData))
|
|
n, _ := io.ReadFull(conn, buf)
|
|
backendDone <- buf[:n]
|
|
|
|
// Send response.
|
|
conn.Write(backendData)
|
|
|
|
// Close write side to signal EOF.
|
|
if tc, ok := conn.(*net.TCPConn); ok {
|
|
tc.CloseWrite()
|
|
}
|
|
}()
|
|
|
|
// Dial the backend.
|
|
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("dial backend: %v", err)
|
|
}
|
|
|
|
// Create a client-side TCP pair.
|
|
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
defer clientLn.Close()
|
|
|
|
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("dial client: %v", err)
|
|
}
|
|
serverSideClient, err := clientLn.Accept()
|
|
if err != nil {
|
|
t.Fatalf("accept client: %v", err)
|
|
}
|
|
|
|
// Client sends data then closes write.
|
|
go func() {
|
|
clientConn.Write(clientData)
|
|
if tc, ok := clientConn.(*net.TCPConn); ok {
|
|
tc.CloseWrite()
|
|
}
|
|
}()
|
|
|
|
// Run relay.
|
|
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("relay error: %v", err)
|
|
}
|
|
|
|
// Verify backend received peeked + client data.
|
|
received := <-backendDone
|
|
expected := append(peeked, clientData...)
|
|
if !bytes.Equal(received, expected) {
|
|
t.Fatalf("backend received %q, want %q", received, expected)
|
|
}
|
|
|
|
// Verify client received backend data.
|
|
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
clientReceived, _ := io.ReadAll(clientConn)
|
|
if !bytes.Equal(clientReceived, backendData) {
|
|
t.Fatalf("client received %q, want %q", clientReceived, backendData)
|
|
}
|
|
|
|
if result.ClientBytes != int64(len(clientData)) {
|
|
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
|
|
}
|
|
if result.BackendBytes != int64(len(backendData)) {
|
|
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
|
|
}
|
|
}
|
|
|
|
func TestRelayIdleTimeout(t *testing.T) {
|
|
// Two connected pairs via TCP.
|
|
clientA, clientB := tcpPair(t)
|
|
defer clientA.Close()
|
|
defer clientB.Close()
|
|
|
|
backendA, backendB := tcpPair(t)
|
|
defer backendA.Close()
|
|
defer backendB.Close()
|
|
|
|
start := time.Now()
|
|
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
|
|
elapsed := time.Since(start)
|
|
|
|
// Should return due to idle timeout.
|
|
if err == nil {
|
|
t.Fatal("expected error from idle timeout")
|
|
}
|
|
|
|
if elapsed > 2*time.Second {
|
|
t.Fatalf("relay took %v, expected ~100ms", elapsed)
|
|
}
|
|
}
|
|
|
|
func TestRelayContextCancel(t *testing.T) {
|
|
clientA, clientB := tcpPair(t)
|
|
defer clientA.Close()
|
|
defer clientB.Close()
|
|
|
|
backendA, backendB := tcpPair(t)
|
|
defer backendA.Close()
|
|
defer backendB.Close()
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
Relay(ctx, clientB, backendA, nil, time.Minute)
|
|
close(done)
|
|
}()
|
|
|
|
// Cancel after a short delay.
|
|
time.Sleep(50 * time.Millisecond)
|
|
cancel()
|
|
|
|
select {
|
|
case <-done:
|
|
// OK
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("relay did not return after context cancel")
|
|
}
|
|
|
|
_ = backendB // keep reference
|
|
}
|
|
|
|
func TestRelayLargeTransfer(t *testing.T) {
|
|
clientA, clientB := tcpPair(t)
|
|
defer clientA.Close()
|
|
defer clientB.Close()
|
|
|
|
backendA, backendB := tcpPair(t)
|
|
defer backendA.Close()
|
|
defer backendB.Close()
|
|
|
|
// 1 MB of random data.
|
|
data := make([]byte, 1<<20)
|
|
if _, err := rand.Read(data); err != nil {
|
|
t.Fatalf("rand read: %v", err)
|
|
}
|
|
|
|
go func() {
|
|
clientA.Write(data)
|
|
if tc, ok := clientA.(*net.TCPConn); ok {
|
|
tc.CloseWrite()
|
|
}
|
|
}()
|
|
|
|
// Backend reads and echoes chunks, then closes when client EOF arrives.
|
|
go func() {
|
|
buf := make([]byte, 32*1024)
|
|
for {
|
|
n, err := backendB.Read(buf)
|
|
if n > 0 {
|
|
backendB.Write(buf[:n])
|
|
}
|
|
if err != nil {
|
|
break
|
|
}
|
|
}
|
|
if tc, ok := backendB.(*net.TCPConn); ok {
|
|
tc.CloseWrite()
|
|
}
|
|
}()
|
|
|
|
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
|
|
if err != nil {
|
|
t.Fatalf("relay error: %v", err)
|
|
}
|
|
|
|
if result.ClientBytes != int64(len(data)) {
|
|
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
|
|
}
|
|
}
|
|
|
|
// tcpPair returns two connected TCP connections.
|
|
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
|
|
t.Helper()
|
|
|
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
|
if err != nil {
|
|
t.Fatalf("listen: %v", err)
|
|
}
|
|
defer ln.Close()
|
|
|
|
var serverConn net.Conn
|
|
done := make(chan struct{})
|
|
go func() {
|
|
serverConn, _ = ln.Accept()
|
|
close(done)
|
|
}()
|
|
|
|
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
|
if err != nil {
|
|
t.Fatalf("dial: %v", err)
|
|
}
|
|
|
|
<-done
|
|
return clientConn, serverConn
|
|
}
|