package proxy import ( "bytes" "context" "crypto/rand" "io" "net" "testing" "time" ) func TestRelayBasic(t *testing.T) { // Set up a TCP listener to act as the backend. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer backendLn.Close() peeked := []byte("peeked-hello-bytes") clientData := []byte("data from client") backendData := []byte("data from backend") // Backend goroutine: accept, read peeked+client data, send response, close. backendDone := make(chan []byte, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } defer conn.Close() // Read everything the backend receives. received, _ := io.ReadAll(conn) backendDone <- received // This won't work since ReadAll waits for EOF. // Instead, restructure: read expected bytes, write response, close write. }() // Restructure: use a more controlled flow. backendLn.Close() // Use a real TCP pair for proper half-close. backendLn2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer backendLn2.Close() go func() { conn, err := backendLn2.Accept() if err != nil { return } defer conn.Close() // Read peeked + client data. buf := make([]byte, len(peeked)+len(clientData)) n, _ := io.ReadFull(conn, buf) backendDone <- buf[:n] // Send response. conn.Write(backendData) // Close write side to signal EOF. if tc, ok := conn.(*net.TCPConn); ok { tc.CloseWrite() } }() // Dial the backend. backendConn, err := net.Dial("tcp", backendLn2.Addr().String()) if err != nil { t.Fatalf("dial backend: %v", err) } // Create a client-side TCP pair. clientLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer clientLn.Close() clientConn, err := net.Dial("tcp", clientLn.Addr().String()) if err != nil { t.Fatalf("dial client: %v", err) } serverSideClient, err := clientLn.Accept() if err != nil { t.Fatalf("accept client: %v", err) } // Client sends data then closes write. go func() { clientConn.Write(clientData) if tc, ok := clientConn.(*net.TCPConn); ok { tc.CloseWrite() } }() // Run relay. result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second) if err != nil { t.Fatalf("relay error: %v", err) } // Verify backend received peeked + client data. received := <-backendDone expected := append(peeked, clientData...) if !bytes.Equal(received, expected) { t.Fatalf("backend received %q, want %q", received, expected) } // Verify client received backend data. clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) clientReceived, _ := io.ReadAll(clientConn) if !bytes.Equal(clientReceived, backendData) { t.Fatalf("client received %q, want %q", clientReceived, backendData) } if result.ClientBytes != int64(len(clientData)) { t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData)) } if result.BackendBytes != int64(len(backendData)) { t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData)) } } func TestRelayIdleTimeout(t *testing.T) { // Two connected pairs via TCP. clientA, clientB := tcpPair(t) defer clientA.Close() defer clientB.Close() backendA, backendB := tcpPair(t) defer backendA.Close() defer backendB.Close() start := time.Now() _, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond) elapsed := time.Since(start) // Should return due to idle timeout. if err == nil { t.Fatal("expected error from idle timeout") } if elapsed > 2*time.Second { t.Fatalf("relay took %v, expected ~100ms", elapsed) } } func TestRelayContextCancel(t *testing.T) { clientA, clientB := tcpPair(t) defer clientA.Close() defer clientB.Close() backendA, backendB := tcpPair(t) defer backendA.Close() defer backendB.Close() ctx, cancel := context.WithCancel(context.Background()) done := make(chan struct{}) go func() { Relay(ctx, clientB, backendA, nil, time.Minute) close(done) }() // Cancel after a short delay. time.Sleep(50 * time.Millisecond) cancel() select { case <-done: // OK case <-time.After(2 * time.Second): t.Fatal("relay did not return after context cancel") } _ = backendB // keep reference } func TestRelayLargeTransfer(t *testing.T) { clientA, clientB := tcpPair(t) defer clientA.Close() defer clientB.Close() backendA, backendB := tcpPair(t) defer backendA.Close() defer backendB.Close() // 1 MB of random data. data := make([]byte, 1<<20) if _, err := rand.Read(data); err != nil { t.Fatalf("rand read: %v", err) } go func() { clientA.Write(data) if tc, ok := clientA.(*net.TCPConn); ok { tc.CloseWrite() } }() // Backend reads and echoes chunks, then closes when client EOF arrives. go func() { buf := make([]byte, 32*1024) for { n, err := backendB.Read(buf) if n > 0 { backendB.Write(buf[:n]) } if err != nil { break } } if tc, ok := backendB.(*net.TCPConn); ok { tc.CloseWrite() } }() result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second) if err != nil { t.Fatalf("relay error: %v", err) } if result.ClientBytes != int64(len(data)) { t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data)) } } // tcpPair returns two connected TCP connections. func tcpPair(t *testing.T) (net.Conn, net.Conn) { t.Helper() ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } defer ln.Close() var serverConn net.Conn done := make(chan struct{}) go func() { serverConn, _ = ln.Accept() close(done) }() clientConn, err := net.Dial("tcp", ln.Addr().String()) if err != nil { t.Fatalf("dial: %v", err) } <-done return clientConn, serverConn }