Add documentation, Docker setup, and tests for server and gRPC packages

Rewrite README with project overview and quick start. Add RUNBOOK with
operational procedures and incident playbooks. Fix Dockerfile for Go 1.25
with version injection. Add docker-compose.yml. Clean up golangci.yaml
for mc-proxy. Add server tests (10) covering the full proxy pipeline with
TCP echo backends, and grpcserver tests (13) covering all admin API RPCs
with bufconn and write-through DB verification.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 11:24:35 -07:00
parent f1e9834bd3
commit e84093b7fb
8 changed files with 1688 additions and 23 deletions

View File

@@ -0,0 +1,746 @@
package server
import (
"context"
"encoding/binary"
"io"
"log/slog"
"net"
"sync"
"testing"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
)
// echoServer accepts one connection, copies everything back, then closes.
func echoServer(t *testing.T, ln net.Listener) {
t.Helper()
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
io.Copy(conn, conn)
}
// newTestServer creates a Server with the given listener data and no firewall rules.
func newTestServer(t *testing.T, listeners []ListenerData) *Server {
t.Helper()
fw, err := firewall.New("", nil, nil, nil)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
return New(cfg, fw, listeners, logger, "test")
}
// startAndStop starts the server in a goroutine and returns a cancel function
// that shuts it down and waits for it to exit.
func startAndStop(t *testing.T, srv *Server) context.CancelFunc {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := srv.Run(ctx); err != nil {
t.Errorf("server.Run: %v", err)
}
}()
// Give the listeners a moment to bind.
time.Sleep(50 * time.Millisecond)
return func() {
cancel()
wg.Wait()
}
}
func TestProxyRoundTrip(t *testing.T) {
// Start an echo backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go echoServer(t, backendLn)
// Pick a free port for the proxy listener.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// Connect through the proxy with a fake ClientHello.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
// Send some additional data.
payload := []byte("hello from client")
if _, err := conn.Write(payload); err != nil {
t.Fatalf("write payload: %v", err)
}
buf := make([]byte, len(payload))
if _, err := io.ReadFull(conn, buf); err != nil {
t.Fatalf("read echoed payload: %v", err)
}
if string(buf) != string(payload) {
t.Fatalf("got %q, want %q", buf, payload)
}
}
func TestNoRouteResets(t *testing.T) {
// Proxy listener with no routes for the requested hostname.
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"other.test": "127.0.0.1:1", // exists but won't match
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
// The proxy should close the connection (no route match).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed, but read succeeded")
}
}
func TestFirewallBlocks(t *testing.T) {
// Start a backend that should never be reached.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
reached := make(chan struct{}, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
conn.Close()
reached <- struct{}{}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
},
},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("echo.test")
conn.Write(hello)
// Connection should be closed (blocked by firewall).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
}
// Backend should not have been reached.
select {
case <-reached:
t.Fatal("backend was reached despite firewall block")
case <-time.After(200 * time.Millisecond):
// Expected.
}
cancel()
wg.Wait()
}
func TestNotTLSResets(t *testing.T) {
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{"x.test": "127.0.0.1:1"},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
// Send HTTP, not TLS.
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for non-TLS data")
}
}
func TestConnectionTracking(t *testing.T) {
// Backend that holds connections open until we close it.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
var backendConns []net.Conn
var mu sync.Mutex
go func() {
for {
conn, err := backendLn.Accept()
if err != nil {
return
}
mu.Lock()
backendConns = append(backendConns, conn)
mu.Unlock()
// Hold connection open, drain input.
go io.Copy(io.Discard, conn)
}
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"conn.test": backendLn.Addr().String(),
},
},
})
stop := startAndStop(t, srv)
defer stop()
if got := srv.TotalConnections(); got != 0 {
t.Fatalf("expected 0 connections before any clients, got %d", got)
}
// Open two connections through the proxy.
var clientConns []net.Conn
for i := range 2 {
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy %d: %v", i, err)
}
hello := buildClientHello("conn.test")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello %d: %v", i, err)
}
clientConns = append(clientConns, conn)
}
// Give connections time to be established.
time.Sleep(100 * time.Millisecond)
if got := srv.TotalConnections(); got != 2 {
t.Fatalf("expected 2 active connections, got %d", got)
}
// Close one client and its corresponding backend connection.
clientConns[0].Close()
mu.Lock()
if len(backendConns) > 0 {
backendConns[0].Close()
}
mu.Unlock()
// Wait for the relay goroutines to detect the close.
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
if srv.TotalConnections() == 1 {
break
}
time.Sleep(50 * time.Millisecond)
}
if got := srv.TotalConnections(); got != 1 {
t.Fatalf("expected 1 active connection after closing one, got %d", got)
}
// Clean up.
clientConns[1].Close()
mu.Lock()
for _, c := range backendConns {
c.Close()
}
mu.Unlock()
}
func TestMultipleListeners(t *testing.T) {
// Two backends.
backendA, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend A listen: %v", err)
}
defer backendA.Close()
backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend B listen: %v", err)
}
defer backendB.Close()
// Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
// Drain the incoming data, then write identity.
go io.Copy(io.Discard, conn)
conn.Write([]byte(id))
}
go serve(backendA, "A")
go serve(backendB, "B")
// Two proxy listeners, same hostname, different backends.
ln1, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 1: %v", err)
}
addr1 := ln1.Addr().String()
ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 2: %v", err)
}
addr2 := ln2.Addr().String()
ln2.Close()
srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]string{"svc.test": backendA.Addr().String()}},
{ID: 2, Addr: addr2, Routes: map[string]string{"svc.test": backendB.Addr().String()}},
})
stop := startAndStop(t, srv)
defer stop()
readID := func(proxyAddr string) string {
t.Helper()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err)
}
defer conn.Close()
hello := buildClientHello("svc.test")
conn.Write(hello)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we
// find the ID byte at the end.
var all []byte
for {
n, err := conn.Read(buf)
all = append(all, buf[:n]...)
if err != nil {
break
}
}
if len(all) == 0 {
t.Fatalf("no data from %s", proxyAddr)
}
// The ID is the last byte.
return string(all[len(all)-1:])
}
idA := readID(addr1)
idB := readID(addr2)
if idA != "A" {
t.Fatalf("listener 1: got backend %q, want A", idA)
}
if idB != "B" {
t.Fatalf("listener 2: got backend %q, want B", idB)
}
}
func TestCaseInsensitiveRouting(t *testing.T) {
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"echo.test": backendLn.Addr().String(),
},
},
})
stop := startAndStop(t, srv)
defer stop()
// SNI extraction lowercases the hostname, so "ECHO.TEST" should match
// the route for "echo.test".
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil {
t.Fatalf("write ClientHello: %v", err)
}
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
}
func TestBackendUnreachable(t *testing.T) {
// Find a port that nothing is listening on.
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
ID: 1,
Addr: proxyAddr,
Routes: map[string]string{
"dead.test": deadAddr,
},
},
})
stop := startAndStop(t, srv)
defer stop()
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("dead.test")
conn.Write(hello)
// Proxy should close the connection after failing to dial backend.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable")
}
}
func TestGracefulShutdown(t *testing.T) {
// Backend that holds the connection open.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
io.Copy(io.Discard, conn)
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil)
if err != nil {
t.Fatalf("creating firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 2 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
srv := New(cfg, fw, []ListenerData{
{ID: 1, Addr: proxyAddr, Routes: map[string]string{"hold.test": backendLn.Addr().String()}},
}, logger, "test")
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
go func() {
done <- srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
// Establish a connection that will be in-flight during shutdown.
conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second)
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
hello := buildClientHello("hold.test")
conn.Write(hello)
time.Sleep(50 * time.Millisecond)
// Trigger shutdown.
cancel()
// Server should exit within the shutdown timeout.
select {
case err := <-done:
if err != nil {
t.Fatalf("server.Run returned error: %v", err)
}
case <-time.After(5 * time.Second):
t.Fatal("server did not shut down within 5 seconds")
}
}
func TestListenerStateRoutes(t *testing.T) {
ls := &ListenerState{
ID: 1,
Addr: ":443",
routes: map[string]string{
"a.test": "127.0.0.1:1",
},
}
// AddRoute
if err := ls.AddRoute("b.test", "127.0.0.1:2"); err != nil {
t.Fatalf("AddRoute: %v", err)
}
// AddRoute duplicate
if err := ls.AddRoute("b.test", "127.0.0.1:3"); err == nil {
t.Fatal("expected error for duplicate route")
}
// Routes snapshot
routes := ls.Routes()
if len(routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(routes))
}
// RemoveRoute
if err := ls.RemoveRoute("a.test"); err != nil {
t.Fatalf("RemoveRoute: %v", err)
}
// RemoveRoute not found
if err := ls.RemoveRoute("nonexistent.test"); err == nil {
t.Fatal("expected error for removing nonexistent route")
}
routes = ls.Routes()
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes["b.test"] != "127.0.0.1:2" {
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"])
}
}
// --- ClientHello builder helpers (mirrors internal/sni test helpers) ---
func buildClientHello(serverName string) []byte {
return buildClientHelloWithExtensions(sniExtension(serverName))
}
func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte
hello = append(hello, 0x03, 0x03) // TLS 1.2
hello = append(hello, make([]byte, 32)...) // random
hello = append(hello, 0x00) // session ID: empty
hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites
hello = append(hello, 0x01, 0x00) // compression methods
if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
hello = append(hello, extensions...)
}
handshake := []byte{0x01, 0x00, 0x00, 0x00}
handshake[1] = byte(len(hello) >> 16)
handshake[2] = byte(len(hello) >> 8)
handshake[3] = byte(len(hello))
handshake = append(handshake, hello...)
record := []byte{0x16, 0x03, 0x01}
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
record = append(record, handshake...)
return record
}
func sniExtension(serverName string) []byte {
name := []byte(serverName)
var entry []byte
entry = append(entry, 0x00)
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
entry = append(entry, name...)
var list []byte
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
list = append(list, entry...)
var ext []byte
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
ext = append(ext, list...)
return ext
}