Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
105
internal/proxy/proxy.go
Normal file
105
internal/proxy/proxy.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Result holds the outcome of a relay operation.
|
||||
type Result struct {
|
||||
ClientBytes int64 // bytes sent from client to backend
|
||||
BackendBytes int64 // bytes sent from backend to client
|
||||
}
|
||||
|
||||
// Relay performs bidirectional byte copying between client and backend.
|
||||
// The peeked bytes (the TLS ClientHello) are written to the backend first.
|
||||
// Relay blocks until both directions are done or ctx is cancelled.
|
||||
func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) {
|
||||
// Forward the buffered ClientHello to the backend.
|
||||
if len(peeked) > 0 {
|
||||
if _, err := backend.Write(peeked); err != nil {
|
||||
return Result{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel context closes both connections to unblock copy goroutines.
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
client.Close()
|
||||
backend.Close()
|
||||
}()
|
||||
|
||||
var (
|
||||
result Result
|
||||
wg sync.WaitGroup
|
||||
errC2B error
|
||||
errB2C error
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
// client → backend
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
|
||||
// Half-close backend's write side.
|
||||
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
|
||||
hc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// backend → client
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
|
||||
// Half-close client's write side.
|
||||
if hc, ok := client.(interface{ CloseWrite() error }); ok {
|
||||
hc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// If context was cancelled, that's the primary error.
|
||||
if ctx.Err() != nil {
|
||||
return result, ctx.Err()
|
||||
}
|
||||
|
||||
// Return the first meaningful error, if any.
|
||||
if errC2B != nil {
|
||||
return result, errC2B
|
||||
}
|
||||
return result, errB2C
|
||||
}
|
||||
|
||||
// copyWithIdleTimeout copies from src to dst, resetting the idle deadline
|
||||
// on each successful read.
|
||||
func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) {
|
||||
buf := make([]byte, 32*1024)
|
||||
var total int64
|
||||
|
||||
for {
|
||||
src.SetReadDeadline(time.Now().Add(idleTimeout))
|
||||
nr, readErr := src.Read(buf)
|
||||
if nr > 0 {
|
||||
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
|
||||
nw, writeErr := dst.Write(buf[:nr])
|
||||
total += int64(nw)
|
||||
if writeErr != nil {
|
||||
return total, writeErr
|
||||
}
|
||||
}
|
||||
if readErr != nil {
|
||||
if readErr == io.EOF {
|
||||
return total, nil
|
||||
}
|
||||
return total, readErr
|
||||
}
|
||||
}
|
||||
}
|
||||
259
internal/proxy/proxy_test.go
Normal file
259
internal/proxy/proxy_test.go
Normal file
@@ -0,0 +1,259 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRelayBasic(t *testing.T) {
|
||||
// Set up a TCP listener to act as the backend.
|
||||
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer backendLn.Close()
|
||||
|
||||
peeked := []byte("peeked-hello-bytes")
|
||||
clientData := []byte("data from client")
|
||||
backendData := []byte("data from backend")
|
||||
|
||||
// Backend goroutine: accept, read peeked+client data, send response, close.
|
||||
backendDone := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, err := backendLn.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read everything the backend receives.
|
||||
received, _ := io.ReadAll(conn)
|
||||
backendDone <- received
|
||||
|
||||
// This won't work since ReadAll waits for EOF.
|
||||
// Instead, restructure: read expected bytes, write response, close write.
|
||||
}()
|
||||
|
||||
// Restructure: use a more controlled flow.
|
||||
backendLn.Close()
|
||||
|
||||
// Use a real TCP pair for proper half-close.
|
||||
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer backendLn2.Close()
|
||||
|
||||
go func() {
|
||||
conn, err := backendLn2.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Read peeked + client data.
|
||||
buf := make([]byte, len(peeked)+len(clientData))
|
||||
n, _ := io.ReadFull(conn, buf)
|
||||
backendDone <- buf[:n]
|
||||
|
||||
// Send response.
|
||||
conn.Write(backendData)
|
||||
|
||||
// Close write side to signal EOF.
|
||||
if tc, ok := conn.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Dial the backend.
|
||||
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial backend: %v", err)
|
||||
}
|
||||
|
||||
// Create a client-side TCP pair.
|
||||
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer clientLn.Close()
|
||||
|
||||
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial client: %v", err)
|
||||
}
|
||||
serverSideClient, err := clientLn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("accept client: %v", err)
|
||||
}
|
||||
|
||||
// Client sends data then closes write.
|
||||
go func() {
|
||||
clientConn.Write(clientData)
|
||||
if tc, ok := clientConn.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Run relay.
|
||||
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("relay error: %v", err)
|
||||
}
|
||||
|
||||
// Verify backend received peeked + client data.
|
||||
received := <-backendDone
|
||||
expected := append(peeked, clientData...)
|
||||
if !bytes.Equal(received, expected) {
|
||||
t.Fatalf("backend received %q, want %q", received, expected)
|
||||
}
|
||||
|
||||
// Verify client received backend data.
|
||||
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||
clientReceived, _ := io.ReadAll(clientConn)
|
||||
if !bytes.Equal(clientReceived, backendData) {
|
||||
t.Fatalf("client received %q, want %q", clientReceived, backendData)
|
||||
}
|
||||
|
||||
if result.ClientBytes != int64(len(clientData)) {
|
||||
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
|
||||
}
|
||||
if result.BackendBytes != int64(len(backendData)) {
|
||||
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayIdleTimeout(t *testing.T) {
|
||||
// Two connected pairs via TCP.
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
start := time.Now()
|
||||
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
// Should return due to idle timeout.
|
||||
if err == nil {
|
||||
t.Fatal("expected error from idle timeout")
|
||||
}
|
||||
|
||||
if elapsed > 2*time.Second {
|
||||
t.Fatalf("relay took %v, expected ~100ms", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelayContextCancel(t *testing.T) {
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
Relay(ctx, clientB, backendA, nil, time.Minute)
|
||||
close(done)
|
||||
}()
|
||||
|
||||
// Cancel after a short delay.
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// OK
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("relay did not return after context cancel")
|
||||
}
|
||||
|
||||
_ = backendB // keep reference
|
||||
}
|
||||
|
||||
func TestRelayLargeTransfer(t *testing.T) {
|
||||
clientA, clientB := tcpPair(t)
|
||||
defer clientA.Close()
|
||||
defer clientB.Close()
|
||||
|
||||
backendA, backendB := tcpPair(t)
|
||||
defer backendA.Close()
|
||||
defer backendB.Close()
|
||||
|
||||
// 1 MB of random data.
|
||||
data := make([]byte, 1<<20)
|
||||
if _, err := rand.Read(data); err != nil {
|
||||
t.Fatalf("rand read: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
clientA.Write(data)
|
||||
if tc, ok := clientA.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
// Backend reads and echoes chunks, then closes when client EOF arrives.
|
||||
go func() {
|
||||
buf := make([]byte, 32*1024)
|
||||
for {
|
||||
n, err := backendB.Read(buf)
|
||||
if n > 0 {
|
||||
backendB.Write(buf[:n])
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if tc, ok := backendB.(*net.TCPConn); ok {
|
||||
tc.CloseWrite()
|
||||
}
|
||||
}()
|
||||
|
||||
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("relay error: %v", err)
|
||||
}
|
||||
|
||||
if result.ClientBytes != int64(len(data)) {
|
||||
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
|
||||
}
|
||||
}
|
||||
|
||||
// tcpPair returns two connected TCP connections.
|
||||
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
|
||||
t.Helper()
|
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
var serverConn net.Conn
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
serverConn, _ = ln.Accept()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
clientConn, err := net.Dial("tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
<-done
|
||||
return clientConn, serverConn
|
||||
}
|
||||
Reference in New Issue
Block a user