package server import ( "context" "encoding/binary" "io" "log/slog" "net" "sync" "testing" "time" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" ) // echoServer accepts one connection, copies everything back, then closes. func echoServer(t *testing.T, ln net.Listener) { t.Helper() conn, err := ln.Accept() if err != nil { return } defer conn.Close() io.Copy(conn, conn) } // newTestServer creates a Server with the given listener data and no firewall rules. func newTestServer(t *testing.T, listeners []ListenerData) *Server { t.Helper() fw, err := firewall.New("", nil, nil, nil) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) return New(cfg, fw, listeners, logger, "test") } // startAndStop starts the server in a goroutine and returns a cancel function // that shuts it down and waits for it to exit. func startAndStop(t *testing.T, srv *Server) context.CancelFunc { t.Helper() ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() if err := srv.Run(ctx); err != nil { t.Errorf("server.Run: %v", err) } }() // Give the listeners a moment to bind. time.Sleep(50 * time.Millisecond) return func() { cancel() wg.Wait() } } func TestProxyRoundTrip(t *testing.T) { // Start an echo backend. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go echoServer(t, backendLn) // Pick a free port for the proxy listener. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "echo.test": backendLn.Addr().String(), }, }, }) stop := startAndStop(t, srv) defer stop() // Connect through the proxy with a fake ClientHello. conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("echo.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } // The backend will echo our ClientHello back. Read it. echoed := make([]byte, len(hello)) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(conn, echoed); err != nil { t.Fatalf("read echoed data: %v", err) } // Send some additional data. payload := []byte("hello from client") if _, err := conn.Write(payload); err != nil { t.Fatalf("write payload: %v", err) } buf := make([]byte, len(payload)) if _, err := io.ReadFull(conn, buf); err != nil { t.Fatalf("read echoed payload: %v", err) } if string(buf) != string(payload) { t.Fatalf("got %q, want %q", buf, payload) } } func TestNoRouteResets(t *testing.T) { // Proxy listener with no routes for the requested hostname. proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "other.test": "127.0.0.1:1", // exists but won't match }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("unknown.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } // The proxy should close the connection (no route match). conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed, but read succeeded") } } func TestFirewallBlocks(t *testing.T) { // Start a backend that should never be reached. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() reached := make(chan struct{}, 1) go func() { conn, err := backendLn.Accept() if err != nil { return } conn.Close() reached <- struct{}{} }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() // Create a firewall that blocks 127.0.0.1 (the test client). fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(cfg, fw, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "echo.test": backendLn.Addr().String(), }, }, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) go func() { defer wg.Done() srv.Run(ctx) }() time.Sleep(50 * time.Millisecond) conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("echo.test") conn.Write(hello) // Connection should be closed (blocked by firewall). conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed by firewall") } // Backend should not have been reached. select { case <-reached: t.Fatal("backend was reached despite firewall block") case <-time.After(200 * time.Millisecond): // Expected. } cancel() wg.Wait() } func TestNotTLSResets(t *testing.T) { proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{"x.test": "127.0.0.1:1"}, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() // Send HTTP, not TLS. conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n")) conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed for non-TLS data") } } func TestConnectionTracking(t *testing.T) { // Backend that holds connections open until we close it. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() var backendConns []net.Conn var mu sync.Mutex go func() { for { conn, err := backendLn.Accept() if err != nil { return } mu.Lock() backendConns = append(backendConns, conn) mu.Unlock() // Hold connection open, drain input. go io.Copy(io.Discard, conn) } }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "conn.test": backendLn.Addr().String(), }, }, }) stop := startAndStop(t, srv) defer stop() if got := srv.TotalConnections(); got != 0 { t.Fatalf("expected 0 connections before any clients, got %d", got) } // Open two connections through the proxy. var clientConns []net.Conn for i := range 2 { conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy %d: %v", i, err) } hello := buildClientHello("conn.test") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello %d: %v", i, err) } clientConns = append(clientConns, conn) } // Give connections time to be established. time.Sleep(100 * time.Millisecond) if got := srv.TotalConnections(); got != 2 { t.Fatalf("expected 2 active connections, got %d", got) } // Close one client and its corresponding backend connection. clientConns[0].Close() mu.Lock() if len(backendConns) > 0 { backendConns[0].Close() } mu.Unlock() // Wait for the relay goroutines to detect the close. deadline := time.Now().Add(5 * time.Second) for time.Now().Before(deadline) { if srv.TotalConnections() == 1 { break } time.Sleep(50 * time.Millisecond) } if got := srv.TotalConnections(); got != 1 { t.Fatalf("expected 1 active connection after closing one, got %d", got) } // Clean up. clientConns[1].Close() mu.Lock() for _, c := range backendConns { c.Close() } mu.Unlock() } func TestMultipleListeners(t *testing.T) { // Two backends. backendA, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend A listen: %v", err) } defer backendA.Close() backendB, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend B listen: %v", err) } defer backendB.Close() // Each backend writes its identity and closes. serve := func(ln net.Listener, id string) { conn, err := ln.Accept() if err != nil { return } defer conn.Close() // Drain the incoming data, then write identity. go io.Copy(io.Discard, conn) conn.Write([]byte(id)) } go serve(backendA, "A") go serve(backendB, "B") // Two proxy listeners, same hostname, different backends. ln1, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port 1: %v", err) } addr1 := ln1.Addr().String() ln1.Close() ln2, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port 2: %v", err) } addr2 := ln2.Addr().String() ln2.Close() srv := newTestServer(t, []ListenerData{ {ID: 1, Addr: addr1, Routes: map[string]string{"svc.test": backendA.Addr().String()}}, {ID: 2, Addr: addr2, Routes: map[string]string{"svc.test": backendB.Addr().String()}}, }) stop := startAndStop(t, srv) defer stop() readID := func(proxyAddr string) string { t.Helper() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial %s: %v", proxyAddr, err) } defer conn.Close() hello := buildClientHello("svc.test") conn.Write(hello) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) buf := make([]byte, 128) // Read what the backend sends back: echoed ClientHello + ID. // The backend drains input and writes the ID, so we read until we // find the ID byte at the end. var all []byte for { n, err := conn.Read(buf) all = append(all, buf[:n]...) if err != nil { break } } if len(all) == 0 { t.Fatalf("no data from %s", proxyAddr) } // The ID is the last byte. return string(all[len(all)-1:]) } idA := readID(addr1) idB := readID(addr2) if idA != "A" { t.Fatalf("listener 1: got backend %q, want A", idA) } if idB != "B" { t.Fatalf("listener 2: got backend %q, want B", idB) } } func TestCaseInsensitiveRouting(t *testing.T) { backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go echoServer(t, backendLn) proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "echo.test": backendLn.Addr().String(), }, }, }) stop := startAndStop(t, srv) defer stop() // SNI extraction lowercases the hostname, so "ECHO.TEST" should match // the route for "echo.test". conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("ECHO.TEST") if _, err := conn.Write(hello); err != nil { t.Fatalf("write ClientHello: %v", err) } echoed := make([]byte, len(hello)) conn.SetReadDeadline(time.Now().Add(5 * time.Second)) if _, err := io.ReadFull(conn, echoed); err != nil { t.Fatalf("read echoed data: %v", err) } } func TestBackendUnreachable(t *testing.T) { // Find a port that nothing is listening on. ln, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } deadAddr := ln.Addr().String() ln.Close() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() srv := newTestServer(t, []ListenerData{ { ID: 1, Addr: proxyAddr, Routes: map[string]string{ "dead.test": deadAddr, }, }, }) stop := startAndStop(t, srv) defer stop() conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("dead.test") conn.Write(hello) // Proxy should close the connection after failing to dial backend. conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _, err = conn.Read(make([]byte, 1)) if err == nil { t.Fatal("expected connection to be closed when backend is unreachable") } } func TestGracefulShutdown(t *testing.T) { // Backend that holds the connection open. backendLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("backend listen: %v", err) } defer backendLn.Close() go func() { conn, err := backendLn.Accept() if err != nil { return } defer conn.Close() io.Copy(io.Discard, conn) }() proxyLn, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("finding free port: %v", err) } proxyAddr := proxyLn.Addr().String() proxyLn.Close() fw, err := firewall.New("", nil, nil, nil) if err != nil { t.Fatalf("creating firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 2 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) srv := New(cfg, fw, []ListenerData{ {ID: 1, Addr: proxyAddr, Routes: map[string]string{"hold.test": backendLn.Addr().String()}}, }, logger, "test") ctx, cancel := context.WithCancel(context.Background()) done := make(chan error, 1) go func() { done <- srv.Run(ctx) }() time.Sleep(50 * time.Millisecond) // Establish a connection that will be in-flight during shutdown. conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) if err != nil { t.Fatalf("dial proxy: %v", err) } defer conn.Close() hello := buildClientHello("hold.test") conn.Write(hello) time.Sleep(50 * time.Millisecond) // Trigger shutdown. cancel() // Server should exit within the shutdown timeout. select { case err := <-done: if err != nil { t.Fatalf("server.Run returned error: %v", err) } case <-time.After(5 * time.Second): t.Fatal("server did not shut down within 5 seconds") } } func TestListenerStateRoutes(t *testing.T) { ls := &ListenerState{ ID: 1, Addr: ":443", routes: map[string]string{ "a.test": "127.0.0.1:1", }, } // AddRoute if err := ls.AddRoute("b.test", "127.0.0.1:2"); err != nil { t.Fatalf("AddRoute: %v", err) } // AddRoute duplicate if err := ls.AddRoute("b.test", "127.0.0.1:3"); err == nil { t.Fatal("expected error for duplicate route") } // Routes snapshot routes := ls.Routes() if len(routes) != 2 { t.Fatalf("expected 2 routes, got %d", len(routes)) } // RemoveRoute if err := ls.RemoveRoute("a.test"); err != nil { t.Fatalf("RemoveRoute: %v", err) } // RemoveRoute not found if err := ls.RemoveRoute("nonexistent.test"); err == nil { t.Fatal("expected error for removing nonexistent route") } routes = ls.Routes() if len(routes) != 1 { t.Fatalf("expected 1 route, got %d", len(routes)) } if routes["b.test"] != "127.0.0.1:2" { t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"]) } } // --- ClientHello builder helpers (mirrors internal/sni test helpers) --- func buildClientHello(serverName string) []byte { return buildClientHelloWithExtensions(sniExtension(serverName)) } func buildClientHelloWithExtensions(extensions []byte) []byte { var hello []byte hello = append(hello, 0x03, 0x03) // TLS 1.2 hello = append(hello, make([]byte, 32)...) // random hello = append(hello, 0x00) // session ID: empty hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites hello = append(hello, 0x01, 0x00) // compression methods if len(extensions) > 0 { hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) hello = append(hello, extensions...) } handshake := []byte{0x01, 0x00, 0x00, 0x00} handshake[1] = byte(len(hello) >> 16) handshake[2] = byte(len(hello) >> 8) handshake[3] = byte(len(hello)) handshake = append(handshake, hello...) record := []byte{0x16, 0x03, 0x01} record = binary.BigEndian.AppendUint16(record, uint16(len(handshake))) record = append(record, handshake...) return record } func sniExtension(serverName string) []byte { name := []byte(serverName) var entry []byte entry = append(entry, 0x00) entry = binary.BigEndian.AppendUint16(entry, uint16(len(name))) entry = append(entry, name...) var list []byte list = binary.BigEndian.AppendUint16(list, uint16(len(entry))) list = append(list, entry...) var ext []byte ext = binary.BigEndian.AppendUint16(ext, 0x0000) ext = binary.BigEndian.AppendUint16(ext, uint16(len(list))) ext = append(ext, list...) return ext }