From c7024dcdf0fdbb115279400c73d69f66174ff33a Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Tue, 17 Mar 2026 02:56:24 -0700 Subject: [PATCH] Initial implementation of mc-proxy Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) --- .gitignore | 12 + .golangci.yaml | 34 +++ ARCHITECTURE.md | 345 +++++++++++++++++++++++++++++ CLAUDE.md | 54 +++++ Dockerfile | 15 ++ Makefile | 36 +++ README.md | 19 ++ deploy/mc-proxy.toml.example | 37 ++++ deploy/scripts/install.sh | 42 ++++ deploy/systemd/mc-proxy.service | 32 +++ go.mod | 20 ++ go.sum | 60 +++++ internal/config/config.go | 115 ++++++++++ internal/config/config_test.go | 186 ++++++++++++++++ internal/firewall/firewall.go | 218 ++++++++++++++++++ internal/firewall/firewall_test.go | 141 ++++++++++++ internal/grpcserver/grpcserver.go | 246 ++++++++++++++++++++ internal/proxy/proxy.go | 105 +++++++++ internal/proxy/proxy_test.go | 259 ++++++++++++++++++++++ internal/server/server.go | 271 ++++++++++++++++++++++ internal/sni/sni.go | 175 +++++++++++++++ internal/sni/sni_test.go | 220 ++++++++++++++++++ mc-proxy.toml.example | 51 +++++ 23 files changed, 2693 insertions(+) create mode 100644 .gitignore create mode 100644 .golangci.yaml create mode 100644 ARCHITECTURE.md create mode 100644 CLAUDE.md create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 README.md create mode 100644 deploy/mc-proxy.toml.example create mode 100755 deploy/scripts/install.sh create mode 100644 deploy/systemd/mc-proxy.service create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/config/config.go create mode 100644 internal/config/config_test.go create mode 100644 internal/firewall/firewall.go create mode 100644 internal/firewall/firewall_test.go create mode 100644 internal/grpcserver/grpcserver.go create mode 100644 internal/proxy/proxy.go create mode 100644 internal/proxy/proxy_test.go create mode 100644 internal/server/server.go create mode 100644 internal/sni/sni.go create mode 100644 internal/sni/sni_test.go create mode 100644 mc-proxy.toml.example diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..75c8bea --- /dev/null +++ b/.gitignore @@ -0,0 +1,12 @@ +# Binary +mc-proxy + +# Runtime data +srv/ + +# IDE +.idea/ +.vscode/ + +# OS +.DS_Store diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..affe97e --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,34 @@ +version: "2" + +linters: + enable: + - errcheck + - govet + - ineffassign + - unused + - errorlint + - gosec + - staticcheck + - revive + - gofmt + - goimports + + settings: + errcheck: + check-type-assertions: true + govet: + disable: + - shadow + gosec: + severity: medium + confidence: medium + excludes: + - G104 + +issues: + max-issues-per-linter: 0 + exclude-rules: + - path: _test\.go + linters: + - gosec + text: "G101" diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md new file mode 100644 index 0000000..d66e7e1 --- /dev/null +++ b/ARCHITECTURE.md @@ -0,0 +1,345 @@ +# ARCHITECTURE.md + +mc-proxy is a Layer 4 TLS proxy and router for Metacircular Dynamics +services. It inspects the SNI field of incoming TLS ClientHello messages to +determine the target backend, then proxies raw TCP between the client and +the appropriate container. A global firewall evaluates every connection +before routing. + +## Table of Contents + +1. [System Overview](#system-overview) +2. [Connection Lifecycle](#connection-lifecycle) +3. [Firewall](#firewall) +4. [Routing](#routing) +5. [Configuration](#configuration) +6. [Storage](#storage) +7. [Deployment](#deployment) +8. [Security Model](#security-model) +9. [Future Work](#future-work) + +--- + +## System Overview + +``` + ┌─────────────────────────────────────┐ + │ mc-proxy │ + Clients ──────┐ │ │ + │ │ ┌──────────┐ ┌───────┐ ┌─────┐ │ ┌────────────┐ + ├────▶│ │ Firewall │──▶│ SNI │──▶│Route│─│────▶│ Backend A │ + │ │ │ (global) │ │Extract│ │Table│ │ │ :8443 │ + ├────▶│ └──────────┘ └───────┘ └─────┘ │ ├────────────┤ + │ │ │ RST │ │ │ Backend B │ + Clients ──────┘ │ ▼ └────│────▶│ :9443 │ + │ (blocked) │ └────────────┘ + └─────────────────────────────────────┘ + + Listener 1 (:443) ─┐ + Listener 2 (:8443) ─┼─ Each listener runs the same pipeline + Listener N (:9443) ─┘ +``` + +Key properties: + +- **Layer 4 only.** mc-proxy never terminates TLS. It reads just enough of + the ClientHello to extract the SNI hostname, then proxies the raw TCP + stream to the matched backend. The backend handles TLS termination. +- **TLS-only.** Non-TLS connections are not supported. If the first bytes of + a connection are not a TLS ClientHello, the connection is reset. +- **Multiple listeners.** A single mc-proxy instance binds to one or more + ports. Each listener runs the same firewall → SNI → route pipeline. +- **Global firewall.** Firewall rules apply to all listeners uniformly. + There are no per-route firewall rules. +- **No authentication.** mc-proxy is pre-auth infrastructure. It sits in + front of services that handle their own authentication via MCIAS. + +--- + +## Connection Lifecycle + +Every inbound connection follows this sequence: + +``` +1. ACCEPT Listener accepts TCP connection. +2. FIREWALL Check source IP against blocklists: + a. IP/CIDR block check. + b. GeoIP country block check. + If blocked → RST, done. +3. SNI EXTRACT Read the TLS ClientHello (without consuming it). + Extract the SNI hostname. + If no valid ClientHello or no SNI → RST, done. +4. ROUTE LOOKUP Match SNI hostname against the route table. + If no match → RST, done. +5. BACKEND DIAL Open TCP connection to the matched backend address. + If dial fails → RST, done. +6. PROXY Bidirectional byte copy: client ↔ backend. + The buffered ClientHello bytes are forwarded first, + then both directions copy concurrently. +7. CLOSE Either side closes → half-close propagation → done. +``` + +### SNI Extraction + +The proxy peeks at the initial bytes of the connection without consuming +them. It parses just enough of the TLS record layer and ClientHello to +extract the `server_name` extension. The full ClientHello (including the +SNI) is then forwarded to the backend so the backend's TLS handshake +proceeds normally. + +If the ClientHello spans multiple TCP segments, the proxy buffers up to +16 KiB (the maximum TLS record size) before giving up. + +### Bidirectional Copy + +After the backend connection is established, the proxy runs two concurrent +copy loops (client→backend and backend→client). When either direction +encounters an EOF or error: + +1. The write side of the opposite direction is half-closed. +2. The remaining direction drains to completion. +3. Both connections are closed. + +Timeouts apply to the copy phase to prevent idle connections from +accumulating indefinitely (see [Configuration](#configuration)). + +--- + +## Firewall + +The firewall is a global, ordered rule set evaluated on every new +connection before SNI extraction. Rules are evaluated in definition order; +the first matching rule determines the outcome. If no rule matches, the +connection is **allowed** (default allow — the firewall is an explicit +blocklist, not an allowlist). + +### Rule Types + +| Config Field | Match Field | Example | +|--------------|-------------|---------| +| `blocked_ips` | Source IP address | `"192.0.2.1"` | +| `blocked_cidrs` | Source IP prefix | `"198.51.100.0/24"` | +| `blocked_countries` | Source country (ISO 3166-1 alpha-2) | `"KP"`, `"CN"`, `"IN"`, `"IL"` | + +### Blocked Connection Handling + +Blocked connections receive a TCP RST. No TLS alert, no HTTP error page, no +indication of why the connection was refused. This is intentional — blocked +sources should receive minimal information. + +### GeoIP Database + +mc-proxy uses the [MaxMind GeoLite2](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data) +free database for country-level IP geolocation. The database file is +distributed separately from the binary and referenced by path in the +configuration. + +The GeoIP database is loaded into memory at startup and can be reloaded +via `SIGHUP` without restarting the proxy. If the database file is missing +or unreadable at startup and GeoIP rules are configured, the proxy refuses +to start. + +--- + +## Routing + +Each listener has its own route table mapping SNI hostnames to backend +addresses. A route entry consists of: + +| Field | Type | Description | +|-------|------|-------------| +| `hostname` | string | Exact SNI hostname to match (e.g. `metacrypt.metacircular.net`) | +| `backend` | string | Backend address as `host:port` (e.g. `127.0.0.1:8443`) | + +Routes are scoped to the listener that accepted the connection. The same +hostname can appear on different listeners with different backends, allowing +the proxy to route the same service name to different backends depending +on which port the client connected to. + +### Match Semantics + +- Hostname matching is **exact** and **case-insensitive** (per RFC 6066, + SNI hostnames are DNS names and compared case-insensitively). +- Wildcard matching is not supported in the initial implementation. +- If duplicate hostnames appear within the same listener, the proxy refuses + to start. +- If no route matches an incoming SNI hostname, the connection is reset. + +### Route Table Source + +Route tables are defined inline under each listener in the TOML +configuration file. The design anticipates future migration to a SQLite +database for dynamic route management via the control plane API. + +--- + +## Configuration + +TOML configuration file, loaded at startup. The proxy refuses to start if +required fields are missing or invalid. + +```toml +# Listeners. Each has its own route table. +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" + +[[listeners]] +addr = ":8443" + + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + +[[listeners]] +addr = ":9443" + + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" + +# Firewall. Global blocklist, evaluated before routing. Default allow. +[firewall] +geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb" +blocked_ips = ["192.0.2.1"] +blocked_cidrs = ["198.51.100.0/24"] +blocked_countries = ["KP", "CN", "IN", "IL"] + +# Proxy behavior. +[proxy] +connect_timeout = "5s" # Timeout for dialing backend +idle_timeout = "300s" # Close connections idle longer than this +shutdown_timeout = "30s" # Graceful shutdown drain period + +# Logging. +[log] +level = "info" # debug, info, warn, error +``` + +### Environment Variable Overrides + +Configuration values can be overridden via environment variables using the +prefix `MCPROXY_` with underscore-separated paths: + +``` +MCPROXY_LOG_LEVEL=debug +MCPROXY_PROXY_IDLE_TIMEOUT=600s +``` + +Environment variables cannot define listeners, routes, or firewall rules — +these are structural and must be in the TOML file. + +--- + +## Storage + +mc-proxy has minimal storage requirements. There is no database in the +initial implementation. + +``` +/srv/mc-proxy/ +├── mc-proxy.toml Configuration +├── GeoLite2-Country.mmdb GeoIP database (if using country blocks) +└── backups/ Reserved for future use +``` + +No TLS certificates are stored — mc-proxy does not terminate TLS. + +--- + +## Deployment + +### Binary + +Single static binary, built with `CGO_ENABLED=0`. No runtime dependencies +beyond the configuration file and optional GeoIP database. + +### Container + +Multi-stage Docker build: + +1. **Builder**: `golang:-alpine`, static compilation. +2. **Runtime**: `alpine`, non-root user (`mc-proxy`), port exposure + determined by configuration. + +### systemd + +| File | Purpose | +|------|---------| +| `mc-proxy.service` | Main proxy service | + +The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE` +in the systemd unit rather than running as root. + +Standard security hardening directives apply per engineering standards +(`NoNewPrivileges=true`, `ProtectSystem=strict`, etc.). + +### Graceful Shutdown + +On `SIGINT` or `SIGTERM`: + +1. Stop accepting new connections on all listeners. +2. Wait for in-flight connections to complete (up to `shutdown_timeout`). +3. Force-close remaining connections. +4. Exit. + +On `SIGHUP`: + +1. Reload the GeoIP database from disk. +2. Continue serving with the updated database. + +Configuration changes (routes, listeners, firewall rules) require a full +restart. Hot reload of routing rules is deferred to the future SQLite-backed +implementation. + +--- + +## Security Model + +mc-proxy is infrastructure that sits in front of authenticated services. +It has no authentication or authorization of its own. + +### Threat Mitigations + +| Threat | Mitigation | +|--------|------------| +| SNI spoofing | Backend performs its own TLS handshake — a spoofed SNI will fail certificate validation at the backend. mc-proxy does not trust SNI for security decisions beyond routing. | +| Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). | +| GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. | +| GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. | +| Slowloris / slow ClientHello | Timeout on the SNI extraction phase. If a complete ClientHello is not received within a reasonable window (e.g. 10s), the connection is reset. | +| Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. | +| Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. | + +### Security Invariants + +1. mc-proxy never terminates TLS. It cannot read application-layer traffic. +2. mc-proxy never modifies the byte stream between client and backend. +3. Firewall rules are always evaluated before any routing decision. +4. The proxy never logs connection content — only metadata (source IP, + SNI hostname, backend, timestamps, bytes transferred). + +--- + +## Future Work + +Items are listed roughly in priority order: + +| Item | Description | +|------|-------------| +| **gRPC admin API** | Internal-only API for managing routes and firewall rules at runtime, integrated with the Metacircular Control Plane. | +| **SQLite route storage** | Migrate route table from TOML to SQLite for dynamic management via the admin API. | +| **L7 HTTPS support** | TLS-terminating mode for selected routes, enabling HTTP-level features (user-agent blocking, header inspection, request routing). | +| **ACME integration** | Automatic certificate provisioning via Let's Encrypt for L7 routes. | +| **User-agent blocking** | Block connections based on user-agent string (requires L7 mode). | +| **Connection rate limiting** | Per-source-IP rate limits to mitigate connection floods. | +| **Per-listener connection limits** | Cap maximum concurrent connections per listener. | +| **Health check endpoint** | Lightweight TCP or HTTP health check for load balancers and monitoring. | +| **Metrics** | Prometheus-compatible metrics: connections per listener, firewall blocks by rule, backend dial latency, active connections. | diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..cb9de74 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,54 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +mc-proxy is a Layer 4 TLS SNI proxy and router for Metacircular Dynamics services. It reads the SNI hostname from incoming TLS ClientHello messages and proxies the raw TCP stream to the matched backend. It does not terminate TLS. A global firewall (IP, CIDR, GeoIP country blocking) is evaluated before routing. See `ARCHITECTURE.md` for full design. + +## Build Commands + +```bash +make all # vet → lint → test → build +make mc-proxy # build the binary with version injection +make build # compile all packages +make test # run all tests +make vet # go vet +make lint # golangci-lint +``` + +Run a single test: +```bash +go test ./internal/sni -run TestExtract +``` + +## Architecture + +- **Module path**: `git.wntrmute.dev/kyle/mc-proxy` +- **Go with CGO_ENABLED=0**, statically linked, Alpine containers +- **No API surface yet** — config-driven via TOML; gRPC admin API planned for future MCP integration +- **No auth** — this is pre-auth infrastructure; services behind it handle their own MCIAS auth +- **No database** — routes and firewall rules are in the TOML config; SQLite planned for dynamic route management +- **Config**: TOML via `go-toml/v2`, runtime data in `/srv/mc-proxy/` +- **Testing**: stdlib `testing` only, `t.TempDir()` for isolation +- **Linting**: golangci-lint v2 with `.golangci.yaml` + +## Package Structure + +- `internal/config/` — TOML config loading and validation +- `internal/sni/` — TLS ClientHello parser; extracts SNI hostname without consuming bytes +- `internal/firewall/` — global blocklist evaluation (IP, CIDR, GeoIP via MaxMind GeoLite2); thread-safe GeoIP reload +- `internal/proxy/` — bidirectional TCP relay with half-close propagation and idle timeout +- `internal/server/` — orchestrates listeners → firewall → SNI → route → proxy pipeline; graceful shutdown + +## Signals + +- `SIGINT`/`SIGTERM` — graceful shutdown (drain in-flight connections up to `shutdown_timeout`) +- `SIGHUP` — reload GeoIP database without restart + +## Critical Rules + +- mc-proxy never terminates TLS and never modifies the byte stream. +- Firewall rules are always evaluated before any routing decision. +- SNI matching is exact and case-insensitive. +- Blocked connections get a TCP RST — no error messages, no TLS alerts. diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6f755db --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +FROM golang:1.24-alpine AS builder + +WORKDIR /build +COPY go.mod go.sum ./ +RUN go mod download +COPY . . +RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o mc-proxy ./cmd/mc-proxy + +FROM alpine:3.21 + +RUN addgroup -S mc-proxy && adduser -S mc-proxy -G mc-proxy +COPY --from=builder /build/mc-proxy /usr/local/bin/mc-proxy + +USER mc-proxy +ENTRYPOINT ["mc-proxy"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..344b288 --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +.PHONY: build test vet lint proto clean docker all devserver + +LDFLAGS := -trimpath -ldflags="-s -w -X main.version=$(shell git describe --tags --always --dirty)" + +mc-proxy: + go build $(LDFLAGS) -o mc-proxy ./cmd/mc-proxy + +build: + go build ./... + +test: + go test ./... + +vet: + go vet ./... + +lint: + golangci-lint run ./... + +proto: + protoc --go_out=. --go_opt=module=git.wntrmute.dev/kyle/mc-proxy \ + --go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/kyle/mc-proxy \ + proto/mc-proxy/v1/*.proto + +clean: + rm -f mc-proxy + +docker: + docker build -t mc-proxy -f Dockerfile . + +devserver: mc-proxy + @mkdir -p srv + @if [ ! -f srv/mc-proxy.toml ]; then cp mc-proxy.toml.example srv/mc-proxy.toml; echo "Created srv/mc-proxy.toml from example — edit before running."; fi + ./mc-proxy server --config srv/mc-proxy.toml + +all: vet lint test mc-proxy diff --git a/README.md b/README.md new file mode 100644 index 0000000..fc659a9 --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +mc-proxy is a TLS proxy and router for Metacircular Dynamics projects; +it follows the Metacircular Engineering Standards. + +Metacircular services are deployed to a machine that runs these projects +as containers. The proxy should do a few things: + +1. It should have a global firewall front-end. It should allow a few + things: + + 1. Per-country blocks using GeoIP for compliance reasons. + 2. Normal IP/CIDR blocks. Note that a proxy has an explicit port + setting, so the firewall doesn't need to consider ports. + 3. For endpoints marked as HTTPS, we should consider how to do + user-agent blocking. + +2. It should inspect the hostname and route that to the proper + container, similar to how haproxy would do it. + + diff --git a/deploy/mc-proxy.toml.example b/deploy/mc-proxy.toml.example new file mode 100644 index 0000000..47c4ec8 --- /dev/null +++ b/deploy/mc-proxy.toml.example @@ -0,0 +1,37 @@ +# mc-proxy configuration + +# Listeners. Each entry binds a TCP listener on the specified address. +[[listeners]] +addr = ":443" + +[[listeners]] +addr = ":8443" + +[[listeners]] +addr = ":9443" + +# Routes. SNI hostname → backend address. +[[routes]] +hostname = "metacrypt.metacircular.net" +backend = "127.0.0.1:18443" + +[[routes]] +hostname = "mcias.metacircular.net" +backend = "127.0.0.1:28443" + +# Firewall. Global blocklist, evaluated before routing. Default allow. +[firewall] +geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb" +blocked_ips = [] +blocked_cidrs = [] +blocked_countries = ["KP", "CN", "IN", "IL"] + +# Proxy behavior. +[proxy] +connect_timeout = "5s" +idle_timeout = "300s" +shutdown_timeout = "30s" + +# Logging. +[log] +level = "info" diff --git a/deploy/scripts/install.sh b/deploy/scripts/install.sh new file mode 100755 index 0000000..b563da8 --- /dev/null +++ b/deploy/scripts/install.sh @@ -0,0 +1,42 @@ +#!/bin/sh +set -eu + +SERVICE="mc-proxy" +BINARY="/usr/local/bin/${SERVICE}" +DATA_DIR="/srv/${SERVICE}" +UNIT_DIR="/etc/systemd/system" +SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)" +REPO_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)" + +# Create system user and group (idempotent). +if ! id -u "${SERVICE}" >/dev/null 2>&1; then + useradd --system --no-create-home --shell /usr/sbin/nologin "${SERVICE}" + echo "Created system user ${SERVICE}." +fi + +# Install binary. +install -m 0755 "${REPO_DIR}/${SERVICE}" "${BINARY}" +echo "Installed ${BINARY}." + +# Create data directory structure. +install -d -o "${SERVICE}" -g "${SERVICE}" -m 0700 "${DATA_DIR}" +install -d -o "${SERVICE}" -g "${SERVICE}" -m 0700 "${DATA_DIR}/backups" +echo "Created ${DATA_DIR}/." + +# Install example config if none exists. +if [ ! -f "${DATA_DIR}/${SERVICE}.toml" ]; then + install -o "${SERVICE}" -g "${SERVICE}" -m 0600 \ + "${REPO_DIR}/${SERVICE}.toml.example" \ + "${DATA_DIR}/${SERVICE}.toml" + echo "Installed example config to ${DATA_DIR}/${SERVICE}.toml — edit before starting." +fi + +# Install systemd units. +install -m 0644 "${REPO_DIR}/deploy/systemd/${SERVICE}.service" "${UNIT_DIR}/" +systemctl daemon-reload +echo "Installed systemd unit ${SERVICE}.service." + +echo "" +echo "Done. Next steps:" +echo " 1. Edit ${DATA_DIR}/${SERVICE}.toml" +echo " 2. systemctl enable --now ${SERVICE}" diff --git a/deploy/systemd/mc-proxy.service b/deploy/systemd/mc-proxy.service new file mode 100644 index 0000000..11edcf7 --- /dev/null +++ b/deploy/systemd/mc-proxy.service @@ -0,0 +1,32 @@ +[Unit] +Description=mc-proxy TLS proxy and router +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +User=mc-proxy +Group=mc-proxy +ExecStart=/usr/local/bin/mc-proxy server --config /srv/mc-proxy/mc-proxy.toml +Restart=on-failure +RestartSec=5 + +AmbientCapabilities=CAP_NET_BIND_SERVICE + +NoNewPrivileges=true +ProtectSystem=strict +ProtectHome=true +PrivateTmp=true +PrivateDevices=true +ProtectKernelTunables=true +ProtectKernelModules=true +ProtectControlGroups=true +RestrictSUIDSGID=true +RestrictNamespaces=true +LockPersonality=true +MemoryDenyWriteExecute=true +RestrictRealtime=true +ReadWritePaths=/srv/mc-proxy + +[Install] +WantedBy=multi-user.target diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..f133b1b --- /dev/null +++ b/go.mod @@ -0,0 +1,20 @@ +module git.wntrmute.dev/kyle/mc-proxy + +go 1.24.0 + +require ( + github.com/oschwald/maxminddb-golang v1.13.1 + github.com/pelletier/go-toml/v2 v2.2.4 + github.com/spf13/cobra v1.10.2 + google.golang.org/grpc v1.79.2 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.9 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..575f3f8 --- /dev/null +++ b/go.sum @@ -0,0 +1,60 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE= +github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU= +google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..5ceb98d --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,115 @@ +package config + +import ( + "fmt" + "os" + "time" + + "github.com/pelletier/go-toml/v2" +) + +type Config struct { + Listeners []Listener `toml:"listeners"` + GRPC GRPC `toml:"grpc"` + Firewall Firewall `toml:"firewall"` + Proxy Proxy `toml:"proxy"` + Log Log `toml:"log"` +} + +type GRPC struct { + Addr string `toml:"addr"` + TLSCert string `toml:"tls_cert"` + TLSKey string `toml:"tls_key"` + ClientCA string `toml:"client_ca"` +} + +type Listener struct { + Addr string `toml:"addr"` + Routes []Route `toml:"routes"` +} + +type Route struct { + Hostname string `toml:"hostname"` + Backend string `toml:"backend"` +} + +type Firewall struct { + GeoIPDB string `toml:"geoip_db"` + BlockedIPs []string `toml:"blocked_ips"` + BlockedCIDRs []string `toml:"blocked_cidrs"` + BlockedCountries []string `toml:"blocked_countries"` +} + +type Proxy struct { + ConnectTimeout Duration `toml:"connect_timeout"` + IdleTimeout Duration `toml:"idle_timeout"` + ShutdownTimeout Duration `toml:"shutdown_timeout"` +} + +type Log struct { + Level string `toml:"level"` +} + +// Duration wraps time.Duration for TOML string unmarshalling. +type Duration struct { + time.Duration +} + +func (d *Duration) UnmarshalText(text []byte) error { + var err error + d.Duration, err = time.ParseDuration(string(text)) + return err +} + +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("reading config: %w", err) + } + + var cfg Config + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("parsing config: %w", err) + } + + if err := cfg.validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + return &cfg, nil +} + +func (c *Config) validate() error { + if len(c.Listeners) == 0 { + return fmt.Errorf("at least one listener is required") + } + + for i, l := range c.Listeners { + if l.Addr == "" { + return fmt.Errorf("listener %d: addr is required", i) + } + if len(l.Routes) == 0 { + return fmt.Errorf("listener %d (%s): at least one route is required", i, l.Addr) + } + + seen := make(map[string]bool) + for j, r := range l.Routes { + if r.Hostname == "" { + return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j) + } + if r.Backend == "" { + return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j) + } + if seen[r.Hostname] { + return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname) + } + seen[r.Hostname] = true + } + } + + if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" { + return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set") + } + + return nil +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..eee7dbb --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,186 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestLoadValid(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "example.com" + backend = "127.0.0.1:8443" + +[proxy] +connect_timeout = "5s" +idle_timeout = "300s" +shutdown_timeout = "30s" + +[log] +level = "info" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(cfg.Listeners) != 1 { + t.Fatalf("got %d listeners, want 1", len(cfg.Listeners)) + } + if cfg.Listeners[0].Addr != ":443" { + t.Fatalf("got listener addr %q, want %q", cfg.Listeners[0].Addr, ":443") + } + if len(cfg.Listeners[0].Routes) != 1 { + t.Fatalf("got %d routes, want 1", len(cfg.Listeners[0].Routes)) + } + if cfg.Listeners[0].Routes[0].Hostname != "example.com" { + t.Fatalf("got hostname %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "example.com") + } +} + +func TestLoadNoListeners(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[log] +level = "info" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for missing listeners") + } +} + +func TestLoadNoRoutes(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[[listeners]] +addr = ":443" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for missing routes") + } +} + +func TestLoadDuplicateHostnames(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "example.com" + backend = "127.0.0.1:8443" + + [[listeners.routes]] + hostname = "example.com" + backend = "127.0.0.1:9443" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for duplicate hostnames") + } +} + +func TestLoadGeoIPRequiredWithCountries(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "example.com" + backend = "127.0.0.1:8443" + +[firewall] +blocked_countries = ["CN"] +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for blocked_countries without geoip_db") + } +} + +func TestLoadMultipleListeners(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "public.example.com" + backend = "127.0.0.1:8443" + +[[listeners]] +addr = ":8443" + + [[listeners.routes]] + hostname = "internal.example.com" + backend = "127.0.0.1:9443" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if len(cfg.Listeners) != 2 { + t.Fatalf("got %d listeners, want 2", len(cfg.Listeners)) + } + if cfg.Listeners[0].Routes[0].Hostname != "public.example.com" { + t.Fatalf("listener 0 hostname = %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "public.example.com") + } + if cfg.Listeners[1].Routes[0].Hostname != "internal.example.com" { + t.Fatalf("listener 1 hostname = %q, want %q", cfg.Listeners[1].Routes[0].Hostname, "internal.example.com") + } +} + +func TestDuration(t *testing.T) { + var d Duration + if err := d.UnmarshalText([]byte("5s")); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d.Duration.Seconds() != 5 { + t.Fatalf("got %v, want 5s", d.Duration) + } +} diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go new file mode 100644 index 0000000..5c19687 --- /dev/null +++ b/internal/firewall/firewall.go @@ -0,0 +1,218 @@ +package firewall + +import ( + "fmt" + "net/netip" + "strings" + "sync" + + "github.com/oschwald/maxminddb-golang" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" +) + +type geoIPRecord struct { + Country struct { + ISOCode string `maxminddb:"iso_code"` + } `maxminddb:"country"` +} + +// Firewall evaluates global blocklist rules against connection source addresses. +type Firewall struct { + blockedIPs map[netip.Addr]struct{} + blockedCIDRs []netip.Prefix + blockedCountries map[string]struct{} + geoDBPath string + geoDB *maxminddb.Reader + mu sync.RWMutex // protects all mutable state +} + +// New creates a Firewall from the given configuration. +func New(cfg config.Firewall) (*Firewall, error) { + f := &Firewall{ + blockedIPs: make(map[netip.Addr]struct{}), + blockedCountries: make(map[string]struct{}), + geoDBPath: cfg.GeoIPDB, + } + + for _, ip := range cfg.BlockedIPs { + addr, err := netip.ParseAddr(ip) + if err != nil { + return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err) + } + f.blockedIPs[addr] = struct{}{} + } + + for _, cidr := range cfg.BlockedCIDRs { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err) + } + f.blockedCIDRs = append(f.blockedCIDRs, prefix) + } + + for _, code := range cfg.BlockedCountries { + f.blockedCountries[strings.ToUpper(code)] = struct{}{} + } + + if len(f.blockedCountries) > 0 { + if err := f.loadGeoDB(cfg.GeoIPDB); err != nil { + return nil, fmt.Errorf("loading GeoIP database: %w", err) + } + } + + return f, nil +} + +// Blocked returns true if the given address should be blocked. +func (f *Firewall) Blocked(addr netip.Addr) bool { + addr = addr.Unmap() + + f.mu.RLock() + defer f.mu.RUnlock() + + if _, ok := f.blockedIPs[addr]; ok { + return true + } + + for _, prefix := range f.blockedCIDRs { + if prefix.Contains(addr) { + return true + } + } + + if len(f.blockedCountries) > 0 && f.geoDB != nil { + var record geoIPRecord + if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil { + if _, ok := f.blockedCountries[record.Country.ISOCode]; ok { + return true + } + } + } + + return false +} + +// AddIP adds an IP address to the blocklist. +func (f *Firewall) AddIP(ip string) error { + addr, err := netip.ParseAddr(ip) + if err != nil { + return fmt.Errorf("invalid IP %q: %w", ip, err) + } + + f.mu.Lock() + f.blockedIPs[addr] = struct{}{} + f.mu.Unlock() + return nil +} + +// RemoveIP removes an IP address from the blocklist. +func (f *Firewall) RemoveIP(ip string) error { + addr, err := netip.ParseAddr(ip) + if err != nil { + return fmt.Errorf("invalid IP %q: %w", ip, err) + } + + f.mu.Lock() + delete(f.blockedIPs, addr) + f.mu.Unlock() + return nil +} + +// AddCIDR adds a CIDR prefix to the blocklist. +func (f *Firewall) AddCIDR(cidr string) error { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + return fmt.Errorf("invalid CIDR %q: %w", cidr, err) + } + + f.mu.Lock() + f.blockedCIDRs = append(f.blockedCIDRs, prefix) + f.mu.Unlock() + return nil +} + +// RemoveCIDR removes a CIDR prefix from the blocklist. +func (f *Firewall) RemoveCIDR(cidr string) error { + prefix, err := netip.ParsePrefix(cidr) + if err != nil { + return fmt.Errorf("invalid CIDR %q: %w", cidr, err) + } + + f.mu.Lock() + for i, p := range f.blockedCIDRs { + if p == prefix { + f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...) + break + } + } + f.mu.Unlock() + return nil +} + +// AddCountry adds a country code to the blocklist. +func (f *Firewall) AddCountry(code string) { + f.mu.Lock() + f.blockedCountries[strings.ToUpper(code)] = struct{}{} + f.mu.Unlock() +} + +// RemoveCountry removes a country code from the blocklist. +func (f *Firewall) RemoveCountry(code string) { + f.mu.Lock() + delete(f.blockedCountries, strings.ToUpper(code)) + f.mu.Unlock() +} + +// Rules returns a snapshot of all current firewall rules. +func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) { + f.mu.RLock() + defer f.mu.RUnlock() + + for addr := range f.blockedIPs { + ips = append(ips, addr.String()) + } + for _, prefix := range f.blockedCIDRs { + cidrs = append(cidrs, prefix.String()) + } + for code := range f.blockedCountries { + countries = append(countries, code) + } + return +} + +// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use. +func (f *Firewall) ReloadGeoIP() error { + if f.geoDBPath == "" { + return nil + } + return f.loadGeoDB(f.geoDBPath) +} + +// Close releases resources held by the firewall. +func (f *Firewall) Close() error { + f.mu.Lock() + defer f.mu.Unlock() + + if f.geoDB != nil { + return f.geoDB.Close() + } + return nil +} + +func (f *Firewall) loadGeoDB(path string) error { + db, err := maxminddb.Open(path) + if err != nil { + return err + } + + f.mu.Lock() + old := f.geoDB + f.geoDB = db + f.mu.Unlock() + + if old != nil { + old.Close() + } + return nil +} diff --git a/internal/firewall/firewall_test.go b/internal/firewall/firewall_test.go new file mode 100644 index 0000000..f442c28 --- /dev/null +++ b/internal/firewall/firewall_test.go @@ -0,0 +1,141 @@ +package firewall + +import ( + "net/netip" + "testing" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" +) + +func TestEmptyFirewall(t *testing.T) { + fw, err := New(config.Firewall{}) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"} + for _, a := range addrs { + addr := netip.MustParseAddr(a) + if fw.Blocked(addr) { + t.Fatalf("empty firewall blocked %s", addr) + } + } +} + +func TestIPBlocking(t *testing.T) { + fw, err := New(config.Firewall{ + BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + tests := []struct { + addr string + blocked bool + }{ + {"192.0.2.1", true}, + {"192.0.2.2", false}, + {"2001:db8::dead", true}, + {"2001:db8::beef", false}, + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + if got := fw.Blocked(addr); got != tt.blocked { + t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked) + } + } +} + +func TestCIDRBlocking(t *testing.T) { + fw, err := New(config.Firewall{ + BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + tests := []struct { + addr string + blocked bool + }{ + {"198.51.100.1", true}, + {"198.51.100.254", true}, + {"198.51.101.1", false}, + {"2001:db8::1", true}, + {"2001:db9::1", false}, + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + if got := fw.Blocked(addr); got != tt.blocked { + t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked) + } + } +} + +func TestIPv4MappedIPv6(t *testing.T) { + fw, err := New(config.Firewall{ + BlockedIPs: []string{"192.0.2.1"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + // IPv4-mapped IPv6 representation of 192.0.2.1. + addr := netip.MustParseAddr("::ffff:192.0.2.1") + if !fw.Blocked(addr) { + t.Fatal("expected IPv4-mapped IPv6 address to be blocked") + } +} + +func TestInvalidIP(t *testing.T) { + _, err := New(config.Firewall{ + BlockedIPs: []string{"not-an-ip"}, + }) + if err == nil { + t.Fatal("expected error for invalid IP") + } +} + +func TestInvalidCIDR(t *testing.T) { + _, err := New(config.Firewall{ + BlockedCIDRs: []string{"not-a-cidr"}, + }) + if err == nil { + t.Fatal("expected error for invalid CIDR") + } +} + +func TestCombinedRules(t *testing.T) { + fw, err := New(config.Firewall{ + BlockedIPs: []string{"10.0.0.1"}, + BlockedCIDRs: []string{"192.168.0.0/16"}, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + defer fw.Close() + + tests := []struct { + addr string + blocked bool + }{ + {"10.0.0.1", true}, // IP match + {"10.0.0.2", false}, // no match + {"192.168.1.1", true}, // CIDR match + {"172.16.0.1", false}, // no match + } + + for _, tt := range tests { + addr := netip.MustParseAddr(tt.addr) + if got := fw.Blocked(addr); got != tt.blocked { + t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked) + } + } +} diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go new file mode 100644 index 0000000..e40e568 --- /dev/null +++ b/internal/grpcserver/grpcserver.go @@ -0,0 +1,246 @@ +package grpcserver + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "log/slog" + "net" + "os" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + + pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1" + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/server" +) + +// AdminServer implements the ProxyAdmin gRPC service. +type AdminServer struct { + pb.UnimplementedProxyAdminServer + srv *server.Server + logger *slog.Logger +} + +// New creates a gRPC server with TLS and optional mTLS. +func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) { + cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) + if err != nil { + return nil, nil, fmt.Errorf("loading TLS keypair: %w", err) + } + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS13, + } + + // mTLS: require and verify client certificates. + if cfg.ClientCA != "" { + caCert, err := os.ReadFile(cfg.ClientCA) + if err != nil { + return nil, nil, fmt.Errorf("reading client CA: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCert) { + return nil, nil, fmt.Errorf("failed to parse client CA certificate") + } + tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert + tlsConfig.ClientCAs = pool + } + + creds := credentials.NewTLS(tlsConfig) + grpcServer := grpc.NewServer(grpc.Creds(creds)) + + admin := &AdminServer{ + srv: srv, + logger: logger, + } + pb.RegisterProxyAdminServer(grpcServer, admin) + + ln, err := net.Listen("tcp", cfg.Addr) + if err != nil { + return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err) + } + + return grpcServer, ln, nil +} + +// ListRoutes returns the route table for a listener. +func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) { + ls, err := a.findListener(req.ListenerAddr) + if err != nil { + return nil, err + } + + routes := ls.Routes() + resp := &pb.ListRoutesResponse{ + ListenerAddr: ls.Addr, + } + for hostname, backend := range routes { + resp.Routes = append(resp.Routes, &pb.Route{ + Hostname: hostname, + Backend: backend, + }) + } + return resp, nil +} + +// AddRoute adds a route to a listener's route table. +func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) { + if req.Route == nil { + return nil, status.Error(codes.InvalidArgument, "route is required") + } + if req.Route.Hostname == "" || req.Route.Backend == "" { + return nil, status.Error(codes.InvalidArgument, "hostname and backend are required") + } + + ls, err := a.findListener(req.ListenerAddr) + if err != nil { + return nil, err + } + + if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil { + return nil, status.Errorf(codes.AlreadyExists, "%v", err) + } + + a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend) + return &pb.AddRouteResponse{}, nil +} + +// RemoveRoute removes a route from a listener's route table. +func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) { + if req.Hostname == "" { + return nil, status.Error(codes.InvalidArgument, "hostname is required") + } + + ls, err := a.findListener(req.ListenerAddr) + if err != nil { + return nil, err + } + + if err := ls.RemoveRoute(req.Hostname); err != nil { + return nil, status.Errorf(codes.NotFound, "%v", err) + } + + a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname) + return &pb.RemoveRouteResponse{}, nil +} + +// GetFirewallRules returns all current firewall rules. +func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) { + ips, cidrs, countries := a.srv.Firewall().Rules() + + var rules []*pb.FirewallRule + for _, ip := range ips { + rules = append(rules, &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, + Value: ip, + }) + } + for _, cidr := range cidrs { + rules = append(rules, &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, + Value: cidr, + }) + } + for _, code := range countries { + rules = append(rules, &pb.FirewallRule{ + Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, + Value: code, + }) + } + + return &pb.GetFirewallRulesResponse{Rules: rules}, nil +} + +// AddFirewallRule adds a firewall rule. +func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) { + if req.Rule == nil { + return nil, status.Error(codes.InvalidArgument, "rule is required") + } + + fw := a.srv.Firewall() + switch req.Rule.Type { + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: + if err := fw.AddIP(req.Rule.Value); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: + if err := fw.AddCIDR(req.Rule.Value); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: + if req.Rule.Value == "" { + return nil, status.Error(codes.InvalidArgument, "country code is required") + } + fw.AddCountry(req.Rule.Value) + default: + return nil, status.Error(codes.InvalidArgument, "unknown rule type") + } + + a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value) + return &pb.AddFirewallRuleResponse{}, nil +} + +// RemoveFirewallRule removes a firewall rule. +func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) { + if req.Rule == nil { + return nil, status.Error(codes.InvalidArgument, "rule is required") + } + + fw := a.srv.Firewall() + switch req.Rule.Type { + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: + if err := fw.RemoveIP(req.Rule.Value); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: + if err := fw.RemoveCIDR(req.Rule.Value); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: + if req.Rule.Value == "" { + return nil, status.Error(codes.InvalidArgument, "country code is required") + } + fw.RemoveCountry(req.Rule.Value) + default: + return nil, status.Error(codes.InvalidArgument, "unknown rule type") + } + + a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value) + return &pb.RemoveFirewallRuleResponse{}, nil +} + +// GetStatus returns the proxy's current status. +func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) { + var listeners []*pb.ListenerStatus + for _, ls := range a.srv.Listeners() { + routes := ls.Routes() + listeners = append(listeners, &pb.ListenerStatus{ + Addr: ls.Addr, + RouteCount: int32(len(routes)), + ActiveConnections: ls.ActiveConnections.Load(), + }) + } + + return &pb.GetStatusResponse{ + Version: a.srv.Version(), + StartedAt: timestamppb.New(a.srv.StartedAt()), + Listeners: listeners, + TotalConnections: a.srv.TotalConnections(), + }, nil +} + +func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) { + for _, ls := range a.srv.Listeners() { + if ls.Addr == addr { + return ls, nil + } + } + return nil, status.Errorf(codes.NotFound, "listener %q not found", addr) +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..efa0cbd --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,105 @@ +package proxy + +import ( + "context" + "io" + "net" + "sync" + "time" +) + +// Result holds the outcome of a relay operation. +type Result struct { + ClientBytes int64 // bytes sent from client to backend + BackendBytes int64 // bytes sent from backend to client +} + +// Relay performs bidirectional byte copying between client and backend. +// The peeked bytes (the TLS ClientHello) are written to the backend first. +// Relay blocks until both directions are done or ctx is cancelled. +func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) { + // Forward the buffered ClientHello to the backend. + if len(peeked) > 0 { + if _, err := backend.Write(peeked); err != nil { + return Result{}, err + } + } + + // Cancel context closes both connections to unblock copy goroutines. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + <-ctx.Done() + client.Close() + backend.Close() + }() + + var ( + result Result + wg sync.WaitGroup + errC2B error + errB2C error + ) + + wg.Add(2) + + // client → backend + go func() { + defer wg.Done() + result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout) + // Half-close backend's write side. + if hc, ok := backend.(interface{ CloseWrite() error }); ok { + hc.CloseWrite() + } + }() + + // backend → client + go func() { + defer wg.Done() + result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout) + // Half-close client's write side. + if hc, ok := client.(interface{ CloseWrite() error }); ok { + hc.CloseWrite() + } + }() + + wg.Wait() + + // If context was cancelled, that's the primary error. + if ctx.Err() != nil { + return result, ctx.Err() + } + + // Return the first meaningful error, if any. + if errC2B != nil { + return result, errC2B + } + return result, errB2C +} + +// copyWithIdleTimeout copies from src to dst, resetting the idle deadline +// on each successful read. +func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) { + buf := make([]byte, 32*1024) + var total int64 + + for { + src.SetReadDeadline(time.Now().Add(idleTimeout)) + nr, readErr := src.Read(buf) + if nr > 0 { + dst.SetWriteDeadline(time.Now().Add(idleTimeout)) + nw, writeErr := dst.Write(buf[:nr]) + total += int64(nw) + if writeErr != nil { + return total, writeErr + } + } + if readErr != nil { + if readErr == io.EOF { + return total, nil + } + return total, readErr + } + } +} diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go new file mode 100644 index 0000000..cbaf986 --- /dev/null +++ b/internal/proxy/proxy_test.go @@ -0,0 +1,259 @@ +package proxy + +import ( + "bytes" + "context" + "crypto/rand" + "io" + "net" + "testing" + "time" +) + +func TestRelayBasic(t *testing.T) { + // Set up a TCP listener to act as the backend. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer backendLn.Close() + + peeked := []byte("peeked-hello-bytes") + clientData := []byte("data from client") + backendData := []byte("data from backend") + + // Backend goroutine: accept, read peeked+client data, send response, close. + backendDone := make(chan []byte, 1) + go func() { + conn, err := backendLn.Accept() + if err != nil { + return + } + defer conn.Close() + + // Read everything the backend receives. + received, _ := io.ReadAll(conn) + backendDone <- received + + // This won't work since ReadAll waits for EOF. + // Instead, restructure: read expected bytes, write response, close write. + }() + + // Restructure: use a more controlled flow. + backendLn.Close() + + // Use a real TCP pair for proper half-close. + backendLn2, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer backendLn2.Close() + + go func() { + conn, err := backendLn2.Accept() + if err != nil { + return + } + defer conn.Close() + + // Read peeked + client data. + buf := make([]byte, len(peeked)+len(clientData)) + n, _ := io.ReadFull(conn, buf) + backendDone <- buf[:n] + + // Send response. + conn.Write(backendData) + + // Close write side to signal EOF. + if tc, ok := conn.(*net.TCPConn); ok { + tc.CloseWrite() + } + }() + + // Dial the backend. + backendConn, err := net.Dial("tcp", backendLn2.Addr().String()) + if err != nil { + t.Fatalf("dial backend: %v", err) + } + + // Create a client-side TCP pair. + clientLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer clientLn.Close() + + clientConn, err := net.Dial("tcp", clientLn.Addr().String()) + if err != nil { + t.Fatalf("dial client: %v", err) + } + serverSideClient, err := clientLn.Accept() + if err != nil { + t.Fatalf("accept client: %v", err) + } + + // Client sends data then closes write. + go func() { + clientConn.Write(clientData) + if tc, ok := clientConn.(*net.TCPConn); ok { + tc.CloseWrite() + } + }() + + // Run relay. + result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second) + if err != nil { + t.Fatalf("relay error: %v", err) + } + + // Verify backend received peeked + client data. + received := <-backendDone + expected := append(peeked, clientData...) + if !bytes.Equal(received, expected) { + t.Fatalf("backend received %q, want %q", received, expected) + } + + // Verify client received backend data. + clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) + clientReceived, _ := io.ReadAll(clientConn) + if !bytes.Equal(clientReceived, backendData) { + t.Fatalf("client received %q, want %q", clientReceived, backendData) + } + + if result.ClientBytes != int64(len(clientData)) { + t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData)) + } + if result.BackendBytes != int64(len(backendData)) { + t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData)) + } +} + +func TestRelayIdleTimeout(t *testing.T) { + // Two connected pairs via TCP. + clientA, clientB := tcpPair(t) + defer clientA.Close() + defer clientB.Close() + + backendA, backendB := tcpPair(t) + defer backendA.Close() + defer backendB.Close() + + start := time.Now() + _, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond) + elapsed := time.Since(start) + + // Should return due to idle timeout. + if err == nil { + t.Fatal("expected error from idle timeout") + } + + if elapsed > 2*time.Second { + t.Fatalf("relay took %v, expected ~100ms", elapsed) + } +} + +func TestRelayContextCancel(t *testing.T) { + clientA, clientB := tcpPair(t) + defer clientA.Close() + defer clientB.Close() + + backendA, backendB := tcpPair(t) + defer backendA.Close() + defer backendB.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + done := make(chan struct{}) + go func() { + Relay(ctx, clientB, backendA, nil, time.Minute) + close(done) + }() + + // Cancel after a short delay. + time.Sleep(50 * time.Millisecond) + cancel() + + select { + case <-done: + // OK + case <-time.After(2 * time.Second): + t.Fatal("relay did not return after context cancel") + } + + _ = backendB // keep reference +} + +func TestRelayLargeTransfer(t *testing.T) { + clientA, clientB := tcpPair(t) + defer clientA.Close() + defer clientB.Close() + + backendA, backendB := tcpPair(t) + defer backendA.Close() + defer backendB.Close() + + // 1 MB of random data. + data := make([]byte, 1<<20) + if _, err := rand.Read(data); err != nil { + t.Fatalf("rand read: %v", err) + } + + go func() { + clientA.Write(data) + if tc, ok := clientA.(*net.TCPConn); ok { + tc.CloseWrite() + } + }() + + // Backend reads and echoes chunks, then closes when client EOF arrives. + go func() { + buf := make([]byte, 32*1024) + for { + n, err := backendB.Read(buf) + if n > 0 { + backendB.Write(buf[:n]) + } + if err != nil { + break + } + } + if tc, ok := backendB.(*net.TCPConn); ok { + tc.CloseWrite() + } + }() + + result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second) + if err != nil { + t.Fatalf("relay error: %v", err) + } + + if result.ClientBytes != int64(len(data)) { + t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data)) + } +} + +// tcpPair returns two connected TCP connections. +func tcpPair(t *testing.T) (net.Conn, net.Conn) { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen: %v", err) + } + defer ln.Close() + + var serverConn net.Conn + done := make(chan struct{}) + go func() { + serverConn, _ = ln.Accept() + close(done) + }() + + clientConn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatalf("dial: %v", err) + } + + <-done + return clientConn, serverConn +} diff --git a/internal/server/server.go b/internal/server/server.go new file mode 100644 index 0000000..6ec71a8 --- /dev/null +++ b/internal/server/server.go @@ -0,0 +1,271 @@ +package server + +import ( + "context" + "fmt" + "log/slog" + "net" + "net/netip" + "strings" + "sync" + "sync/atomic" + "time" + + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" + "git.wntrmute.dev/kyle/mc-proxy/internal/proxy" + "git.wntrmute.dev/kyle/mc-proxy/internal/sni" +) + +// ListenerState holds the mutable state for a single proxy listener. +type ListenerState struct { + Addr string + routes map[string]string // lowercase hostname → backend addr + mu sync.RWMutex + ActiveConnections atomic.Int64 +} + +// Routes returns a snapshot of the listener's route table. +func (ls *ListenerState) Routes() map[string]string { + ls.mu.RLock() + defer ls.mu.RUnlock() + + m := make(map[string]string, len(ls.routes)) + for k, v := range ls.routes { + m[k] = v + } + return m +} + +// AddRoute adds a route to the listener. Returns an error if the hostname +// already exists. +func (ls *ListenerState) AddRoute(hostname, backend string) error { + key := strings.ToLower(hostname) + + ls.mu.Lock() + defer ls.mu.Unlock() + + if _, ok := ls.routes[key]; ok { + return fmt.Errorf("route %q already exists", hostname) + } + ls.routes[key] = backend + return nil +} + +// RemoveRoute removes a route from the listener. Returns an error if the +// hostname does not exist. +func (ls *ListenerState) RemoveRoute(hostname string) error { + key := strings.ToLower(hostname) + + ls.mu.Lock() + defer ls.mu.Unlock() + + if _, ok := ls.routes[key]; !ok { + return fmt.Errorf("route %q not found", hostname) + } + delete(ls.routes, key) + return nil +} + +func (ls *ListenerState) lookupRoute(hostname string) (string, bool) { + ls.mu.RLock() + defer ls.mu.RUnlock() + + backend, ok := ls.routes[hostname] + return backend, ok +} + +// Server is the mc-proxy server. It manages listeners, firewall evaluation, +// SNI-based routing, and bidirectional proxying. +type Server struct { + cfg *config.Config + fw *firewall.Firewall + listeners []*ListenerState + logger *slog.Logger + wg sync.WaitGroup // tracks active connections + startedAt time.Time + version string +} + +// New creates a Server from the given configuration. +func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) { + fw, err := firewall.New(cfg.Firewall) + if err != nil { + return nil, fmt.Errorf("initializing firewall: %w", err) + } + + var listeners []*ListenerState + for _, lcfg := range cfg.Listeners { + routes := make(map[string]string, len(lcfg.Routes)) + for _, r := range lcfg.Routes { + routes[strings.ToLower(r.Hostname)] = r.Backend + } + listeners = append(listeners, &ListenerState{ + Addr: lcfg.Addr, + routes: routes, + }) + } + + return &Server{ + cfg: cfg, + fw: fw, + listeners: listeners, + logger: logger, + version: version, + }, nil +} + +// Firewall returns the server's firewall for use by the gRPC admin API. +func (s *Server) Firewall() *firewall.Firewall { + return s.fw +} + +// Listeners returns the server's listener states for use by the gRPC admin API. +func (s *Server) Listeners() []*ListenerState { + return s.listeners +} + +// StartedAt returns the time the server started. +func (s *Server) StartedAt() time.Time { + return s.startedAt +} + +// Version returns the server's version string. +func (s *Server) Version() string { + return s.version +} + +// TotalConnections returns the total number of active connections. +func (s *Server) TotalConnections() int64 { + var total int64 + for _, ls := range s.listeners { + total += ls.ActiveConnections.Load() + } + return total +} + +// Run starts all listeners and blocks until ctx is cancelled. +func (s *Server) Run(ctx context.Context) error { + s.startedAt = time.Now() + + var netListeners []net.Listener + + for _, ls := range s.listeners { + ln, err := net.Listen("tcp", ls.Addr) + if err != nil { + for _, l := range netListeners { + l.Close() + } + return fmt.Errorf("listening on %s: %w", ls.Addr, err) + } + s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes)) + netListeners = append(netListeners, ln) + } + + // Start accept loops. + for i, ln := range netListeners { + ln := ln + ls := s.listeners[i] + go s.serve(ctx, ln, ls) + } + + // Block until shutdown signal. + <-ctx.Done() + s.logger.Info("shutting down") + + // Stop accepting new connections. + for _, ln := range netListeners { + ln.Close() + } + + // Wait for in-flight connections with a timeout. + done := make(chan struct{}) + go func() { + s.wg.Wait() + close(done) + }() + + select { + case <-done: + s.logger.Info("all connections drained") + case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration): + s.logger.Warn("shutdown timeout exceeded, forcing close") + } + + s.fw.Close() + return nil +} + +// ReloadGeoIP reloads the GeoIP database from disk. +func (s *Server) ReloadGeoIP() error { + return s.fw.ReloadGeoIP() +} + +func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) { + for { + conn, err := ln.Accept() + if err != nil { + if ctx.Err() != nil { + return + } + s.logger.Error("accept error", "error", err) + continue + } + + s.wg.Add(1) + ls.ActiveConnections.Add(1) + go s.handleConn(ctx, conn, ls) + } +} + +func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) { + defer s.wg.Done() + defer ls.ActiveConnections.Add(-1) + defer conn.Close() + + remoteAddr := conn.RemoteAddr().String() + addrPort, err := netip.ParseAddrPort(remoteAddr) + if err != nil { + s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err) + return + } + addr := addrPort.Addr() + + if s.fw.Blocked(addr) { + s.logger.Debug("blocked by firewall", "addr", addr) + return + } + + hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second)) + if err != nil { + s.logger.Debug("SNI extraction failed", "addr", addr, "error", err) + return + } + + backend, ok := ls.lookupRoute(hostname) + if !ok { + s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname) + return + } + + backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration) + if err != nil { + s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err) + return + } + defer backendConn.Close() + + s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend) + + result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration) + if err != nil && ctx.Err() == nil { + s.logger.Debug("relay ended", "hostname", hostname, "error", err) + } + + s.logger.Info("connection closed", + "addr", addr, + "hostname", hostname, + "client_bytes", result.ClientBytes, + "backend_bytes", result.BackendBytes, + ) +} diff --git a/internal/sni/sni.go b/internal/sni/sni.go new file mode 100644 index 0000000..9fe8de1 --- /dev/null +++ b/internal/sni/sni.go @@ -0,0 +1,175 @@ +package sni + +import ( + "encoding/binary" + "fmt" + "io" + "net" + "strings" + "time" +) + +const maxBufferSize = 16384 // 16 KiB, max TLS record size + +// Extract reads the TLS ClientHello from conn and returns the SNI hostname. +// The returned peeked bytes contain the full ClientHello and must be forwarded +// to the backend before starting the bidirectional copy. +// +// A read deadline is set on the connection to prevent slowloris attacks. +func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) { + conn.SetReadDeadline(deadline) + defer conn.SetReadDeadline(time.Time{}) + + // Read TLS record header (5 bytes). + header := make([]byte, 5) + if _, err := io.ReadFull(conn, header); err != nil { + return "", nil, fmt.Errorf("reading TLS record header: %w", err) + } + + // Verify this is a TLS handshake record (content type 0x16). + if header[0] != 0x16 { + return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0]) + } + + // Record length. + recordLen := int(binary.BigEndian.Uint16(header[3:5])) + if recordLen == 0 || recordLen > maxBufferSize-5 { + return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen) + } + + // Read the full record body. + buf := make([]byte, 5+recordLen) + copy(buf, header) + if _, err := io.ReadFull(conn, buf[5:]); err != nil { + return "", nil, fmt.Errorf("reading TLS record body: %w", err) + } + + // Parse the handshake message from the record body. + hostname, err = parseClientHello(buf[5:]) + if err != nil { + return "", nil, err + } + + return hostname, buf, nil +} + +func parseClientHello(data []byte) (string, error) { + if len(data) < 4 { + return "", fmt.Errorf("handshake message too short") + } + + // Handshake type: 0x01 = ClientHello. + if data[0] != 0x01 { + return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0]) + } + + // Handshake length (3 bytes). + hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < hsLen { + return "", fmt.Errorf("ClientHello truncated") + } + data = data[:hsLen] + + // Skip client version (2 bytes) + random (32 bytes). + if len(data) < 34 { + return "", fmt.Errorf("ClientHello too short for version+random") + } + data = data[34:] + + // Skip session ID (1-byte length prefix). + if len(data) < 1 { + return "", fmt.Errorf("ClientHello too short for session ID length") + } + sidLen := int(data[0]) + data = data[1:] + if len(data) < sidLen { + return "", fmt.Errorf("ClientHello truncated at session ID") + } + data = data[sidLen:] + + // Skip cipher suites (2-byte length prefix). + if len(data) < 2 { + return "", fmt.Errorf("ClientHello too short for cipher suites length") + } + csLen := int(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + if len(data) < csLen { + return "", fmt.Errorf("ClientHello truncated at cipher suites") + } + data = data[csLen:] + + // Skip compression methods (1-byte length prefix). + if len(data) < 1 { + return "", fmt.Errorf("ClientHello too short for compression methods length") + } + cmLen := int(data[0]) + data = data[1:] + if len(data) < cmLen { + return "", fmt.Errorf("ClientHello truncated at compression methods") + } + data = data[cmLen:] + + // Extensions (2-byte total length). + if len(data) < 2 { + return "", fmt.Errorf("no extensions in ClientHello") + } + extLen := int(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + if len(data) < extLen { + return "", fmt.Errorf("ClientHello truncated at extensions") + } + data = data[:extLen] + + return findSNI(data) +} + +func findSNI(data []byte) (string, error) { + for len(data) >= 4 { + extType := binary.BigEndian.Uint16(data[:2]) + extDataLen := int(binary.BigEndian.Uint16(data[2:4])) + data = data[4:] + if len(data) < extDataLen { + return "", fmt.Errorf("extension truncated") + } + + if extType == 0x0000 { // server_name + return parseServerNameExtension(data[:extDataLen]) + } + + data = data[extDataLen:] + } + + return "", fmt.Errorf("no SNI extension found") +} + +func parseServerNameExtension(data []byte) (string, error) { + if len(data) < 2 { + return "", fmt.Errorf("server_name extension too short") + } + + // Server name list length. + listLen := int(binary.BigEndian.Uint16(data[:2])) + data = data[2:] + if len(data) < listLen { + return "", fmt.Errorf("server_name list truncated") + } + data = data[:listLen] + + for len(data) >= 3 { + nameType := data[0] + nameLen := int(binary.BigEndian.Uint16(data[1:3])) + data = data[3:] + if len(data) < nameLen { + return "", fmt.Errorf("server_name entry truncated") + } + + if nameType == 0x00 { // hostname + return strings.ToLower(string(data[:nameLen])), nil + } + + data = data[nameLen:] + } + + return "", fmt.Errorf("no hostname in server_name extension") +} diff --git a/internal/sni/sni_test.go b/internal/sni/sni_test.go new file mode 100644 index 0000000..9826504 --- /dev/null +++ b/internal/sni/sni_test.go @@ -0,0 +1,220 @@ +package sni + +import ( + "encoding/binary" + "net" + "testing" + "time" +) + +func TestExtract(t *testing.T) { + tests := []struct { + name string + sni string + wantSNI string + wantErr bool + }{ + {"basic", "example.com", "example.com", false}, + {"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false}, + {"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + hello := buildClientHello(tt.sni) + + go func() { + client.Write(hello) + }() + + hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second)) + if tt.wantErr { + if err == nil { + t.Fatal("expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hostname != tt.wantSNI { + t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI) + } + if len(peeked) != len(hello) { + t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello)) + } + }) + } +} + +func TestExtractNoSNI(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + hello := buildClientHelloNoSNI() + + go func() { + client.Write(hello) + }() + + _, _, err := Extract(server, time.Now().Add(5*time.Second)) + if err == nil { + t.Fatal("expected error for ClientHello without SNI") + } +} + +func TestExtractNotTLS(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + go func() { + client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) + }() + + _, _, err := Extract(server, time.Now().Add(5*time.Second)) + if err == nil { + t.Fatal("expected error for non-TLS data") + } +} + +func TestExtractTruncated(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + go func() { + // Write just the TLS record header, then close. + client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50}) + client.Close() + }() + + _, _, err := Extract(server, time.Now().Add(5*time.Second)) + if err == nil { + t.Fatal("expected error for truncated record") + } +} + +func TestExtractOversizedRecord(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + go func() { + // Record header claiming a length larger than 16 KiB. + header := []byte{0x16, 0x03, 0x01} + header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5 + client.Write(header) + client.Close() + }() + + _, _, err := Extract(server, time.Now().Add(5*time.Second)) + if err == nil { + t.Fatal("expected error for oversized record") + } +} + +func TestExtractMultipleExtensions(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + hello := buildClientHelloWithExtraExtensions("target.example.com") + + go func() { + client.Write(hello) + }() + + hostname, _, err := Extract(server, time.Now().Add(5*time.Second)) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if hostname != "target.example.com" { + t.Fatalf("got hostname %q, want %q", hostname, "target.example.com") + } +} + +// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension. +func buildClientHello(serverName string) []byte { + return buildClientHelloWithExtensions(sniExtension(serverName)) +} + +// buildClientHelloNoSNI constructs a ClientHello with no extensions. +func buildClientHelloNoSNI() []byte { + return buildClientHelloWithExtensions(nil) +} + +// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI. +func buildClientHelloWithExtraExtensions(serverName string) []byte { + // Dummy extension (type 0xFF01, empty data). + dummy := []byte{0xFF, 0x01, 0x00, 0x00} + ext := append(dummy, sniExtension(serverName)...) + return buildClientHelloWithExtensions(ext) +} + +func buildClientHelloWithExtensions(extensions []byte) []byte { + var hello []byte + + // Client version: TLS 1.2. + hello = append(hello, 0x03, 0x03) + + // Random: 32 bytes of zeros. + hello = append(hello, make([]byte, 32)...) + + // Session ID: empty. + hello = append(hello, 0x00) + + // Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256). + hello = append(hello, 0x00, 0x02, 0x00, 0x9C) + + // Compression methods: null. + hello = append(hello, 0x01, 0x00) + + // Extensions. + if len(extensions) > 0 { + hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) + hello = append(hello, extensions...) + } + + // Wrap in handshake header (type 0x01 = ClientHello). + handshake := []byte{0x01, 0x00, 0x00, 0x00} + handshake[1] = byte(len(hello) >> 16) + handshake[2] = byte(len(hello) >> 8) + handshake[3] = byte(len(hello)) + handshake = append(handshake, hello...) + + // Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0). + record := []byte{0x16, 0x03, 0x01} + record = binary.BigEndian.AppendUint16(record, uint16(len(handshake))) + record = append(record, handshake...) + + return record +} + +func sniExtension(serverName string) []byte { + name := []byte(serverName) + + // Server name entry: type 0x00 (hostname), length, name. + var entry []byte + entry = append(entry, 0x00) + entry = binary.BigEndian.AppendUint16(entry, uint16(len(name))) + entry = append(entry, name...) + + // Server name list: length prefix. + var list []byte + list = binary.BigEndian.AppendUint16(list, uint16(len(entry))) + list = append(list, entry...) + + // Extension: type 0x0000 (server_name), length, data. + var ext []byte + ext = binary.BigEndian.AppendUint16(ext, 0x0000) + ext = binary.BigEndian.AppendUint16(ext, uint16(len(list))) + ext = append(ext, list...) + + return ext +} diff --git a/mc-proxy.toml.example b/mc-proxy.toml.example new file mode 100644 index 0000000..a05fa7d --- /dev/null +++ b/mc-proxy.toml.example @@ -0,0 +1,51 @@ +# mc-proxy configuration + +# Listeners. Each listener binds a TCP port and has its own route table. +[[listeners]] +addr = ":443" + + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" + +[[listeners]] +addr = ":8443" + + [[listeners.routes]] + hostname = "metacrypt.metacircular.net" + backend = "127.0.0.1:18443" + +[[listeners]] +addr = ":9443" + + [[listeners.routes]] + hostname = "mcias.metacircular.net" + backend = "127.0.0.1:28443" + +# gRPC admin API. Optional — omit addr to disable. +[grpc] +addr = "127.0.0.1:9090" +tls_cert = "/srv/mc-proxy/certs/cert.pem" +tls_key = "/srv/mc-proxy/certs/key.pem" +client_ca = "/srv/mc-proxy/certs/ca.pem" # mTLS; omit to disable client auth + +# Firewall. Global blocklist, evaluated before routing. Default allow. +[firewall] +geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb" +blocked_ips = [] +blocked_cidrs = [] +blocked_countries = ["KP", "CN", "IN", "IL"] + +# Proxy behavior. +[proxy] +connect_timeout = "5s" +idle_timeout = "300s" +shutdown_timeout = "30s" + +# Logging. +[log] +level = "info"