Initial implementation of mc-proxy

Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking),
per-listener route tables, bidirectional TCP relay with half-close
propagation, and a gRPC admin API (routes, firewall, status) with
TLS/mTLS support.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 02:56:24 -07:00
commit c7024dcdf0
23 changed files with 2693 additions and 0 deletions

12
.gitignore vendored Normal file
View File

@@ -0,0 +1,12 @@
# Binary
mc-proxy
# Runtime data
srv/
# IDE
.idea/
.vscode/
# OS
.DS_Store

34
.golangci.yaml Normal file
View File

@@ -0,0 +1,34 @@
version: "2"
linters:
enable:
- errcheck
- govet
- ineffassign
- unused
- errorlint
- gosec
- staticcheck
- revive
- gofmt
- goimports
settings:
errcheck:
check-type-assertions: true
govet:
disable:
- shadow
gosec:
severity: medium
confidence: medium
excludes:
- G104
issues:
max-issues-per-linter: 0
exclude-rules:
- path: _test\.go
linters:
- gosec
text: "G101"

345
ARCHITECTURE.md Normal file
View File

@@ -0,0 +1,345 @@
# ARCHITECTURE.md
mc-proxy is a Layer 4 TLS proxy and router for Metacircular Dynamics
services. It inspects the SNI field of incoming TLS ClientHello messages to
determine the target backend, then proxies raw TCP between the client and
the appropriate container. A global firewall evaluates every connection
before routing.
## Table of Contents
1. [System Overview](#system-overview)
2. [Connection Lifecycle](#connection-lifecycle)
3. [Firewall](#firewall)
4. [Routing](#routing)
5. [Configuration](#configuration)
6. [Storage](#storage)
7. [Deployment](#deployment)
8. [Security Model](#security-model)
9. [Future Work](#future-work)
---
## System Overview
```
┌─────────────────────────────────────┐
│ mc-proxy │
Clients ──────┐ │ │
│ │ ┌──────────┐ ┌───────┐ ┌─────┐ │ ┌────────────┐
├────▶│ │ Firewall │──▶│ SNI │──▶│Route│─│────▶│ Backend A │
│ │ │ (global) │ │Extract│ │Table│ │ │ :8443 │
├────▶│ └──────────┘ └───────┘ └─────┘ │ ├────────────┤
│ │ │ RST │ │ │ Backend B │
Clients ──────┘ │ ▼ └────│────▶│ :9443 │
│ (blocked) │ └────────────┘
└─────────────────────────────────────┘
Listener 1 (:443) ─┐
Listener 2 (:8443) ─┼─ Each listener runs the same pipeline
Listener N (:9443) ─┘
```
Key properties:
- **Layer 4 only.** mc-proxy never terminates TLS. It reads just enough of
the ClientHello to extract the SNI hostname, then proxies the raw TCP
stream to the matched backend. The backend handles TLS termination.
- **TLS-only.** Non-TLS connections are not supported. If the first bytes of
a connection are not a TLS ClientHello, the connection is reset.
- **Multiple listeners.** A single mc-proxy instance binds to one or more
ports. Each listener runs the same firewall → SNI → route pipeline.
- **Global firewall.** Firewall rules apply to all listeners uniformly.
There are no per-route firewall rules.
- **No authentication.** mc-proxy is pre-auth infrastructure. It sits in
front of services that handle their own authentication via MCIAS.
---
## Connection Lifecycle
Every inbound connection follows this sequence:
```
1. ACCEPT Listener accepts TCP connection.
2. FIREWALL Check source IP against blocklists:
a. IP/CIDR block check.
b. GeoIP country block check.
If blocked → RST, done.
3. SNI EXTRACT Read the TLS ClientHello (without consuming it).
Extract the SNI hostname.
If no valid ClientHello or no SNI → RST, done.
4. ROUTE LOOKUP Match SNI hostname against the route table.
If no match → RST, done.
5. BACKEND DIAL Open TCP connection to the matched backend address.
If dial fails → RST, done.
6. PROXY Bidirectional byte copy: client ↔ backend.
The buffered ClientHello bytes are forwarded first,
then both directions copy concurrently.
7. CLOSE Either side closes → half-close propagation → done.
```
### SNI Extraction
The proxy peeks at the initial bytes of the connection without consuming
them. It parses just enough of the TLS record layer and ClientHello to
extract the `server_name` extension. The full ClientHello (including the
SNI) is then forwarded to the backend so the backend's TLS handshake
proceeds normally.
If the ClientHello spans multiple TCP segments, the proxy buffers up to
16 KiB (the maximum TLS record size) before giving up.
### Bidirectional Copy
After the backend connection is established, the proxy runs two concurrent
copy loops (client→backend and backend→client). When either direction
encounters an EOF or error:
1. The write side of the opposite direction is half-closed.
2. The remaining direction drains to completion.
3. Both connections are closed.
Timeouts apply to the copy phase to prevent idle connections from
accumulating indefinitely (see [Configuration](#configuration)).
---
## Firewall
The firewall is a global, ordered rule set evaluated on every new
connection before SNI extraction. Rules are evaluated in definition order;
the first matching rule determines the outcome. If no rule matches, the
connection is **allowed** (default allow — the firewall is an explicit
blocklist, not an allowlist).
### Rule Types
| Config Field | Match Field | Example |
|--------------|-------------|---------|
| `blocked_ips` | Source IP address | `"192.0.2.1"` |
| `blocked_cidrs` | Source IP prefix | `"198.51.100.0/24"` |
| `blocked_countries` | Source country (ISO 3166-1 alpha-2) | `"KP"`, `"CN"`, `"IN"`, `"IL"` |
### Blocked Connection Handling
Blocked connections receive a TCP RST. No TLS alert, no HTTP error page, no
indication of why the connection was refused. This is intentional — blocked
sources should receive minimal information.
### GeoIP Database
mc-proxy uses the [MaxMind GeoLite2](https://dev.maxmind.com/geoip/geolite2-free-geolocation-data)
free database for country-level IP geolocation. The database file is
distributed separately from the binary and referenced by path in the
configuration.
The GeoIP database is loaded into memory at startup and can be reloaded
via `SIGHUP` without restarting the proxy. If the database file is missing
or unreadable at startup and GeoIP rules are configured, the proxy refuses
to start.
---
## Routing
Each listener has its own route table mapping SNI hostnames to backend
addresses. A route entry consists of:
| Field | Type | Description |
|-------|------|-------------|
| `hostname` | string | Exact SNI hostname to match (e.g. `metacrypt.metacircular.net`) |
| `backend` | string | Backend address as `host:port` (e.g. `127.0.0.1:8443`) |
Routes are scoped to the listener that accepted the connection. The same
hostname can appear on different listeners with different backends, allowing
the proxy to route the same service name to different backends depending
on which port the client connected to.
### Match Semantics
- Hostname matching is **exact** and **case-insensitive** (per RFC 6066,
SNI hostnames are DNS names and compared case-insensitively).
- Wildcard matching is not supported in the initial implementation.
- If duplicate hostnames appear within the same listener, the proxy refuses
to start.
- If no route matches an incoming SNI hostname, the connection is reset.
### Route Table Source
Route tables are defined inline under each listener in the TOML
configuration file. The design anticipates future migration to a SQLite
database for dynamic route management via the control plane API.
---
## Configuration
TOML configuration file, loaded at startup. The proxy refuses to start if
required fields are missing or invalid.
```toml
# Listeners. Each has its own route table.
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
[[listeners]]
addr = ":8443"
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[listeners]]
addr = ":9443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
# Firewall. Global blocklist, evaluated before routing. Default allow.
[firewall]
geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb"
blocked_ips = ["192.0.2.1"]
blocked_cidrs = ["198.51.100.0/24"]
blocked_countries = ["KP", "CN", "IN", "IL"]
# Proxy behavior.
[proxy]
connect_timeout = "5s" # Timeout for dialing backend
idle_timeout = "300s" # Close connections idle longer than this
shutdown_timeout = "30s" # Graceful shutdown drain period
# Logging.
[log]
level = "info" # debug, info, warn, error
```
### Environment Variable Overrides
Configuration values can be overridden via environment variables using the
prefix `MCPROXY_` with underscore-separated paths:
```
MCPROXY_LOG_LEVEL=debug
MCPROXY_PROXY_IDLE_TIMEOUT=600s
```
Environment variables cannot define listeners, routes, or firewall rules —
these are structural and must be in the TOML file.
---
## Storage
mc-proxy has minimal storage requirements. There is no database in the
initial implementation.
```
/srv/mc-proxy/
├── mc-proxy.toml Configuration
├── GeoLite2-Country.mmdb GeoIP database (if using country blocks)
└── backups/ Reserved for future use
```
No TLS certificates are stored — mc-proxy does not terminate TLS.
---
## Deployment
### Binary
Single static binary, built with `CGO_ENABLED=0`. No runtime dependencies
beyond the configuration file and optional GeoIP database.
### Container
Multi-stage Docker build:
1. **Builder**: `golang:<version>-alpine`, static compilation.
2. **Runtime**: `alpine`, non-root user (`mc-proxy`), port exposure
determined by configuration.
### systemd
| File | Purpose |
|------|---------|
| `mc-proxy.service` | Main proxy service |
The proxy binds to privileged ports (443) and should use `AmbientCapabilities=CAP_NET_BIND_SERVICE`
in the systemd unit rather than running as root.
Standard security hardening directives apply per engineering standards
(`NoNewPrivileges=true`, `ProtectSystem=strict`, etc.).
### Graceful Shutdown
On `SIGINT` or `SIGTERM`:
1. Stop accepting new connections on all listeners.
2. Wait for in-flight connections to complete (up to `shutdown_timeout`).
3. Force-close remaining connections.
4. Exit.
On `SIGHUP`:
1. Reload the GeoIP database from disk.
2. Continue serving with the updated database.
Configuration changes (routes, listeners, firewall rules) require a full
restart. Hot reload of routing rules is deferred to the future SQLite-backed
implementation.
---
## Security Model
mc-proxy is infrastructure that sits in front of authenticated services.
It has no authentication or authorization of its own.
### Threat Mitigations
| Threat | Mitigation |
|--------|------------|
| SNI spoofing | Backend performs its own TLS handshake — a spoofed SNI will fail certificate validation at the backend. mc-proxy does not trust SNI for security decisions beyond routing. |
| Resource exhaustion (connection flood) | Idle timeout closes stale connections. Per-listener connection limits (future). Rate limiting (future). |
| GeoIP evasion via IPv6 | GeoLite2 database includes IPv6 mappings. Both IPv4 and IPv6 source addresses are checked. |
| GeoIP evasion via VPN/proxy | Accepted risk. GeoIP blocking is a compliance measure, not a security boundary. Determined adversaries will bypass it. |
| Slowloris / slow ClientHello | Timeout on the SNI extraction phase. If a complete ClientHello is not received within a reasonable window (e.g. 10s), the connection is reset. |
| Backend unavailability | Connect timeout prevents indefinite hangs. Connection is reset if the backend is unreachable. |
| Information leakage | Blocked connections receive only a TCP RST. No version strings, no error messages, no TLS alerts. |
### Security Invariants
1. mc-proxy never terminates TLS. It cannot read application-layer traffic.
2. mc-proxy never modifies the byte stream between client and backend.
3. Firewall rules are always evaluated before any routing decision.
4. The proxy never logs connection content — only metadata (source IP,
SNI hostname, backend, timestamps, bytes transferred).
---
## Future Work
Items are listed roughly in priority order:
| Item | Description |
|------|-------------|
| **gRPC admin API** | Internal-only API for managing routes and firewall rules at runtime, integrated with the Metacircular Control Plane. |
| **SQLite route storage** | Migrate route table from TOML to SQLite for dynamic management via the admin API. |
| **L7 HTTPS support** | TLS-terminating mode for selected routes, enabling HTTP-level features (user-agent blocking, header inspection, request routing). |
| **ACME integration** | Automatic certificate provisioning via Let's Encrypt for L7 routes. |
| **User-agent blocking** | Block connections based on user-agent string (requires L7 mode). |
| **Connection rate limiting** | Per-source-IP rate limits to mitigate connection floods. |
| **Per-listener connection limits** | Cap maximum concurrent connections per listener. |
| **Health check endpoint** | Lightweight TCP or HTTP health check for load balancers and monitoring. |
| **Metrics** | Prometheus-compatible metrics: connections per listener, firewall blocks by rule, backend dial latency, active connections. |

54
CLAUDE.md Normal file
View File

@@ -0,0 +1,54 @@
# CLAUDE.md
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
## Project Overview
mc-proxy is a Layer 4 TLS SNI proxy and router for Metacircular Dynamics services. It reads the SNI hostname from incoming TLS ClientHello messages and proxies the raw TCP stream to the matched backend. It does not terminate TLS. A global firewall (IP, CIDR, GeoIP country blocking) is evaluated before routing. See `ARCHITECTURE.md` for full design.
## Build Commands
```bash
make all # vet → lint → test → build
make mc-proxy # build the binary with version injection
make build # compile all packages
make test # run all tests
make vet # go vet
make lint # golangci-lint
```
Run a single test:
```bash
go test ./internal/sni -run TestExtract
```
## Architecture
- **Module path**: `git.wntrmute.dev/kyle/mc-proxy`
- **Go with CGO_ENABLED=0**, statically linked, Alpine containers
- **No API surface yet** — config-driven via TOML; gRPC admin API planned for future MCP integration
- **No auth** — this is pre-auth infrastructure; services behind it handle their own MCIAS auth
- **No database** — routes and firewall rules are in the TOML config; SQLite planned for dynamic route management
- **Config**: TOML via `go-toml/v2`, runtime data in `/srv/mc-proxy/`
- **Testing**: stdlib `testing` only, `t.TempDir()` for isolation
- **Linting**: golangci-lint v2 with `.golangci.yaml`
## Package Structure
- `internal/config/` — TOML config loading and validation
- `internal/sni/` — TLS ClientHello parser; extracts SNI hostname without consuming bytes
- `internal/firewall/` — global blocklist evaluation (IP, CIDR, GeoIP via MaxMind GeoLite2); thread-safe GeoIP reload
- `internal/proxy/` — bidirectional TCP relay with half-close propagation and idle timeout
- `internal/server/` — orchestrates listeners → firewall → SNI → route → proxy pipeline; graceful shutdown
## Signals
- `SIGINT`/`SIGTERM` — graceful shutdown (drain in-flight connections up to `shutdown_timeout`)
- `SIGHUP` — reload GeoIP database without restart
## Critical Rules
- mc-proxy never terminates TLS and never modifies the byte stream.
- Firewall rules are always evaluated before any routing decision.
- SNI matching is exact and case-insensitive.
- Blocked connections get a TCP RST — no error messages, no TLS alerts.

15
Dockerfile Normal file
View File

@@ -0,0 +1,15 @@
FROM golang:1.24-alpine AS builder
WORKDIR /build
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 go build -trimpath -ldflags="-s -w" -o mc-proxy ./cmd/mc-proxy
FROM alpine:3.21
RUN addgroup -S mc-proxy && adduser -S mc-proxy -G mc-proxy
COPY --from=builder /build/mc-proxy /usr/local/bin/mc-proxy
USER mc-proxy
ENTRYPOINT ["mc-proxy"]

36
Makefile Normal file
View File

@@ -0,0 +1,36 @@
.PHONY: build test vet lint proto clean docker all devserver
LDFLAGS := -trimpath -ldflags="-s -w -X main.version=$(shell git describe --tags --always --dirty)"
mc-proxy:
go build $(LDFLAGS) -o mc-proxy ./cmd/mc-proxy
build:
go build ./...
test:
go test ./...
vet:
go vet ./...
lint:
golangci-lint run ./...
proto:
protoc --go_out=. --go_opt=module=git.wntrmute.dev/kyle/mc-proxy \
--go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/kyle/mc-proxy \
proto/mc-proxy/v1/*.proto
clean:
rm -f mc-proxy
docker:
docker build -t mc-proxy -f Dockerfile .
devserver: mc-proxy
@mkdir -p srv
@if [ ! -f srv/mc-proxy.toml ]; then cp mc-proxy.toml.example srv/mc-proxy.toml; echo "Created srv/mc-proxy.toml from example — edit before running."; fi
./mc-proxy server --config srv/mc-proxy.toml
all: vet lint test mc-proxy

19
README.md Normal file
View File

@@ -0,0 +1,19 @@
mc-proxy is a TLS proxy and router for Metacircular Dynamics projects;
it follows the Metacircular Engineering Standards.
Metacircular services are deployed to a machine that runs these projects
as containers. The proxy should do a few things:
1. It should have a global firewall front-end. It should allow a few
things:
1. Per-country blocks using GeoIP for compliance reasons.
2. Normal IP/CIDR blocks. Note that a proxy has an explicit port
setting, so the firewall doesn't need to consider ports.
3. For endpoints marked as HTTPS, we should consider how to do
user-agent blocking.
2. It should inspect the hostname and route that to the proper
container, similar to how haproxy would do it.

View File

@@ -0,0 +1,37 @@
# mc-proxy configuration
# Listeners. Each entry binds a TCP listener on the specified address.
[[listeners]]
addr = ":443"
[[listeners]]
addr = ":8443"
[[listeners]]
addr = ":9443"
# Routes. SNI hostname → backend address.
[[routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
# Firewall. Global blocklist, evaluated before routing. Default allow.
[firewall]
geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb"
blocked_ips = []
blocked_cidrs = []
blocked_countries = ["KP", "CN", "IN", "IL"]
# Proxy behavior.
[proxy]
connect_timeout = "5s"
idle_timeout = "300s"
shutdown_timeout = "30s"
# Logging.
[log]
level = "info"

42
deploy/scripts/install.sh Executable file
View File

@@ -0,0 +1,42 @@
#!/bin/sh
set -eu
SERVICE="mc-proxy"
BINARY="/usr/local/bin/${SERVICE}"
DATA_DIR="/srv/${SERVICE}"
UNIT_DIR="/etc/systemd/system"
SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
REPO_DIR="$(cd "${SCRIPT_DIR}/../.." && pwd)"
# Create system user and group (idempotent).
if ! id -u "${SERVICE}" >/dev/null 2>&1; then
useradd --system --no-create-home --shell /usr/sbin/nologin "${SERVICE}"
echo "Created system user ${SERVICE}."
fi
# Install binary.
install -m 0755 "${REPO_DIR}/${SERVICE}" "${BINARY}"
echo "Installed ${BINARY}."
# Create data directory structure.
install -d -o "${SERVICE}" -g "${SERVICE}" -m 0700 "${DATA_DIR}"
install -d -o "${SERVICE}" -g "${SERVICE}" -m 0700 "${DATA_DIR}/backups"
echo "Created ${DATA_DIR}/."
# Install example config if none exists.
if [ ! -f "${DATA_DIR}/${SERVICE}.toml" ]; then
install -o "${SERVICE}" -g "${SERVICE}" -m 0600 \
"${REPO_DIR}/${SERVICE}.toml.example" \
"${DATA_DIR}/${SERVICE}.toml"
echo "Installed example config to ${DATA_DIR}/${SERVICE}.toml — edit before starting."
fi
# Install systemd units.
install -m 0644 "${REPO_DIR}/deploy/systemd/${SERVICE}.service" "${UNIT_DIR}/"
systemctl daemon-reload
echo "Installed systemd unit ${SERVICE}.service."
echo ""
echo "Done. Next steps:"
echo " 1. Edit ${DATA_DIR}/${SERVICE}.toml"
echo " 2. systemctl enable --now ${SERVICE}"

View File

@@ -0,0 +1,32 @@
[Unit]
Description=mc-proxy TLS proxy and router
After=network-online.target
Wants=network-online.target
[Service]
Type=simple
User=mc-proxy
Group=mc-proxy
ExecStart=/usr/local/bin/mc-proxy server --config /srv/mc-proxy/mc-proxy.toml
Restart=on-failure
RestartSec=5
AmbientCapabilities=CAP_NET_BIND_SERVICE
NoNewPrivileges=true
ProtectSystem=strict
ProtectHome=true
PrivateTmp=true
PrivateDevices=true
ProtectKernelTunables=true
ProtectKernelModules=true
ProtectControlGroups=true
RestrictSUIDSGID=true
RestrictNamespaces=true
LockPersonality=true
MemoryDenyWriteExecute=true
RestrictRealtime=true
ReadWritePaths=/srv/mc-proxy
[Install]
WantedBy=multi-user.target

20
go.mod Normal file
View File

@@ -0,0 +1,20 @@
module git.wntrmute.dev/kyle/mc-proxy
go 1.24.0
require (
github.com/oschwald/maxminddb-golang v1.13.1
github.com/pelletier/go-toml/v2 v2.2.4
github.com/spf13/cobra v1.10.2
google.golang.org/grpc v1.79.2
google.golang.org/protobuf v1.36.11
)
require (
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/spf13/pflag v1.0.9 // indirect
golang.org/x/net v0.48.0 // indirect
golang.org/x/sys v0.39.0 // indirect
golang.org/x/text v0.32.0 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
)

60
go.sum Normal file
View File

@@ -0,0 +1,60 @@
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI=
github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/oschwald/maxminddb-golang v1.13.1 h1:G3wwjdN9JmIK2o/ermkHM+98oX5fS+k5MbwsmL4MRQE=
github.com/oschwald/maxminddb-golang v1.13.1/go.mod h1:K4pgV9N/GcK694KSTmVSDTODk4IsCNThNdTmnaBZ/F8=
github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4=
github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU=
github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4=
github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY=
github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64=
go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y=
go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48=
go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8=
go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0=
go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs=
go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18=
go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE=
go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8=
go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew=
go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI=
go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA=
go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU=
golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY=
golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk=
golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU=
golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY=
gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk=
gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww=
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk=
google.golang.org/grpc v1.79.2 h1:fRMD94s2tITpyJGtBBn7MkMseNpOZU8ZxgC3MMBaXRU=
google.golang.org/grpc v1.79.2/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ=
google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE=
google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

115
internal/config/config.go Normal file
View File

@@ -0,0 +1,115 @@
package config
import (
"fmt"
"os"
"time"
"github.com/pelletier/go-toml/v2"
)
type Config struct {
Listeners []Listener `toml:"listeners"`
GRPC GRPC `toml:"grpc"`
Firewall Firewall `toml:"firewall"`
Proxy Proxy `toml:"proxy"`
Log Log `toml:"log"`
}
type GRPC struct {
Addr string `toml:"addr"`
TLSCert string `toml:"tls_cert"`
TLSKey string `toml:"tls_key"`
ClientCA string `toml:"client_ca"`
}
type Listener struct {
Addr string `toml:"addr"`
Routes []Route `toml:"routes"`
}
type Route struct {
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
}
type Firewall struct {
GeoIPDB string `toml:"geoip_db"`
BlockedIPs []string `toml:"blocked_ips"`
BlockedCIDRs []string `toml:"blocked_cidrs"`
BlockedCountries []string `toml:"blocked_countries"`
}
type Proxy struct {
ConnectTimeout Duration `toml:"connect_timeout"`
IdleTimeout Duration `toml:"idle_timeout"`
ShutdownTimeout Duration `toml:"shutdown_timeout"`
}
type Log struct {
Level string `toml:"level"`
}
// Duration wraps time.Duration for TOML string unmarshalling.
type Duration struct {
time.Duration
}
func (d *Duration) UnmarshalText(text []byte) error {
var err error
d.Duration, err = time.ParseDuration(string(text))
return err
}
func Load(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("reading config: %w", err)
}
var cfg Config
if err := toml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("parsing config: %w", err)
}
if err := cfg.validate(); err != nil {
return nil, fmt.Errorf("invalid config: %w", err)
}
return &cfg, nil
}
func (c *Config) validate() error {
if len(c.Listeners) == 0 {
return fmt.Errorf("at least one listener is required")
}
for i, l := range c.Listeners {
if l.Addr == "" {
return fmt.Errorf("listener %d: addr is required", i)
}
if len(l.Routes) == 0 {
return fmt.Errorf("listener %d (%s): at least one route is required", i, l.Addr)
}
seen := make(map[string]bool)
for j, r := range l.Routes {
if r.Hostname == "" {
return fmt.Errorf("listener %d (%s), route %d: hostname is required", i, l.Addr, j)
}
if r.Backend == "" {
return fmt.Errorf("listener %d (%s), route %d: backend is required", i, l.Addr, j)
}
if seen[r.Hostname] {
return fmt.Errorf("listener %d (%s), route %d: duplicate hostname %q", i, l.Addr, j, r.Hostname)
}
seen[r.Hostname] = true
}
}
if len(c.Firewall.BlockedCountries) > 0 && c.Firewall.GeoIPDB == "" {
return fmt.Errorf("firewall: geoip_db is required when blocked_countries is set")
}
return nil
}

View File

@@ -0,0 +1,186 @@
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadValid(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[proxy]
connect_timeout = "5s"
idle_timeout = "300s"
shutdown_timeout = "30s"
[log]
level = "info"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Listeners) != 1 {
t.Fatalf("got %d listeners, want 1", len(cfg.Listeners))
}
if cfg.Listeners[0].Addr != ":443" {
t.Fatalf("got listener addr %q, want %q", cfg.Listeners[0].Addr, ":443")
}
if len(cfg.Listeners[0].Routes) != 1 {
t.Fatalf("got %d routes, want 1", len(cfg.Listeners[0].Routes))
}
if cfg.Listeners[0].Routes[0].Hostname != "example.com" {
t.Fatalf("got hostname %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "example.com")
}
}
func TestLoadNoListeners(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[log]
level = "info"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for missing listeners")
}
}
func TestLoadNoRoutes(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for missing routes")
}
}
func TestLoadDuplicateHostnames(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:9443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for duplicate hostnames")
}
}
func TestLoadGeoIPRequiredWithCountries(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "example.com"
backend = "127.0.0.1:8443"
[firewall]
blocked_countries = ["CN"]
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
_, err := Load(path)
if err == nil {
t.Fatal("expected error for blocked_countries without geoip_db")
}
}
func TestLoadMultipleListeners(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "test.toml")
data := `
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "public.example.com"
backend = "127.0.0.1:8443"
[[listeners]]
addr = ":8443"
[[listeners.routes]]
hostname = "internal.example.com"
backend = "127.0.0.1:9443"
`
if err := os.WriteFile(path, []byte(data), 0600); err != nil {
t.Fatalf("write config: %v", err)
}
cfg, err := Load(path)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(cfg.Listeners) != 2 {
t.Fatalf("got %d listeners, want 2", len(cfg.Listeners))
}
if cfg.Listeners[0].Routes[0].Hostname != "public.example.com" {
t.Fatalf("listener 0 hostname = %q, want %q", cfg.Listeners[0].Routes[0].Hostname, "public.example.com")
}
if cfg.Listeners[1].Routes[0].Hostname != "internal.example.com" {
t.Fatalf("listener 1 hostname = %q, want %q", cfg.Listeners[1].Routes[0].Hostname, "internal.example.com")
}
}
func TestDuration(t *testing.T) {
var d Duration
if err := d.UnmarshalText([]byte("5s")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if d.Duration.Seconds() != 5 {
t.Fatalf("got %v, want 5s", d.Duration)
}
}

View File

@@ -0,0 +1,218 @@
package firewall
import (
"fmt"
"net/netip"
"strings"
"sync"
"github.com/oschwald/maxminddb-golang"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
)
type geoIPRecord struct {
Country struct {
ISOCode string `maxminddb:"iso_code"`
} `maxminddb:"country"`
}
// Firewall evaluates global blocklist rules against connection source addresses.
type Firewall struct {
blockedIPs map[netip.Addr]struct{}
blockedCIDRs []netip.Prefix
blockedCountries map[string]struct{}
geoDBPath string
geoDB *maxminddb.Reader
mu sync.RWMutex // protects all mutable state
}
// New creates a Firewall from the given configuration.
func New(cfg config.Firewall) (*Firewall, error) {
f := &Firewall{
blockedIPs: make(map[netip.Addr]struct{}),
blockedCountries: make(map[string]struct{}),
geoDBPath: cfg.GeoIPDB,
}
for _, ip := range cfg.BlockedIPs {
addr, err := netip.ParseAddr(ip)
if err != nil {
return nil, fmt.Errorf("invalid blocked IP %q: %w", ip, err)
}
f.blockedIPs[addr] = struct{}{}
}
for _, cidr := range cfg.BlockedCIDRs {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return nil, fmt.Errorf("invalid blocked CIDR %q: %w", cidr, err)
}
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
}
for _, code := range cfg.BlockedCountries {
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
}
if len(f.blockedCountries) > 0 {
if err := f.loadGeoDB(cfg.GeoIPDB); err != nil {
return nil, fmt.Errorf("loading GeoIP database: %w", err)
}
}
return f, nil
}
// Blocked returns true if the given address should be blocked.
func (f *Firewall) Blocked(addr netip.Addr) bool {
addr = addr.Unmap()
f.mu.RLock()
defer f.mu.RUnlock()
if _, ok := f.blockedIPs[addr]; ok {
return true
}
for _, prefix := range f.blockedCIDRs {
if prefix.Contains(addr) {
return true
}
}
if len(f.blockedCountries) > 0 && f.geoDB != nil {
var record geoIPRecord
if err := f.geoDB.Lookup(addr.AsSlice(), &record); err == nil {
if _, ok := f.blockedCountries[record.Country.ISOCode]; ok {
return true
}
}
}
return false
}
// AddIP adds an IP address to the blocklist.
func (f *Firewall) AddIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
f.blockedIPs[addr] = struct{}{}
f.mu.Unlock()
return nil
}
// RemoveIP removes an IP address from the blocklist.
func (f *Firewall) RemoveIP(ip string) error {
addr, err := netip.ParseAddr(ip)
if err != nil {
return fmt.Errorf("invalid IP %q: %w", ip, err)
}
f.mu.Lock()
delete(f.blockedIPs, addr)
f.mu.Unlock()
return nil
}
// AddCIDR adds a CIDR prefix to the blocklist.
func (f *Firewall) AddCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
f.blockedCIDRs = append(f.blockedCIDRs, prefix)
f.mu.Unlock()
return nil
}
// RemoveCIDR removes a CIDR prefix from the blocklist.
func (f *Firewall) RemoveCIDR(cidr string) error {
prefix, err := netip.ParsePrefix(cidr)
if err != nil {
return fmt.Errorf("invalid CIDR %q: %w", cidr, err)
}
f.mu.Lock()
for i, p := range f.blockedCIDRs {
if p == prefix {
f.blockedCIDRs = append(f.blockedCIDRs[:i], f.blockedCIDRs[i+1:]...)
break
}
}
f.mu.Unlock()
return nil
}
// AddCountry adds a country code to the blocklist.
func (f *Firewall) AddCountry(code string) {
f.mu.Lock()
f.blockedCountries[strings.ToUpper(code)] = struct{}{}
f.mu.Unlock()
}
// RemoveCountry removes a country code from the blocklist.
func (f *Firewall) RemoveCountry(code string) {
f.mu.Lock()
delete(f.blockedCountries, strings.ToUpper(code))
f.mu.Unlock()
}
// Rules returns a snapshot of all current firewall rules.
func (f *Firewall) Rules() (ips []string, cidrs []string, countries []string) {
f.mu.RLock()
defer f.mu.RUnlock()
for addr := range f.blockedIPs {
ips = append(ips, addr.String())
}
for _, prefix := range f.blockedCIDRs {
cidrs = append(cidrs, prefix.String())
}
for code := range f.blockedCountries {
countries = append(countries, code)
}
return
}
// ReloadGeoIP reloads the GeoIP database from disk. Safe for concurrent use.
func (f *Firewall) ReloadGeoIP() error {
if f.geoDBPath == "" {
return nil
}
return f.loadGeoDB(f.geoDBPath)
}
// Close releases resources held by the firewall.
func (f *Firewall) Close() error {
f.mu.Lock()
defer f.mu.Unlock()
if f.geoDB != nil {
return f.geoDB.Close()
}
return nil
}
func (f *Firewall) loadGeoDB(path string) error {
db, err := maxminddb.Open(path)
if err != nil {
return err
}
f.mu.Lock()
old := f.geoDB
f.geoDB = db
f.mu.Unlock()
if old != nil {
old.Close()
}
return nil
}

View File

@@ -0,0 +1,141 @@
package firewall
import (
"net/netip"
"testing"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
)
func TestEmptyFirewall(t *testing.T) {
fw, err := New(config.Firewall{})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
for _, a := range addrs {
addr := netip.MustParseAddr(a)
if fw.Blocked(addr) {
t.Fatalf("empty firewall blocked %s", addr)
}
}
}
func TestIPBlocking(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"192.0.2.1", "2001:db8::dead"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"192.0.2.1", true},
{"192.0.2.2", false},
{"2001:db8::dead", true},
{"2001:db8::beef", false},
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}
func TestCIDRBlocking(t *testing.T) {
fw, err := New(config.Firewall{
BlockedCIDRs: []string{"198.51.100.0/24", "2001:db8::/32"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"198.51.100.1", true},
{"198.51.100.254", true},
{"198.51.101.1", false},
{"2001:db8::1", true},
{"2001:db9::1", false},
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}
func TestIPv4MappedIPv6(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"192.0.2.1"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
// IPv4-mapped IPv6 representation of 192.0.2.1.
addr := netip.MustParseAddr("::ffff:192.0.2.1")
if !fw.Blocked(addr) {
t.Fatal("expected IPv4-mapped IPv6 address to be blocked")
}
}
func TestInvalidIP(t *testing.T) {
_, err := New(config.Firewall{
BlockedIPs: []string{"not-an-ip"},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
}
func TestInvalidCIDR(t *testing.T) {
_, err := New(config.Firewall{
BlockedCIDRs: []string{"not-a-cidr"},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
}
func TestCombinedRules(t *testing.T) {
fw, err := New(config.Firewall{
BlockedIPs: []string{"10.0.0.1"},
BlockedCIDRs: []string{"192.168.0.0/16"},
})
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
tests := []struct {
addr string
blocked bool
}{
{"10.0.0.1", true}, // IP match
{"10.0.0.2", false}, // no match
{"192.168.1.1", true}, // CIDR match
{"172.16.0.1", false}, // no match
}
for _, tt := range tests {
addr := netip.MustParseAddr(tt.addr)
if got := fw.Blocked(addr); got != tt.blocked {
t.Fatalf("Blocked(%s) = %v, want %v", tt.addr, got, tt.blocked)
}
}
}

View File

@@ -0,0 +1,246 @@
package grpcserver
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
"os"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServer
srv *server.Server
logger *slog.Logger
}
// New creates a gRPC server with TLS and optional mTLS.
func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
if err != nil {
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
}
tlsConfig := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
// mTLS: require and verify client certificates.
if cfg.ClientCA != "" {
caCert, err := os.ReadFile(cfg.ClientCA)
if err != nil {
return nil, nil, fmt.Errorf("reading client CA: %w", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(caCert) {
return nil, nil, fmt.Errorf("failed to parse client CA certificate")
}
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
tlsConfig.ClientCAs = pool
}
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(grpc.Creds(creds))
admin := &AdminServer{
srv: srv,
logger: logger,
}
pb.RegisterProxyAdminServer(grpcServer, admin)
ln, err := net.Listen("tcp", cfg.Addr)
if err != nil {
return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err)
}
return grpcServer, ln, nil
}
// ListRoutes returns the route table for a listener.
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
routes := ls.Routes()
resp := &pb.ListRoutesResponse{
ListenerAddr: ls.Addr,
}
for hostname, backend := range routes {
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: backend,
})
}
return resp, nil
}
// AddRoute adds a route to a listener's route table.
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
if req.Route == nil {
return nil, status.Error(codes.InvalidArgument, "route is required")
}
if req.Route.Hostname == "" || req.Route.Backend == "" {
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend)
return &pb.AddRouteResponse{}, nil
}
// RemoveRoute removes a route from a listener's route table.
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
if req.Hostname == "" {
return nil, status.Error(codes.InvalidArgument, "hostname is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
if err := ls.RemoveRoute(req.Hostname); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname)
return &pb.RemoveRouteResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()
var rules []*pb.FirewallRule
for _, ip := range ips {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: ip,
})
}
for _, cidr := range cidrs {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: cidr,
})
}
for _, code := range countries {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: code,
})
}
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
}
// AddFirewallRule adds a firewall rule.
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
fw := a.srv.Firewall()
switch req.Rule.Type {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
if err := fw.AddIP(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
if err := fw.AddCIDR(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "country code is required")
}
fw.AddCountry(req.Rule.Value)
default:
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
}
a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value)
return &pb.AddFirewallRuleResponse{}, nil
}
// RemoveFirewallRule removes a firewall rule.
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
fw := a.srv.Firewall()
switch req.Rule.Type {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
if err := fw.RemoveIP(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
}
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "country code is required")
}
fw.RemoveCountry(req.Rule.Value)
default:
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
}
a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value)
return &pb.RemoveFirewallRuleResponse{}, nil
}
// GetStatus returns the proxy's current status.
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
var listeners []*pb.ListenerStatus
for _, ls := range a.srv.Listeners() {
routes := ls.Routes()
listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr,
RouteCount: int32(len(routes)),
ActiveConnections: ls.ActiveConnections.Load(),
})
}
return &pb.GetStatusResponse{
Version: a.srv.Version(),
StartedAt: timestamppb.New(a.srv.StartedAt()),
Listeners: listeners,
TotalConnections: a.srv.TotalConnections(),
}, nil
}
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
for _, ls := range a.srv.Listeners() {
if ls.Addr == addr {
return ls, nil
}
}
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
}

105
internal/proxy/proxy.go Normal file
View File

@@ -0,0 +1,105 @@
package proxy
import (
"context"
"io"
"net"
"sync"
"time"
)
// Result holds the outcome of a relay operation.
type Result struct {
ClientBytes int64 // bytes sent from client to backend
BackendBytes int64 // bytes sent from backend to client
}
// Relay performs bidirectional byte copying between client and backend.
// The peeked bytes (the TLS ClientHello) are written to the backend first.
// Relay blocks until both directions are done or ctx is cancelled.
func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTimeout time.Duration) (Result, error) {
// Forward the buffered ClientHello to the backend.
if len(peeked) > 0 {
if _, err := backend.Write(peeked); err != nil {
return Result{}, err
}
}
// Cancel context closes both connections to unblock copy goroutines.
ctx, cancel := context.WithCancel(ctx)
defer cancel()
go func() {
<-ctx.Done()
client.Close()
backend.Close()
}()
var (
result Result
wg sync.WaitGroup
errC2B error
errB2C error
)
wg.Add(2)
// client → backend
go func() {
defer wg.Done()
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
// backend → client
go func() {
defer wg.Done()
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
}
}()
wg.Wait()
// If context was cancelled, that's the primary error.
if ctx.Err() != nil {
return result, ctx.Err()
}
// Return the first meaningful error, if any.
if errC2B != nil {
return result, errC2B
}
return result, errB2C
}
// copyWithIdleTimeout copies from src to dst, resetting the idle deadline
// on each successful read.
func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, error) {
buf := make([]byte, 32*1024)
var total int64
for {
src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr])
total += int64(nw)
if writeErr != nil {
return total, writeErr
}
}
if readErr != nil {
if readErr == io.EOF {
return total, nil
}
return total, readErr
}
}
}

View File

@@ -0,0 +1,259 @@
package proxy
import (
"bytes"
"context"
"crypto/rand"
"io"
"net"
"testing"
"time"
)
func TestRelayBasic(t *testing.T) {
// Set up a TCP listener to act as the backend.
backendLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn.Close()
peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client")
backendData := []byte("data from backend")
// Backend goroutine: accept, read peeked+client data, send response, close.
backendDone := make(chan []byte, 1)
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
// Read everything the backend receives.
received, _ := io.ReadAll(conn)
backendDone <- received
// This won't work since ReadAll waits for EOF.
// Instead, restructure: read expected bytes, write response, close write.
}()
// Restructure: use a more controlled flow.
backendLn.Close()
// Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn2.Close()
go func() {
conn, err := backendLn2.Accept()
if err != nil {
return
}
defer conn.Close()
// Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData))
n, _ := io.ReadFull(conn, buf)
backendDone <- buf[:n]
// Send response.
conn.Write(backendData)
// Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Dial the backend.
backendConn, err := net.Dial("tcp", backendLn2.Addr().String())
if err != nil {
t.Fatalf("dial backend: %v", err)
}
// Create a client-side TCP pair.
clientLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer clientLn.Close()
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil {
t.Fatalf("dial client: %v", err)
}
serverSideClient, err := clientLn.Accept()
if err != nil {
t.Fatalf("accept client: %v", err)
}
// Client sends data then closes write.
go func() {
clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Run relay.
result, err := Relay(context.Background(), serverSideClient, backendConn, peeked, 5*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
// Verify backend received peeked + client data.
received := <-backendDone
expected := append(peeked, clientData...)
if !bytes.Equal(received, expected) {
t.Fatalf("backend received %q, want %q", received, expected)
}
// Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData)
}
if result.ClientBytes != int64(len(clientData)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(clientData))
}
if result.BackendBytes != int64(len(backendData)) {
t.Fatalf("BackendBytes = %d, want %d", result.BackendBytes, len(backendData))
}
}
func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP.
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
elapsed := time.Since(start)
// Should return due to idle timeout.
if err == nil {
t.Fatal("expected error from idle timeout")
}
if elapsed > 2*time.Second {
t.Fatalf("relay took %v, expected ~100ms", elapsed)
}
}
func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, clientB, backendA, nil, time.Minute)
close(done)
}()
// Cancel after a short delay.
time.Sleep(50 * time.Millisecond)
cancel()
select {
case <-done:
// OK
case <-time.After(2 * time.Second):
t.Fatal("relay did not return after context cancel")
}
_ = backendB // keep reference
}
func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
// 1 MB of random data.
data := make([]byte, 1<<20)
if _, err := rand.Read(data); err != nil {
t.Fatalf("rand read: %v", err)
}
go func() {
clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
// Backend reads and echoes chunks, then closes when client EOF arrives.
go func() {
buf := make([]byte, 32*1024)
for {
n, err := backendB.Read(buf)
if n > 0 {
backendB.Write(buf[:n])
}
if err != nil {
break
}
}
if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite()
}
}()
result, err := Relay(context.Background(), clientB, backendA, nil, 10*time.Second)
if err != nil {
t.Fatalf("relay error: %v", err)
}
if result.ClientBytes != int64(len(data)) {
t.Fatalf("ClientBytes = %d, want %d", result.ClientBytes, len(data))
}
}
// tcpPair returns two connected TCP connections.
func tcpPair(t *testing.T) (net.Conn, net.Conn) {
t.Helper()
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
var serverConn net.Conn
done := make(chan struct{})
go func() {
serverConn, _ = ln.Accept()
close(done)
}()
clientConn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
<-done
return clientConn, serverConn
}

271
internal/server/server.go Normal file
View File

@@ -0,0 +1,271 @@
package server
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
)
// ListenerState holds the mutable state for a single proxy listener.
type ListenerState struct {
Addr string
routes map[string]string // lowercase hostname → backend addr
mu sync.RWMutex
ActiveConnections atomic.Int64
}
// Routes returns a snapshot of the listener's route table.
func (ls *ListenerState) Routes() map[string]string {
ls.mu.RLock()
defer ls.mu.RUnlock()
m := make(map[string]string, len(ls.routes))
for k, v := range ls.routes {
m[k] = v
}
return m
}
// AddRoute adds a route to the listener. Returns an error if the hostname
// already exists.
func (ls *ListenerState) AddRoute(hostname, backend string) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; ok {
return fmt.Errorf("route %q already exists", hostname)
}
ls.routes[key] = backend
return nil
}
// RemoveRoute removes a route from the listener. Returns an error if the
// hostname does not exist.
func (ls *ListenerState) RemoveRoute(hostname string) error {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; !ok {
return fmt.Errorf("route %q not found", hostname)
}
delete(ls.routes, key)
return nil
}
func (ls *ListenerState) lookupRoute(hostname string) (string, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
backend, ok := ls.routes[hostname]
return backend, ok
}
// Server is the mc-proxy server. It manages listeners, firewall evaluation,
// SNI-based routing, and bidirectional proxying.
type Server struct {
cfg *config.Config
fw *firewall.Firewall
listeners []*ListenerState
logger *slog.Logger
wg sync.WaitGroup // tracks active connections
startedAt time.Time
version string
}
// New creates a Server from the given configuration.
func New(cfg *config.Config, logger *slog.Logger, version string) (*Server, error) {
fw, err := firewall.New(cfg.Firewall)
if err != nil {
return nil, fmt.Errorf("initializing firewall: %w", err)
}
var listeners []*ListenerState
for _, lcfg := range cfg.Listeners {
routes := make(map[string]string, len(lcfg.Routes))
for _, r := range lcfg.Routes {
routes[strings.ToLower(r.Hostname)] = r.Backend
}
listeners = append(listeners, &ListenerState{
Addr: lcfg.Addr,
routes: routes,
})
}
return &Server{
cfg: cfg,
fw: fw,
listeners: listeners,
logger: logger,
version: version,
}, nil
}
// Firewall returns the server's firewall for use by the gRPC admin API.
func (s *Server) Firewall() *firewall.Firewall {
return s.fw
}
// Listeners returns the server's listener states for use by the gRPC admin API.
func (s *Server) Listeners() []*ListenerState {
return s.listeners
}
// StartedAt returns the time the server started.
func (s *Server) StartedAt() time.Time {
return s.startedAt
}
// Version returns the server's version string.
func (s *Server) Version() string {
return s.version
}
// TotalConnections returns the total number of active connections.
func (s *Server) TotalConnections() int64 {
var total int64
for _, ls := range s.listeners {
total += ls.ActiveConnections.Load()
}
return total
}
// Run starts all listeners and blocks until ctx is cancelled.
func (s *Server) Run(ctx context.Context) error {
s.startedAt = time.Now()
var netListeners []net.Listener
for _, ls := range s.listeners {
ln, err := net.Listen("tcp", ls.Addr)
if err != nil {
for _, l := range netListeners {
l.Close()
}
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
}
s.logger.Info("listening", "addr", ln.Addr(), "routes", len(ls.routes))
netListeners = append(netListeners, ln)
}
// Start accept loops.
for i, ln := range netListeners {
ln := ln
ls := s.listeners[i]
go s.serve(ctx, ln, ls)
}
// Block until shutdown signal.
<-ctx.Done()
s.logger.Info("shutting down")
// Stop accepting new connections.
for _, ln := range netListeners {
ln.Close()
}
// Wait for in-flight connections with a timeout.
done := make(chan struct{})
go func() {
s.wg.Wait()
close(done)
}()
select {
case <-done:
s.logger.Info("all connections drained")
case <-time.After(s.cfg.Proxy.ShutdownTimeout.Duration):
s.logger.Warn("shutdown timeout exceeded, forcing close")
}
s.fw.Close()
return nil
}
// ReloadGeoIP reloads the GeoIP database from disk.
func (s *Server) ReloadGeoIP() error {
return s.fw.ReloadGeoIP()
}
func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) {
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
s.logger.Error("accept error", "error", err)
continue
}
s.wg.Add(1)
ls.ActiveConnections.Add(1)
go s.handleConn(ctx, conn, ls)
}
}
func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerState) {
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer conn.Close()
remoteAddr := conn.RemoteAddr().String()
addrPort, err := netip.ParseAddrPort(remoteAddr)
if err != nil {
s.logger.Error("parsing remote address", "addr", remoteAddr, "error", err)
return
}
addr := addrPort.Addr()
if s.fw.Blocked(addr) {
s.logger.Debug("blocked by firewall", "addr", addr)
return
}
hostname, peeked, err := sni.Extract(conn, time.Now().Add(10*time.Second))
if err != nil {
s.logger.Debug("SNI extraction failed", "addr", addr, "error", err)
return
}
backend, ok := ls.lookupRoute(hostname)
if !ok {
s.logger.Debug("no route for hostname", "addr", addr, "hostname", hostname)
return
}
backendConn, err := net.DialTimeout("tcp", backend, s.cfg.Proxy.ConnectTimeout.Duration)
if err != nil {
s.logger.Error("backend dial failed", "hostname", hostname, "backend", backend, "error", err)
return
}
defer backendConn.Close()
s.logger.Debug("proxying", "addr", addr, "hostname", hostname, "backend", backend)
result, err := proxy.Relay(ctx, conn, backendConn, peeked, s.cfg.Proxy.IdleTimeout.Duration)
if err != nil && ctx.Err() == nil {
s.logger.Debug("relay ended", "hostname", hostname, "error", err)
}
s.logger.Info("connection closed",
"addr", addr,
"hostname", hostname,
"client_bytes", result.ClientBytes,
"backend_bytes", result.BackendBytes,
)
}

175
internal/sni/sni.go Normal file
View File

@@ -0,0 +1,175 @@
package sni
import (
"encoding/binary"
"fmt"
"io"
"net"
"strings"
"time"
)
const maxBufferSize = 16384 // 16 KiB, max TLS record size
// Extract reads the TLS ClientHello from conn and returns the SNI hostname.
// The returned peeked bytes contain the full ClientHello and must be forwarded
// to the backend before starting the bidirectional copy.
//
// A read deadline is set on the connection to prevent slowloris attacks.
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{})
// Read TLS record header (5 bytes).
header := make([]byte, 5)
if _, err := io.ReadFull(conn, header); err != nil {
return "", nil, fmt.Errorf("reading TLS record header: %w", err)
}
// Verify this is a TLS handshake record (content type 0x16).
if header[0] != 0x16 {
return "", nil, fmt.Errorf("not a TLS handshake record (type 0x%02x)", header[0])
}
// Record length.
recordLen := int(binary.BigEndian.Uint16(header[3:5]))
if recordLen == 0 || recordLen > maxBufferSize-5 {
return "", nil, fmt.Errorf("TLS record length %d out of range", recordLen)
}
// Read the full record body.
buf := make([]byte, 5+recordLen)
copy(buf, header)
if _, err := io.ReadFull(conn, buf[5:]); err != nil {
return "", nil, fmt.Errorf("reading TLS record body: %w", err)
}
// Parse the handshake message from the record body.
hostname, err = parseClientHello(buf[5:])
if err != nil {
return "", nil, err
}
return hostname, buf, nil
}
func parseClientHello(data []byte) (string, error) {
if len(data) < 4 {
return "", fmt.Errorf("handshake message too short")
}
// Handshake type: 0x01 = ClientHello.
if data[0] != 0x01 {
return "", fmt.Errorf("not a ClientHello (type 0x%02x)", data[0])
}
// Handshake length (3 bytes).
hsLen := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
data = data[4:]
if len(data) < hsLen {
return "", fmt.Errorf("ClientHello truncated")
}
data = data[:hsLen]
// Skip client version (2 bytes) + random (32 bytes).
if len(data) < 34 {
return "", fmt.Errorf("ClientHello too short for version+random")
}
data = data[34:]
// Skip session ID (1-byte length prefix).
if len(data) < 1 {
return "", fmt.Errorf("ClientHello too short for session ID length")
}
sidLen := int(data[0])
data = data[1:]
if len(data) < sidLen {
return "", fmt.Errorf("ClientHello truncated at session ID")
}
data = data[sidLen:]
// Skip cipher suites (2-byte length prefix).
if len(data) < 2 {
return "", fmt.Errorf("ClientHello too short for cipher suites length")
}
csLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < csLen {
return "", fmt.Errorf("ClientHello truncated at cipher suites")
}
data = data[csLen:]
// Skip compression methods (1-byte length prefix).
if len(data) < 1 {
return "", fmt.Errorf("ClientHello too short for compression methods length")
}
cmLen := int(data[0])
data = data[1:]
if len(data) < cmLen {
return "", fmt.Errorf("ClientHello truncated at compression methods")
}
data = data[cmLen:]
// Extensions (2-byte total length).
if len(data) < 2 {
return "", fmt.Errorf("no extensions in ClientHello")
}
extLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < extLen {
return "", fmt.Errorf("ClientHello truncated at extensions")
}
data = data[:extLen]
return findSNI(data)
}
func findSNI(data []byte) (string, error) {
for len(data) >= 4 {
extType := binary.BigEndian.Uint16(data[:2])
extDataLen := int(binary.BigEndian.Uint16(data[2:4]))
data = data[4:]
if len(data) < extDataLen {
return "", fmt.Errorf("extension truncated")
}
if extType == 0x0000 { // server_name
return parseServerNameExtension(data[:extDataLen])
}
data = data[extDataLen:]
}
return "", fmt.Errorf("no SNI extension found")
}
func parseServerNameExtension(data []byte) (string, error) {
if len(data) < 2 {
return "", fmt.Errorf("server_name extension too short")
}
// Server name list length.
listLen := int(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
if len(data) < listLen {
return "", fmt.Errorf("server_name list truncated")
}
data = data[:listLen]
for len(data) >= 3 {
nameType := data[0]
nameLen := int(binary.BigEndian.Uint16(data[1:3]))
data = data[3:]
if len(data) < nameLen {
return "", fmt.Errorf("server_name entry truncated")
}
if nameType == 0x00 { // hostname
return strings.ToLower(string(data[:nameLen])), nil
}
data = data[nameLen:]
}
return "", fmt.Errorf("no hostname in server_name extension")
}

220
internal/sni/sni_test.go Normal file
View File

@@ -0,0 +1,220 @@
package sni
import (
"encoding/binary"
"net"
"testing"
"time"
)
func TestExtract(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{"basic", "example.com", "example.com", false},
{"case insensitive", "FoO.BaR.CoM", "foo.bar.com", false},
{"subdomain", "metacrypt.metacircular.net", "metacrypt.metacircular.net", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHello(tt.sni)
go func() {
client.Write(hello)
}()
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
if tt.wantErr {
if err == nil {
t.Fatal("expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hostname != tt.wantSNI {
t.Fatalf("got hostname %q, want %q", hostname, tt.wantSNI)
}
if len(peeked) != len(hello) {
t.Fatalf("peeked %d bytes, want %d", len(peeked), len(hello))
}
})
}
}
func TestExtractNoSNI(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHelloNoSNI()
go func() {
client.Write(hello)
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for ClientHello without SNI")
}
}
func TestExtractNotTLS(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for non-TLS data")
}
}
func TestExtractTruncated(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
// Write just the TLS record header, then close.
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for truncated record")
}
}
func TestExtractOversizedRecord(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
go func() {
// Record header claiming a length larger than 16 KiB.
header := []byte{0x16, 0x03, 0x01}
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
client.Write(header)
client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
if err == nil {
t.Fatal("expected error for oversized record")
}
}
func TestExtractMultipleExtensions(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
hello := buildClientHelloWithExtraExtensions("target.example.com")
go func() {
client.Write(hello)
}()
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if hostname != "target.example.com" {
t.Fatalf("got hostname %q, want %q", hostname, "target.example.com")
}
}
// buildClientHello constructs a minimal TLS 1.2 ClientHello with an SNI extension.
func buildClientHello(serverName string) []byte {
return buildClientHelloWithExtensions(sniExtension(serverName))
}
// buildClientHelloNoSNI constructs a ClientHello with no extensions.
func buildClientHelloNoSNI() []byte {
return buildClientHelloWithExtensions(nil)
}
// buildClientHelloWithExtraExtensions puts a dummy extension before the SNI.
func buildClientHelloWithExtraExtensions(serverName string) []byte {
// Dummy extension (type 0xFF01, empty data).
dummy := []byte{0xFF, 0x01, 0x00, 0x00}
ext := append(dummy, sniExtension(serverName)...)
return buildClientHelloWithExtensions(ext)
}
func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte
// Client version: TLS 1.2.
hello = append(hello, 0x03, 0x03)
// Random: 32 bytes of zeros.
hello = append(hello, make([]byte, 32)...)
// Session ID: empty.
hello = append(hello, 0x00)
// Cipher suites: one suite (TLS_RSA_WITH_AES_128_GCM_SHA256).
hello = append(hello, 0x00, 0x02, 0x00, 0x9C)
// Compression methods: null.
hello = append(hello, 0x01, 0x00)
// Extensions.
if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
hello = append(hello, extensions...)
}
// Wrap in handshake header (type 0x01 = ClientHello).
handshake := []byte{0x01, 0x00, 0x00, 0x00}
handshake[1] = byte(len(hello) >> 16)
handshake[2] = byte(len(hello) >> 8)
handshake[3] = byte(len(hello))
handshake = append(handshake, hello...)
// Wrap in TLS record header (type 0x16 = handshake, version TLS 1.0).
record := []byte{0x16, 0x03, 0x01}
record = binary.BigEndian.AppendUint16(record, uint16(len(handshake)))
record = append(record, handshake...)
return record
}
func sniExtension(serverName string) []byte {
name := []byte(serverName)
// Server name entry: type 0x00 (hostname), length, name.
var entry []byte
entry = append(entry, 0x00)
entry = binary.BigEndian.AppendUint16(entry, uint16(len(name)))
entry = append(entry, name...)
// Server name list: length prefix.
var list []byte
list = binary.BigEndian.AppendUint16(list, uint16(len(entry)))
list = append(list, entry...)
// Extension: type 0x0000 (server_name), length, data.
var ext []byte
ext = binary.BigEndian.AppendUint16(ext, 0x0000)
ext = binary.BigEndian.AppendUint16(ext, uint16(len(list)))
ext = append(ext, list...)
return ext
}

51
mc-proxy.toml.example Normal file
View File

@@ -0,0 +1,51 @@
# mc-proxy configuration
# Listeners. Each listener binds a TCP port and has its own route table.
[[listeners]]
addr = ":443"
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
[[listeners]]
addr = ":8443"
[[listeners.routes]]
hostname = "metacrypt.metacircular.net"
backend = "127.0.0.1:18443"
[[listeners]]
addr = ":9443"
[[listeners.routes]]
hostname = "mcias.metacircular.net"
backend = "127.0.0.1:28443"
# gRPC admin API. Optional — omit addr to disable.
[grpc]
addr = "127.0.0.1:9090"
tls_cert = "/srv/mc-proxy/certs/cert.pem"
tls_key = "/srv/mc-proxy/certs/key.pem"
client_ca = "/srv/mc-proxy/certs/ca.pem" # mTLS; omit to disable client auth
# Firewall. Global blocklist, evaluated before routing. Default allow.
[firewall]
geoip_db = "/srv/mc-proxy/GeoLite2-Country.mmdb"
blocked_ips = []
blocked_cidrs = []
blocked_countries = ["KP", "CN", "IN", "IL"]
# Proxy behavior.
[proxy]
connect_timeout = "5s"
idle_timeout = "300s"
shutdown_timeout = "30s"
# Logging.
[log]
level = "info"