Initial implementation of mc-proxy

Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking),
per-listener route tables, bidirectional TCP relay with half-close
propagation, and a gRPC admin API (routes, firewall, status) with
TLS/mTLS support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 02:56:24 -07:00
commit c7024dcdf0
23 changed files with 2693 additions and 0 deletions

105
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,105 @@
package proxy
import (
"context"
"io"
"net"
"sync"
"time"
)
// Result holds the outcome of a relay operation.
type Result struct {
ClientBytes int64 // bytes sent from client to backend
BackendBytes int64 // bytes sent from backend to client
}
// Relay performs bidirectional byte copying between client and backend.
// The peeked bytes (the TLS ClientHello) are written to the backend first.
// Relay blocks until both directions are done or ctx is cancelled.
func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) {
// Forward the buffered ClientHello to the backend.
if len(peeked) > 0 {
if _, err := backend.Write(peeked); err != nil {
return Result{}, err
}
}
// Cancel context closes both connections to unblock copy goroutines.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
client.Close()
backend.Close()
}()
var (
result Result
wg sync.WaitGroup
errC2B error
errB2C error
)
wg.Add(2)
// client → backend
go func() {
defer wg.Done()
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
// backend → client
go func() {
defer wg.Done()
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
wg.Wait()
// If context was cancelled, that's the primary error.
if ctx.Err() != nil {
return result, ctx.Err()
}
// Return the first meaningful error, if any.
if errC2B != nil {
return result, errC2B
}
return result, errB2C
}
// copyWithIdleTimeout copies from src to dst, resetting the idle deadline
// on each successful read.
func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) {
buf := make([]byte, 32*1024)
var total int64
for {
src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr])
total += int64(nw)
if writeErr != nil {
return total, writeErr
}
}
if readErr != nil {
if readErr == io.EOF {
return total, nil
}
return total, readErr
}
}
}

View File

@@ -0,0 +1,259 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"io"
"net"
"testing"
"time"
)
func TestRelayBasic(t *testing.T) {
// Set up a TCP listener to act as the backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn.Close()
peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client")
backendData := []byte("data from backend")
// Backend goroutine: accept, read peeked+client data, send response, close.
backendDone := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
// Read everything the backend receives.
received, _ := io.ReadAll(conn)
backendDone <- received
// This won't work since ReadAll waits for EOF.
// Instead, restructure: read expected bytes, write response, close write.
}()
// Restructure: use a more controlled flow.
backendLn.Close()
// Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn2.Close()
go func() {
conn, err := backendLn2.Accept()
if err != nil {
return
}
defer conn.Close()
// Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData))
n, _ := io.ReadFull(conn, buf)
backendDone <- buf[:n]
// Send response.
conn.Write(backendData)
// Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Dial the backend.
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
if err != nil {
t.Fatalf("dial backend: %v", err)
}
// Create a client-side TCP pair.
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer clientLn.Close()
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil {
t.Fatalf("dial client: %v", err)
}
serverSideClient, err := clientLn.Accept()
if err != nil {
t.Fatalf("accept client: %v", err)
}
// Client sends data then closes write.
go func() {
clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Run relay.
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
// Verify backend received peeked + client data.
received := <-backendDone
expected := append(peeked, clientData...)
if !bytes.Equal(received, expected) {
t.Fatalf("backend received %q, want %q", received, expected)
}
// Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData)
}
if result.ClientBytes != int64(len(clientData)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
}
if result.BackendBytes != int64(len(backendData)) {
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
}
}
func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP.
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
elapsed := time.Since(start)
// Should return due to idle timeout.
if err == nil {
t.Fatal("expected error from idle timeout")
}
if elapsed > 2*time.Second {
t.Fatalf("relay took %v, expected ~100ms", elapsed)
}
}
func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, clientB, backendA, nil, time.Minute)
close(done)
}()
// Cancel after a short delay.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
// OK
case <-time.After(2 * time.Second):
t.Fatal("relay did not return after context cancel")
}
_ = backendB // keep reference
}
func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
// 1 MB of random data.
data := make([]byte, 1<<20)
if _, err := rand.Read(data); err != nil {
t.Fatalf("rand read: %v", err)
}
go func() {
clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Backend reads and echoes chunks, then closes when client EOF arrives.
go func() {
buf := make([]byte, 32*1024)
for {
n, err := backendB.Read(buf)
if n > 0 {
backendB.Write(buf[:n])
}
if err != nil {
break
}
}
if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
if result.ClientBytes != int64(len(data)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
}
}
// tcpPair returns two connected TCP connections.
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var serverConn net.Conn
done := make(chan struct{})
go func() {
serverConn, _ = ln.Accept()
close(done)
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
<-done
return clientConn, serverConn
}