Files
mc-proxy/internal/proxy/proxy_test.go
Kyle Isom c7024dcdf0 Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking),
per-listener route tables, bidirectional TCP relay with half-close
propagation, and a gRPC admin API (routes, firewall, status) with
TLS/mTLS support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-17 02:56:24 -07:00

260 lines
5.7 KiB
Go

package proxy
import (
"bytes"
"context"
"crypto/rand"
"io"
"net"
"testing"
"time"
)
func TestRelayBasic(t *testing.T) {
// Set up a TCP listener to act as the backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn.Close()
peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client")
backendData := []byte("data from backend")
// Backend goroutine: accept, read peeked+client data, send response, close.
backendDone := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
// Read everything the backend receives.
received, _ := io.ReadAll(conn)
backendDone <- received
// This won't work since ReadAll waits for EOF.
// Instead, restructure: read expected bytes, write response, close write.
}()
// Restructure: use a more controlled flow.
backendLn.Close()
// Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn2.Close()
go func() {
conn, err := backendLn2.Accept()
if err != nil {
return
}
defer conn.Close()
// Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData))
n, _ := io.ReadFull(conn, buf)
backendDone <- buf[:n]
// Send response.
conn.Write(backendData)
// Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Dial the backend.
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
if err != nil {
t.Fatalf("dial backend: %v", err)
}
// Create a client-side TCP pair.
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer clientLn.Close()
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil {
t.Fatalf("dial client: %v", err)
}
serverSideClient, err := clientLn.Accept()
if err != nil {
t.Fatalf("accept client: %v", err)
}
// Client sends data then closes write.
go func() {
clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Run relay.
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
// Verify backend received peeked + client data.
received := <-backendDone
expected := append(peeked, clientData...)
if !bytes.Equal(received, expected) {
t.Fatalf("backend received %q, want %q", received, expected)
}
// Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData)
}
if result.ClientBytes != int64(len(clientData)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
}
if result.BackendBytes != int64(len(backendData)) {
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
}
}
func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP.
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
elapsed := time.Since(start)
// Should return due to idle timeout.
if err == nil {
t.Fatal("expected error from idle timeout")
}
if elapsed > 2*time.Second {
t.Fatalf("relay took %v, expected ~100ms", elapsed)
}
}
func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, clientB, backendA, nil, time.Minute)
close(done)
}()
// Cancel after a short delay.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
// OK
case <-time.After(2 * time.Second):
t.Fatal("relay did not return after context cancel")
}
_ = backendB // keep reference
}
func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
// 1 MB of random data.
data := make([]byte, 1<<20)
if _, err := rand.Read(data); err != nil {
t.Fatalf("rand read: %v", err)
}
go func() {
clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Backend reads and echoes chunks, then closes when client EOF arrives.
go func() {
buf := make([]byte, 32*1024)
for {
n, err := backendB.Read(buf)
if n > 0 {
backendB.Write(buf[:n])
}
if err != nil {
break
}
}
if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
if result.ClientBytes != int64(len(data)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
}
}
// tcpPair returns two connected TCP connections.
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var serverConn net.Conn
done := make(chan struct{})
go func() {
serverConn, _ = ln.Accept()
close(done)
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
<-done
return clientConn, serverConn
}