diff --git a/.golangci.yaml b/.golangci.yaml index affe97e..531a674 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,34 +1,124 @@ +# golangci-lint v2 configuration for mc-proxy. +# Principle: fail loudly. Security and correctness issues are errors, not warnings. + version: "2" +run: + timeout: 5m + # Include test files so security rules apply to test helpers too. + tests: true + linters: + default: none enable: + # --- Correctness --- + # Unhandled errors are silent failures; in auth code they become vulnerabilities. - errcheck + # go vet: catches printf-verb mismatches, unreachable code, suspicious constructs. - govet + # Detects assignments whose result is never used; dead writes hide logic bugs. - ineffassign + # Detects variables and functions that are never used. - unused + + # --- Error handling --- + # Enforces proper error wrapping (errors.Is/As instead of == comparisons) and + # prevents accidental discard of wrapped sentinel errors. - errorlint + + # --- Security --- + # Primary security scanner: hardcoded secrets, weak RNG, insecure crypto + # (MD5/SHA1/DES/RC4), SQL injection, insecure TLS, file permission issues, etc. - gosec + # Deep static analysis: deprecated APIs, incorrect mutex use, unreachable code, + # incorrect string conversions, simplification suggestions, and hundreds of other checks. + # (gosimple was merged into staticcheck in golangci-lint v2) - staticcheck + + # --- Style / conventions (per CLAUDE.md) --- + # Enforces Go naming conventions and selected style rules. - revive - - gofmt - - goimports settings: errcheck: + # Do NOT flag blank-identifier assignments: `_ = rows.Close()` in defers, + # `_ = tx.Rollback()` after errors, and `_ = fs.Parse(args)` with ExitOnError + # are all legitimate patterns where the error is genuinely unrecoverable or + # irrelevant. The default errcheck (without check-blank) still catches + # unchecked returns that have no assignment at all. + check-blank: false + # Flag discarded ok-value in type assertions: `c, _ := x.(*T)` — the ok + # value should be checked so a failed assertion is not silently treated as nil. check-type-assertions: true + govet: + # Enable all analyzers except shadow. The shadow analyzer flags the idiomatic + # `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`, + # which is ubiquitous in Go and does not pose a security risk in this codebase. + enable-all: true disable: - shadow + gosec: + # Treat all gosec findings as errors, not warnings. severity: medium confidence: medium excludes: + # G104 (errors unhandled) overlaps with errcheck; let errcheck own this. - G104 + errorlint: + errorf: true + asserts: true + comparison: true + + revive: + rules: + # error-return and unexported-return are correctness/API-safety rules. + - name: error-return + severity: error + - name: unexported-return + severity: error + # Style rules. + - name: error-strings + severity: warning + - name: if-return + severity: warning + - name: increment-decrement + severity: warning + - name: var-naming + severity: warning + - name: range + severity: warning + - name: time-naming + severity: warning + - name: indent-error-flow + severity: warning + - name: early-return + severity: warning + # exported and package-comments are omitted: this is a personal project, + # not a public library; godoc completeness is not a CI requirement. + +formatters: + enable: + # Enforces gofmt formatting. Non-formatted code is a CI failure. + - gofmt + # Manages import grouping and formatting; catches stray debug imports. + - goimports + issues: + # Do not cap the number of reported issues; in security code every finding matters. max-issues-per-linter: 0 - exclude-rules: - - path: _test\.go - linters: - - gosec - text: "G101" + max-same-issues: 0 + + exclusions: + paths: + - vendor + rules: + # In test files, allow hardcoded test credentials (gosec G101) since they are + # intentional fixtures, not production secrets. + - path: "_test\\.go" + linters: + - gosec + text: "G101" + diff --git a/Dockerfile b/Dockerfile index 6f755db..cffe3b6 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,20 @@ -FROM golang:1.24-alpine AS builder +FROM golang:1.25-alpine AS builder + +ARG VERSION=dev WORKDIR /build COPY go.mod go.sum ./ RUN go mod download COPY . . -RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o mc-proxy ./cmd/mc-proxy +RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w -X main.version=${VERSION}" \ + -o mc-proxy ./cmd/mc-proxy FROM alpine:3.21 RUN addgroup -S mc-proxy && adduser -S mc-proxy -G mc-proxy + COPY --from=builder /build/mc-proxy /usr/local/bin/mc-proxy USER mc-proxy ENTRYPOINT ["mc-proxy"] +CMD ["server", "--config", "/srv/mc-proxy/mc-proxy.toml"] diff --git a/Makefile b/Makefile index 08ad93d..247f7d0 100644 --- a/Makefile +++ b/Makefile @@ -30,7 +30,7 @@ clean: rm -f mc-proxy docker: - docker build -t mc-proxy -f Dockerfile . + docker build --build-arg VERSION=$(shell git describe --tags --always --dirty) -t mc-proxy -f Dockerfile . devserver: mc-proxy @mkdir -p srv diff --git a/README.md b/README.md index fc659a9..e20d3b3 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,74 @@ -mc-proxy is a TLS proxy and router for Metacircular Dynamics projects; -it follows the Metacircular Engineering Standards. +# mc-proxy -Metacircular services are deployed to a machine that runs these projects -as containers. The proxy should do a few things: +mc-proxy is a Layer 4 TLS SNI proxy and router for +[Metacircular Dynamics](https://metacircular.net) services. It reads the SNI +hostname from incoming TLS ClientHello messages and proxies the raw TCP stream +to the matched backend. It does not terminate TLS. -1. It should have a global firewall front-end. It should allow a few - things: +A global firewall (IP, CIDR, GeoIP country blocking) is evaluated before any +routing decision. Blocked connections receive a TCP RST with no further +information. - 1. Per-country blocks using GeoIP for compliance reasons. - 2. Normal IP/CIDR blocks. Note that a proxy has an explicit port - setting, so the firewall doesn't need to consider ports. - 3. For endpoints marked as HTTPS, we should consider how to do - user-agent blocking. +## Quick Start -2. It should inspect the hostname and route that to the proper - container, similar to how haproxy would do it. +```bash +# Build +make mc-proxy +# Run locally (creates srv/ with example config on first run) +make devserver +# Full CI pipeline: vet → lint → test → build +make all +``` + +## Configuration + +Copy the example config and edit it: + +```bash +cp mc-proxy.toml.example /srv/mc-proxy/mc-proxy.toml +``` + +See [ARCHITECTURE.md](ARCHITECTURE.md) for the full configuration reference. + +Key sections: +- `[database]` — SQLite database path (required) +- `[[listeners]]` — TCP ports to bind and their route tables (seeds DB on first run) +- `[grpc]` — optional gRPC admin API with TLS/mTLS +- `[firewall]` — global blocklist (IP, CIDR, GeoIP country) +- `[proxy]` — connect timeout, idle timeout, shutdown timeout + +## CLI Commands + +| Command | Purpose | +|---------|---------| +| `mc-proxy server -c ` | Start the proxy | +| `mc-proxy status -c ` | Query a running instance's health via gRPC | +| `mc-proxy snapshot -c ` | Create a database backup (`VACUUM INTO`) | + +## Deployment + +See [RUNBOOK.md](RUNBOOK.md) for operational procedures. + +```bash +# Install on a Linux host +sudo deploy/scripts/install.sh + +# Or build and run as a container +make docker +docker run -v /srv/mc-proxy:/srv/mc-proxy mc-proxy server -c /srv/mc-proxy/mc-proxy.toml +``` + +## Design + +mc-proxy intentionally omits a REST API and web frontend. The gRPC admin API +is the sole management interface. This is an intentional departure from the +Metacircular engineering standards — mc-proxy is pre-auth infrastructure and +a minimal attack surface is prioritized over interface breadth. + +See [ARCHITECTURE.md](ARCHITECTURE.md) for the full system specification. + +## License + +Proprietary. Metacircular Dynamics. diff --git a/RUNBOOK.md b/RUNBOOK.md new file mode 100644 index 0000000..4b40e59 --- /dev/null +++ b/RUNBOOK.md @@ -0,0 +1,304 @@ +# RUNBOOK.md + +Operational procedures for mc-proxy. Written for operators, not developers. + +## Service Overview + +mc-proxy is a Layer 4 TLS SNI proxy. It routes incoming TLS connections to +backend services based on the SNI hostname. It does not terminate TLS or +inspect application-layer traffic. A global firewall blocks connections by +IP, CIDR, or GeoIP country before routing. + +## Health Checks + +### Via gRPC (requires admin API enabled) + +```bash +mc-proxy status -c /srv/mc-proxy/mc-proxy.toml +``` + +Expected output: + +``` +mc-proxy v0.1.0 +uptime: 4h32m10s +connections: 1247 + + :443 routes=2 active=12 + :8443 routes=1 active=3 + :9443 routes=1 active=0 +``` + +### Via systemd + +```bash +systemctl status mc-proxy +journalctl -u mc-proxy -n 50 --no-pager +``` + +### Via process + +```bash +ss -tlnp | grep mc-proxy +``` + +Verify all configured listener ports are in LISTEN state. + +## Common Operations + +### Start / Stop / Restart + +```bash +systemctl start mc-proxy +systemctl stop mc-proxy +systemctl restart mc-proxy +``` + +Stopping the service triggers graceful shutdown: new connections are refused, +in-flight connections drain for up to `shutdown_timeout` (default 30s), then +remaining connections are force-closed. + +### View Logs + +```bash +# Recent logs +journalctl -u mc-proxy -n 100 --no-pager + +# Follow live +journalctl -u mc-proxy -f + +# Filter by severity +journalctl -u mc-proxy -p err +``` + +### Reload GeoIP Database + +Send SIGHUP to reload the GeoIP database without restarting: + +```bash +systemctl kill -s HUP mc-proxy +``` + +Or: + +```bash +kill -HUP $(pidof mc-proxy) +``` + +Verify in logs: + +``` +level=INFO msg="received SIGHUP, reloading GeoIP database" +``` + +### Create a Database Backup + +```bash +# Manual backup +mc-proxy snapshot -c /srv/mc-proxy/mc-proxy.toml + +# Manual backup to a specific path +mc-proxy snapshot -c /srv/mc-proxy/mc-proxy.toml -o /tmp/mc-proxy-backup.db +``` + +Automated daily backups run via the systemd timer: + +```bash +# Check timer status +systemctl list-timers mc-proxy-backup.timer + +# Run backup manually via systemd +systemctl start mc-proxy-backup.service + +# View backup logs +journalctl -u mc-proxy-backup.service -n 20 --no-pager +``` + +Backups are stored in `/srv/mc-proxy/backups/` and pruned after 30 days. + +### Restore from Backup + +1. Stop the service: + ```bash + systemctl stop mc-proxy + ``` +2. Replace the database: + ```bash + cp /srv/mc-proxy/backups/mc-proxy-.db /srv/mc-proxy/mc-proxy.db + chown mc-proxy:mc-proxy /srv/mc-proxy/mc-proxy.db + chmod 0600 /srv/mc-proxy/mc-proxy.db + ``` +3. Start the service: + ```bash + systemctl start mc-proxy + ``` +4. Verify health: + ```bash + mc-proxy status -c /srv/mc-proxy/mc-proxy.toml + ``` + +### Manage Routes at Runtime (gRPC) + +Routes can be added and removed at runtime via the gRPC admin API using +`grpcurl` or any gRPC client. + +```bash +# List routes for a listener +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/ListRoutes \ + -d '{"listener_addr": ":443"}' + +# Add a route +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/AddRoute \ + -d '{"listener_addr": ":443", "route": {"hostname": "new.metacircular.net", "backend": "127.0.0.1:38443"}}' + +# Remove a route +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/RemoveRoute \ + -d '{"listener_addr": ":443", "hostname": "old.metacircular.net"}' +``` + +### Manage Firewall Rules at Runtime (gRPC) + +```bash +# List rules +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/GetFirewallRules + +# Block an IP +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/AddFirewallRule \ + -d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}' + +# Block a CIDR +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/AddFirewallRule \ + -d '{"rule": {"type": "FIREWALL_RULE_TYPE_CIDR", "value": "198.51.100.0/24"}}' + +# Block a country +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/AddFirewallRule \ + -d '{"rule": {"type": "FIREWALL_RULE_TYPE_COUNTRY", "value": "RU"}}' + +# Remove a rule +grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/RemoveFirewallRule \ + -d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}' +``` + +## Incident Procedures + +### Proxy Not Starting + +1. Check logs for the error: + ```bash + journalctl -u mc-proxy -n 50 --no-pager + ``` +2. Common causes: + - **"database.path is required"** — config file missing or malformed. + - **"firewall: geoip_db is required"** — country blocks configured but GeoIP database missing. + - **"address already in use"** — another process holds the port. + ```bash + ss -tlnp | grep ':' + ``` + - **Permission denied on database** — check ownership: + ```bash + ls -la /srv/mc-proxy/mc-proxy.db + chown mc-proxy:mc-proxy /srv/mc-proxy/mc-proxy.db + ``` + +### High Connection Count / Resource Exhaustion + +1. Check active connections: + ```bash + mc-proxy status -c /srv/mc-proxy/mc-proxy.toml + ``` +2. Check system-level connection count: + ```bash + ss -tn | grep -c ':' + ``` +3. If under attack, add firewall rules via gRPC to block the source: + ```bash + grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \ + localhost:9090 mc_proxy.v1.ProxyAdminService/AddFirewallRule \ + -d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": ""}}' + ``` +4. If many IPs from one region, consider a country block or CIDR block. + +### Database Corruption + +1. Stop the service: + ```bash + systemctl stop mc-proxy + ``` +2. Check database integrity: + ```bash + sqlite3 /srv/mc-proxy/mc-proxy.db "PRAGMA integrity_check;" + ``` +3. If corrupted, restore from the most recent backup (see [Restore from Backup](#restore-from-backup)). +4. If no backups exist, delete the database and restart. The service will + re-seed from the TOML configuration: + ```bash + rm /srv/mc-proxy/mc-proxy.db + systemctl start mc-proxy + ``` + Note: any routes or firewall rules added at runtime via gRPC will be lost. + +### GeoIP Database Stale or Missing + +1. Download a fresh copy of GeoLite2-Country.mmdb from MaxMind. +2. Place it at the configured path: + ```bash + cp GeoLite2-Country.mmdb /srv/mc-proxy/GeoLite2-Country.mmdb + chown mc-proxy:mc-proxy /srv/mc-proxy/GeoLite2-Country.mmdb + ``` +3. Reload without restart: + ```bash + systemctl kill -s HUP mc-proxy + ``` + +### Certificate Expiry (gRPC Admin API) + +The gRPC admin API uses TLS certificates from `/srv/mc-proxy/certs/`. +Certificates are loaded at startup; replacing them requires a restart. + +1. Replace the certificates: + ```bash + cp new-cert.pem /srv/mc-proxy/certs/cert.pem + cp new-key.pem /srv/mc-proxy/certs/key.pem + chown mc-proxy:mc-proxy /srv/mc-proxy/certs/*.pem + chmod 0600 /srv/mc-proxy/certs/key.pem + ``` +2. Restart: + ```bash + systemctl restart mc-proxy + ``` + +Note: certificate expiry does not affect the proxy listeners — they do not +terminate TLS. + +### Backend Unreachable + +If a backend service is down, connections to routes pointing at that backend +will fail at the dial phase and the client receives a TCP RST. mc-proxy logs +the dial failure at `warn` level. + +1. Check logs for dial errors: + ```bash + journalctl -u mc-proxy -n 100 --no-pager | grep "dial" + ``` +2. Verify the backend is running: + ```bash + ss -tlnp | grep ':' + ``` +3. This is not an mc-proxy issue — fix the backend service. + +## Escalation + +If the runbook does not resolve the issue: + +1. Collect logs: `journalctl -u mc-proxy --since "1 hour ago" > /tmp/mc-proxy-logs.txt` +2. Collect status: `mc-proxy status -c /srv/mc-proxy/mc-proxy.toml > /tmp/mc-proxy-status.txt` +3. Collect database state: `mc-proxy snapshot -c /srv/mc-proxy/mc-proxy.toml -o /tmp/mc-proxy-escalation.db` +4. Escalate with the collected artifacts. diff --git a/deploy/docker/docker-compose.yml b/deploy/docker/docker-compose.yml new file mode 100644 index 0000000..5e154bf --- /dev/null +++ b/deploy/docker/docker-compose.yml @@ -0,0 +1,14 @@ +services: + mc-proxy: + build: + context: ../.. + dockerfile: Dockerfile + args: + VERSION: "${VERSION:-dev}" + ports: + - "443:443" + - "8443:8443" + - "9443:9443" + volumes: + - /srv/mc-proxy:/srv/mc-proxy + restart: unless-stopped diff --git a/internal/grpcserver/grpcserver_test.go b/internal/grpcserver/grpcserver_test.go new file mode 100644 index 0000000..9d38c49 --- /dev/null +++ b/internal/grpcserver/grpcserver_test.go @@ -0,0 +1,451 @@ +package grpcserver + +import ( + "context" + "io" + "log/slog" + "net" + "path/filepath" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/status" + "google.golang.org/grpc/test/bufconn" + + pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/db" + "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" + "git.wntrmute.dev/kyle/mc-proxy/internal/server" +) + +// testEnv bundles all the objects needed for a grpcserver test. +type testEnv struct { + client pb.ProxyAdminServiceClient + conn *grpc.ClientConn + store *db.Store + srv *server.Server +} + +func setup(t *testing.T) *testEnv { + t.Helper() + + // Database in temp dir. + dbPath := filepath.Join(t.TempDir(), "test.db") + store, err := db.Open(dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + t.Cleanup(func() { store.Close() }) + + if err := store.Migrate(); err != nil { + t.Fatalf("migrate: %v", err) + } + + // Seed with one listener and one route. + listeners := []config.Listener{ + { + Addr: ":443", + Routes: []config.Route{ + {Hostname: "a.test", Backend: "127.0.0.1:8443"}, + }, + }, + } + fw := config.Firewall{ + BlockedIPs: []string{"10.0.0.1"}, + } + if err := store.Seed(listeners, fw); err != nil { + t.Fatalf("seed: %v", err) + } + + // Build server with matching in-memory state. + fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil) + if err != nil { + t.Fatalf("firewall: %v", err) + } + + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Load listener data from DB to get correct IDs. + dbListeners, err := store.ListListeners() + if err != nil { + t.Fatalf("list listeners: %v", err) + } + var listenerData []server.ListenerData + for _, l := range dbListeners { + dbRoutes, err := store.ListRoutes(l.ID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + routes := make(map[string]string, len(dbRoutes)) + for _, r := range dbRoutes { + routes[r.Hostname] = r.Backend + } + listenerData = append(listenerData, server.ListenerData{ + ID: l.ID, + Addr: l.Addr, + Routes: routes, + }) + } + + srv := server.New(cfg, fwObj, listenerData, logger, "test-version") + + // Set up bufconn gRPC server (no TLS for tests). + lis := bufconn.Listen(1024 * 1024) + grpcSrv := grpc.NewServer() + admin := &AdminServer{ + srv: srv, + store: store, + logger: logger, + } + pb.RegisterProxyAdminServiceServer(grpcSrv, admin) + + go func() { + if err := grpcSrv.Serve(lis); err != nil { + t.Logf("grpc serve: %v", err) + } + }() + t.Cleanup(grpcSrv.Stop) + + conn, err := grpc.NewClient("passthrough://bufconn", + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("dial bufconn: %v", err) + } + t.Cleanup(func() { conn.Close() }) + + return &testEnv{ + client: pb.NewProxyAdminServiceClient(conn), + conn: conn, + store: store, + srv: srv, + } +} + +func TestGetStatus(t *testing.T) { + env := setup(t) + ctx := context.Background() + + resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) + if err != nil { + t.Fatalf("GetStatus: %v", err) + } + if resp.Version != "test-version" { + t.Fatalf("got version %q, want %q", resp.Version, "test-version") + } + if len(resp.Listeners) != 1 { + t.Fatalf("got %d listeners, want 1", len(resp.Listeners)) + } + if resp.Listeners[0].Addr != ":443" { + t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443") + } + if resp.Listeners[0].RouteCount != 1 { + t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount) + } +} + +func TestListRoutes(t *testing.T) { + env := setup(t) + ctx := context.Background() + + resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + if len(resp.Routes) != 1 { + t.Fatalf("got %d routes, want 1", len(resp.Routes)) + } + if resp.Routes[0].Hostname != "a.test" { + t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test") + } + if resp.Routes[0].Backend != "127.0.0.1:8443" { + t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443") + } +} + +func TestListRoutesNotFound(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"}) + if err == nil { + t.Fatal("expected error for nonexistent listener") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound { + t.Fatalf("expected NotFound, got %v", err) + } +} + +func TestAddRoute(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"}, + }) + if err != nil { + t.Fatalf("AddRoute: %v", err) + } + + // Verify in-memory. + resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + if len(resp.Routes) != 2 { + t.Fatalf("got %d routes, want 2", len(resp.Routes)) + } + + // Verify in DB. + dbListeners, err := env.store.ListListeners() + if err != nil { + t.Fatalf("list listeners: %v", err) + } + dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + if len(dbRoutes) != 2 { + t.Fatalf("DB has %d routes, want 2", len(dbRoutes)) + } +} + +func TestAddRouteDuplicate(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"}, + }) + if err == nil { + t.Fatal("expected error for duplicate route") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists { + t.Fatalf("expected AlreadyExists, got %v", err) + } +} + +func TestAddRouteValidation(t *testing.T) { + env := setup(t) + ctx := context.Background() + + // Missing route. + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"}) + if err == nil { + t.Fatal("expected error for nil route") + } + + // Missing hostname. + _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{Backend: "127.0.0.1:1"}, + }) + if err == nil { + t.Fatal("expected error for empty hostname") + } + + // Missing backend. + _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{Hostname: "x.test"}, + }) + if err == nil { + t.Fatal("expected error for empty backend") + } +} + +func TestRemoveRoute(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{ + ListenerAddr: ":443", + Hostname: "a.test", + }) + if err != nil { + t.Fatalf("RemoveRoute: %v", err) + } + + // Verify removed from memory. + resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + if len(resp.Routes) != 0 { + t.Fatalf("got %d routes, want 0", len(resp.Routes)) + } +} + +func TestRemoveRouteNotFound(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{ + ListenerAddr: ":443", + Hostname: "nonexistent.test", + }) + if err == nil { + t.Fatal("expected error for removing nonexistent route") + } +} + +func TestGetFirewallRules(t *testing.T) { + env := setup(t) + ctx := context.Background() + + resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + if len(resp.Rules) != 1 { + t.Fatalf("got %d rules, want 1", len(resp.Rules)) + } + if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP { + t.Fatalf("got type %v, want IP", resp.Rules[0].Type) + } + if resp.Rules[0].Value != "10.0.0.1" { + t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1") + } +} + +func TestAddFirewallRule(t *testing.T) { + env := setup(t) + ctx := context.Background() + + // Add IP rule. + _, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + Value: "10.0.0.2", + }, + }) + if err != nil { + t.Fatalf("AddFirewallRule IP: %v", err) + } + + // Add CIDR rule. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, + Value: "192.168.0.0/16", + }, + }) + if err != nil { + t.Fatalf("AddFirewallRule CIDR: %v", err) + } + + // Add country rule. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, + Value: "RU", + }, + }) + if err != nil { + t.Fatalf("AddFirewallRule country: %v", err) + } + + // Verify. + resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + if len(resp.Rules) != 4 { + t.Fatalf("got %d rules, want 4", len(resp.Rules)) + } + + // Verify DB persistence. + dbRules, err := env.store.ListFirewallRules() + if err != nil { + t.Fatalf("list firewall rules: %v", err) + } + if len(dbRules) != 4 { + t.Fatalf("DB has %d rules, want 4", len(dbRules)) + } +} + +func TestAddFirewallRuleValidation(t *testing.T) { + env := setup(t) + ctx := context.Background() + + // Nil rule. + _, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{}) + if err == nil { + t.Fatal("expected error for nil rule") + } + + // Unknown type. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED, + Value: "x", + }, + }) + if err == nil { + t.Fatal("expected error for unspecified rule type") + } + + // Empty value. + _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + }, + }) + if err == nil { + t.Fatal("expected error for empty value") + } +} + +func TestRemoveFirewallRule(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + Value: "10.0.0.1", + }, + }) + if err != nil { + t.Fatalf("RemoveFirewallRule: %v", err) + } + + resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + if len(resp.Rules) != 0 { + t.Fatalf("got %d rules, want 0", len(resp.Rules)) + } +} + +func TestRemoveFirewallRuleNotFound(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + Value: "99.99.99.99", + }, + }) + if err == nil { + t.Fatal("expected error for removing nonexistent rule") + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go new file mode 100644 index 0000000..5605ac7 --- /dev/null +++ b/internal/server/server_test.go @@ -0,0 +1,746 @@ +package server + +import ( + "context" + "encoding/binary" + "io" + "log/slog" + "net" + "sync" + "testing" + "time" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" +) + +// echoServer accepts one connection, copies everything back, then closes. +func echoServer(t *testing.T, ln net.Listener) { + t.Helper() + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + io.Copy(conn, conn) +} + +// newTestServer creates a Server with the given listener data and no firewall rules. +func newTestServer(t *testing.T, listeners []ListenerData) *Server { + t.Helper() + fw, err := firewall.New("", nil, nil, nil) + if err != nil { + t.Fatalf("creating firewall: %v", err) + } + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + return New(cfg, fw, listeners, logger, "test") +} + +// startAndStop starts the server in a goroutine and returns a cancel function +// that shuts it down and waits for it to exit. +func startAndStop(t *testing.T, srv *Server) context.CancelFunc { + t.Helper() + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + if err := srv.Run(ctx); err != nil { + t.Errorf("server.Run: %v", err) + } + }() + // Give the listeners a moment to bind. + time.Sleep(50 * time.Millisecond) + return func() { + cancel() + wg.Wait() + } +} + +func TestProxyRoundTrip(t *testing.T) { + // Start an echo backend. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + go echoServer(t, backendLn) + + // Pick a free port for the proxy listener. + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "echo.test": backendLn.Addr().String(), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + // Connect through the proxy with a fake ClientHello. + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("echo.test") + if _, err := conn.Write(hello); err != nil { + t.Fatalf("write ClientHello: %v", err) + } + + // The backend will echo our ClientHello back. Read it. + echoed := make([]byte, len(hello)) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.ReadFull(conn, echoed); err != nil { + t.Fatalf("read echoed data: %v", err) + } + + // Send some additional data. + payload := []byte("hello from client") + if _, err := conn.Write(payload); err != nil { + t.Fatalf("write payload: %v", err) + } + + buf := make([]byte, len(payload)) + if _, err := io.ReadFull(conn, buf); err != nil { + t.Fatalf("read echoed payload: %v", err) + } + if string(buf) != string(payload) { + t.Fatalf("got %q, want %q", buf, payload) + } +} + +func TestNoRouteResets(t *testing.T) { + // Proxy listener with no routes for the requested hostname. + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "other.test": "127.0.0.1:1", // exists but won't match + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("unknown.test") + if _, err := conn.Write(hello); err != nil { + t.Fatalf("write ClientHello: %v", err) + } + + // The proxy should close the connection (no route match). + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed, but read succeeded") + } +} + +func TestFirewallBlocks(t *testing.T) { + // Start a backend that should never be reached. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + reached := make(chan struct{}, 1) + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + conn.Close() + reached <- struct{}{} + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + // Create a firewall that blocks 127.0.0.1 (the test client). + fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil) + if err != nil { + t.Fatalf("creating firewall: %v", err) + } + + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(cfg, fw, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "echo.test": backendLn.Addr().String(), + }, + }, + }, logger, "test") + + ctx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + srv.Run(ctx) + }() + time.Sleep(50 * time.Millisecond) + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("echo.test") + conn.Write(hello) + + // Connection should be closed (blocked by firewall). + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed by firewall") + } + + // Backend should not have been reached. + select { + case <-reached: + t.Fatal("backend was reached despite firewall block") + case <-time.After(200 * time.Millisecond): + // Expected. + } + + cancel() + wg.Wait() +} + +func TestNotTLSResets(t *testing.T) { + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{"x.test": "127.0.0.1:1"}, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + // Send HTTP, not TLS. + conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n")) + + conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed for non-TLS data") + } +} + +func TestConnectionTracking(t *testing.T) { + // Backend that holds connections open until we close it. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + var backendConns []net.Conn + var mu sync.Mutex + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + mu.Lock() + backendConns = append(backendConns, conn) + mu.Unlock() + // Hold connection open, drain input. + go io.Copy(io.Discard, conn) + } + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "conn.test": backendLn.Addr().String(), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + if got := srv.TotalConnections(); got != 0 { + t.Fatalf("expected 0 connections before any clients, got %d", got) + } + + // Open two connections through the proxy. + var clientConns []net.Conn + for i := range 2 { + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy %d: %v", i, err) + } + hello := buildClientHello("conn.test") + if _, err := conn.Write(hello); err != nil { + t.Fatalf("write ClientHello %d: %v", i, err) + } + clientConns = append(clientConns, conn) + } + + // Give connections time to be established. + time.Sleep(100 * time.Millisecond) + + if got := srv.TotalConnections(); got != 2 { + t.Fatalf("expected 2 active connections, got %d", got) + } + + // Close one client and its corresponding backend connection. + clientConns[0].Close() + mu.Lock() + if len(backendConns) > 0 { + backendConns[0].Close() + } + mu.Unlock() + + // Wait for the relay goroutines to detect the close. + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if srv.TotalConnections() == 1 { + break + } + time.Sleep(50 * time.Millisecond) + } + if got := srv.TotalConnections(); got != 1 { + t.Fatalf("expected 1 active connection after closing one, got %d", got) + } + + // Clean up. + clientConns[1].Close() + mu.Lock() + for _, c := range backendConns { + c.Close() + } + mu.Unlock() +} + +func TestMultipleListeners(t *testing.T) { + // Two backends. + backendA, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend A listen: %v", err) + } + defer backendA.Close() + + backendB, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend B listen: %v", err) + } + defer backendB.Close() + + // Each backend writes its identity and closes. + serve := func(ln net.Listener, id string) { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + // Drain the incoming data, then write identity. + go io.Copy(io.Discard, conn) + conn.Write([]byte(id)) + } + go serve(backendA, "A") + go serve(backendB, "B") + + // Two proxy listeners, same hostname, different backends. + ln1, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port 1: %v", err) + } + addr1 := ln1.Addr().String() + ln1.Close() + + ln2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port 2: %v", err) + } + addr2 := ln2.Addr().String() + ln2.Close() + + srv := newTestServer(t, []ListenerData{ + {ID: 1, Addr: addr1, Routes: map[string]string{"svc.test": backendA.Addr().String()}}, + {ID: 2, Addr: addr2, Routes: map[string]string{"svc.test": backendB.Addr().String()}}, + }) + + stop := startAndStop(t, srv) + defer stop() + + readID := func(proxyAddr string) string { + t.Helper() + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial %s: %v", proxyAddr, err) + } + defer conn.Close() + + hello := buildClientHello("svc.test") + conn.Write(hello) + + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + buf := make([]byte, 128) + // Read what the backend sends back: echoed ClientHello + ID. + // The backend drains input and writes the ID, so we read until we + // find the ID byte at the end. + var all []byte + for { + n, err := conn.Read(buf) + all = append(all, buf[:n]...) + if err != nil { + break + } + } + if len(all) == 0 { + t.Fatalf("no data from %s", proxyAddr) + } + // The ID is the last byte. + return string(all[len(all)-1:]) + } + + idA := readID(addr1) + idB := readID(addr2) + + if idA != "A" { + t.Fatalf("listener 1: got backend %q, want A", idA) + } + if idB != "B" { + t.Fatalf("listener 2: got backend %q, want B", idB) + } +} + +func TestCaseInsensitiveRouting(t *testing.T) { + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + go echoServer(t, backendLn) + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "echo.test": backendLn.Addr().String(), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + // SNI extraction lowercases the hostname, so "ECHO.TEST" should match + // the route for "echo.test". + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("ECHO.TEST") + if _, err := conn.Write(hello); err != nil { + t.Fatalf("write ClientHello: %v", err) + } + + echoed := make([]byte, len(hello)) + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + if _, err := io.ReadFull(conn, echoed); err != nil { + t.Fatalf("read echoed data: %v", err) + } +} + +func TestBackendUnreachable(t *testing.T) { + // Find a port that nothing is listening on. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + deadAddr := ln.Addr().String() + ln.Close() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + Routes: map[string]string{ + "dead.test": deadAddr, + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("dead.test") + conn.Write(hello) + + // Proxy should close the connection after failing to dial backend. + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + _, err = conn.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected connection to be closed when backend is unreachable") + } +} + +func TestGracefulShutdown(t *testing.T) { + // Backend that holds the connection open. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + io.Copy(io.Discard, conn) + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("finding free port: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + fw, err := firewall.New("", nil, nil, nil) + if err != nil { + t.Fatalf("creating firewall: %v", err) + } + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 2 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + srv := New(cfg, fw, []ListenerData{ + {ID: 1, Addr: proxyAddr, Routes: map[string]string{"hold.test": backendLn.Addr().String()}}, + }, logger, "test") + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- srv.Run(ctx) + }() + time.Sleep(50 * time.Millisecond) + + // Establish a connection that will be in-flight during shutdown. + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial proxy: %v", err) + } + defer conn.Close() + + hello := buildClientHello("hold.test") + conn.Write(hello) + time.Sleep(50 * time.Millisecond) + + // Trigger shutdown. + cancel() + + // Server should exit within the shutdown timeout. + select { + case err := <-done: + if err != nil { + t.Fatalf("server.Run returned error: %v", err) + } + case <-time.After(5 * time.Second): + t.Fatal("server did not shut down within 5 seconds") + } +} + +func TestListenerStateRoutes(t *testing.T) { + ls := &ListenerState{ + ID: 1, + Addr: ":443", + routes: map[string]string{ + "a.test": "127.0.0.1:1", + }, + } + + // AddRoute + if err := ls.AddRoute("b.test", "127.0.0.1:2"); err != nil { + t.Fatalf("AddRoute: %v", err) + } + + // AddRoute duplicate + if err := ls.AddRoute("b.test", "127.0.0.1:3"); err == nil { + t.Fatal("expected error for duplicate route") + } + + // Routes snapshot + routes := ls.Routes() + if len(routes) != 2 { + t.Fatalf("expected 2 routes, got %d", len(routes)) + } + + // RemoveRoute + if err := ls.RemoveRoute("a.test"); err != nil { + t.Fatalf("RemoveRoute: %v", err) + } + + // RemoveRoute not found + if err := ls.RemoveRoute("nonexistent.test"); err == nil { + t.Fatal("expected error for removing nonexistent route") + } + + routes = ls.Routes() + if len(routes) != 1 { + t.Fatalf("expected 1 route, got %d", len(routes)) + } + if routes["b.test"] != "127.0.0.1:2" { + t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"]) + } +} + +// --- ClientHello builder helpers (mirrors internal/sni test helpers) --- + +func buildClientHello(serverName string) []byte { + return buildClientHelloWithExtensions(sniExtension(serverName)) +} + +func buildClientHelloWithExtensions(extensions []byte) []byte { + var hello []byte + + hello = append(hello, 0x03, 0x03) // TLS 1.2 + hello = append(hello, make([]byte, 32)...) // random + hello = append(hello, 0x00) // session ID: empty + hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites + hello = append(hello, 0x01, 0x00) // compression methods + + if len(extensions) > 0 { + hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) + hello = append(hello, extensions...) + } + + handshake := []byte{0x01, 0x00, 0x00, 0x00} + handshake[1] = byte(len(hello) >> 16) + handshake[2] = byte(len(hello) >> 8) + handshake[3] = byte(len(hello)) + handshake = append(handshake, hello...) + + record := []byte{0x16, 0x03, 0x01} + record = binary.BigEndian.AppendUint16(record, uint16(len(handshake))) + record = append(record, handshake...) + + return record +} + +func sniExtension(serverName string) []byte { + name := []byte(serverName) + + var entry []byte + entry = append(entry, 0x00) + entry = binary.BigEndian.AppendUint16(entry, uint16(len(name))) + entry = append(entry, name...) + + var list []byte + list = binary.BigEndian.AppendUint16(list, uint16(len(entry))) + list = append(list, entry...) + + var ext []byte + ext = binary.BigEndian.AppendUint16(ext, 0x0000) + ext = binary.BigEndian.AppendUint16(ext, uint16(len(list))) + ext = append(ext, list...) + + return ext +} +