Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a60e5cb86a | |||
| 4f3249fdc3 | |||
| f31a7f20fb | |||
| feeadc582b | |||
| a45ed03432 | |||
| dc1816b159 |
@@ -9,6 +9,20 @@ run:
|
|||||||
tests: true
|
tests: true
|
||||||
|
|
||||||
linters:
|
linters:
|
||||||
|
exclusions:
|
||||||
|
paths:
|
||||||
|
- vendor
|
||||||
|
rules:
|
||||||
|
# In test files, suppress gosec rules that are false positives in test code:
|
||||||
|
# G101: hardcoded test credentials (intentional fixtures)
|
||||||
|
# G115: integer overflow in type conversions (test TLS packet builders)
|
||||||
|
# G304: file paths from variables (t.TempDir paths)
|
||||||
|
# G402: InsecureSkipVerify (required for test TLS clients)
|
||||||
|
# G705: XSS via taint analysis (test HTTP handlers, not real servers)
|
||||||
|
- path: "_test\\.go"
|
||||||
|
linters:
|
||||||
|
- gosec
|
||||||
|
text: "G101|G115|G304|G402|G705"
|
||||||
default: none
|
default: none
|
||||||
enable:
|
enable:
|
||||||
# --- Correctness ---
|
# --- Correctness ---
|
||||||
@@ -52,12 +66,15 @@ linters:
|
|||||||
check-type-assertions: true
|
check-type-assertions: true
|
||||||
|
|
||||||
govet:
|
govet:
|
||||||
# Enable all analyzers except shadow. The shadow analyzer flags the idiomatic
|
# Enable all analyzers except shadow and fieldalignment. The shadow analyzer
|
||||||
# `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`,
|
# flags the idiomatic `if err := f(); err != nil { ... }` pattern as shadowing
|
||||||
# which is ubiquitous in Go and does not pose a security risk in this codebase.
|
# an outer `err`, which is ubiquitous in Go. The fieldalignment analyzer
|
||||||
|
# suggests struct field reordering for memory efficiency — useful as a one-off
|
||||||
|
# audit but too noisy for CI (every struct change triggers it).
|
||||||
enable-all: true
|
enable-all: true
|
||||||
disable:
|
disable:
|
||||||
- shadow
|
- shadow
|
||||||
|
- fieldalignment
|
||||||
|
|
||||||
gosec:
|
gosec:
|
||||||
# Treat all gosec findings as errors, not warnings.
|
# Treat all gosec findings as errors, not warnings.
|
||||||
@@ -110,15 +127,3 @@ issues:
|
|||||||
# Do not cap the number of reported issues; in security code every finding matters.
|
# Do not cap the number of reported issues; in security code every finding matters.
|
||||||
max-issues-per-linter: 0
|
max-issues-per-linter: 0
|
||||||
max-same-issues: 0
|
max-same-issues: 0
|
||||||
|
|
||||||
exclusions:
|
|
||||||
paths:
|
|
||||||
- vendor
|
|
||||||
rules:
|
|
||||||
# In test files, allow hardcoded test credentials (gosec G101) since they are
|
|
||||||
# intentional fixtures, not production secrets.
|
|
||||||
- path: "_test\\.go"
|
|
||||||
linters:
|
|
||||||
- gosec
|
|
||||||
text: "G101"
|
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ go test ./internal/sni -run TestExtract
|
|||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
- **Module path**: `git.wntrmute.dev/kyle/mc-proxy`
|
- **Module path**: `git.wntrmute.dev/mc/mc-proxy`
|
||||||
- **Go with CGO_ENABLED=0**, statically linked, Alpine containers
|
- **Go with CGO_ENABLED=0**, statically linked, Alpine containers
|
||||||
- **Dual mode, per-route** — L4 (passthrough) and L7 (TLS-terminating HTTP/2 reverse proxy) coexist on the same listener
|
- **Dual mode, per-route** — L4 (passthrough) and L7 (TLS-terminating HTTP/2 reverse proxy) coexist on the same listener
|
||||||
- **PROXY protocol** — listeners accept v1/v2; routes send v2. Enables edge→origin deployments over Tailscale
|
- **PROXY protocol** — listeners accept v1/v2; routes send v2. Enables edge→origin deployments over Tailscale
|
||||||
|
|||||||
4
Makefile
4
Makefile
@@ -21,8 +21,8 @@ lint:
|
|||||||
golangci-lint run ./...
|
golangci-lint run ./...
|
||||||
|
|
||||||
proto:
|
proto:
|
||||||
protoc --go_out=. --go_opt=module=git.wntrmute.dev/kyle/mc-proxy \
|
protoc --go_out=. --go_opt=module=git.wntrmute.dev/mc/mc-proxy \
|
||||||
--go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/kyle/mc-proxy \
|
--go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/mc/mc-proxy \
|
||||||
proto/mc_proxy/v1/*.proto
|
proto/mc_proxy/v1/*.proto
|
||||||
|
|
||||||
proto-lint:
|
proto-lint:
|
||||||
|
|||||||
50
RUNBOOK.md
50
RUNBOOK.md
@@ -187,6 +187,56 @@ grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \
|
|||||||
-d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}'
|
-d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}'
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Deployment with MCP
|
||||||
|
|
||||||
|
mc-proxy runs on rift as a single container managed by MCP. The service
|
||||||
|
definition lives at `~/.config/mcp/services/mc-proxy.toml` on rift (reference
|
||||||
|
copy at `deploy/mc-proxy-rift.toml` in this repo). The container mounts
|
||||||
|
`/srv/mc-proxy` which holds the config file, SQLite database, GeoIP database,
|
||||||
|
and TLS certificates for backends. It runs as `--user 0:0` under rootless
|
||||||
|
podman.
|
||||||
|
|
||||||
|
Listeners: `:443` (L7 terminating), `:8443` (L4 passthrough), `:9443` (L4
|
||||||
|
passthrough).
|
||||||
|
|
||||||
|
### Deploy or Update
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mcp deploy mc-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restart / Stop
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mcp restart mc-proxy
|
||||||
|
mcp stop mc-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
### Check Status
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mcp ps
|
||||||
|
mcp status mc-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
### View Logs
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ssh rift 'doas su - mcp -s /bin/sh -c "podman logs mc-proxy"'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Update Routes
|
||||||
|
|
||||||
|
Edit the config at `/srv/mc-proxy/mc-proxy.toml` on rift, then restart:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mcp restart mc-proxy
|
||||||
|
```
|
||||||
|
|
||||||
|
Routes added at runtime via the gRPC admin API are persisted in the database
|
||||||
|
and survive restarts. Editing the TOML config is only necessary for changing
|
||||||
|
listener definitions or static seed routes.
|
||||||
|
|
||||||
## Incident Procedures
|
## Incident Procedures
|
||||||
|
|
||||||
### Proxy Not Starting
|
### Proxy Not Starting
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Client provides access to the mc-proxy admin API.
|
// Client provides access to the mc-proxy admin API.
|
||||||
|
|||||||
@@ -15,12 +15,12 @@ import (
|
|||||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||||
"google.golang.org/grpc/test/bufconn"
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
"git.wntrmute.dev/mc/mc-proxy/internal/db"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver"
|
"git.wntrmute.dev/mc/mc-proxy/internal/grpcserver"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
"git.wntrmute.dev/mc/mc-proxy/internal/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupTestClient(t *testing.T) *Client {
|
func setupTestClient(t *testing.T) *Client {
|
||||||
@@ -32,7 +32,7 @@ func setupTestClient(t *testing.T) *Client {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open db: %v", err)
|
t.Fatalf("open db: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { store.Close() })
|
t.Cleanup(func() { _ = store.Close() })
|
||||||
|
|
||||||
if err := store.Migrate(); err != nil {
|
if err := store.Migrate(); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
@@ -128,7 +128,7 @@ func setupTestClient(t *testing.T) *Client {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial bufconn: %v", err)
|
t.Fatalf("dial bufconn: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { conn.Close() })
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
"git.wntrmute.dev/mc/mc-proxy/internal/db"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver"
|
"git.wntrmute.dev/mc/mc-proxy/internal/grpcserver"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
|
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
"git.wntrmute.dev/mc/mc-proxy/internal/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
func serverCmd() *cobra.Command {
|
func serverCmd() *cobra.Command {
|
||||||
@@ -40,7 +40,7 @@ func serverCmd() *cobra.Command {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("opening database: %w", err)
|
return fmt.Errorf("opening database: %w", err)
|
||||||
}
|
}
|
||||||
defer store.Close()
|
defer func() { _ = store.Close() }()
|
||||||
|
|
||||||
if err := store.Migrate(); err != nil {
|
if err := store.Migrate(); err != nil {
|
||||||
return fmt.Errorf("running migrations: %w", err)
|
return fmt.Errorf("running migrations: %w", err)
|
||||||
@@ -93,7 +93,7 @@ func serverCmd() *cobra.Command {
|
|||||||
}()
|
}()
|
||||||
defer func() {
|
defer func() {
|
||||||
grpcSrv.GracefulStop()
|
grpcSrv.GracefulStop()
|
||||||
os.Remove(cfg.GRPC.SocketPath())
|
_ = os.Remove(cfg.GRPC.SocketPath())
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
"git.wntrmute.dev/mc/mc-proxy/internal/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
func snapshotCmd() *cobra.Command {
|
func snapshotCmd() *cobra.Command {
|
||||||
@@ -32,7 +32,7 @@ func snapshotCmd() *cobra.Command {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("opening database: %w", err)
|
return fmt.Errorf("opening database: %w", err)
|
||||||
}
|
}
|
||||||
defer store.Close()
|
defer func() { _ = store.Close() }()
|
||||||
|
|
||||||
dataDir := filepath.Dir(cfg.Database.Path)
|
dataDir := filepath.Dir(cfg.Database.Path)
|
||||||
|
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func statusCmd() *cobra.Command {
|
func statusCmd() *cobra.Command {
|
||||||
@@ -33,7 +33,7 @@ func statusCmd() *cobra.Command {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("connecting to gRPC API: %w", err)
|
return fmt.Errorf("connecting to gRPC API: %w", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
client := pb.NewProxyAdminServiceClient(conn)
|
client := pb.NewProxyAdminServiceClient(conn)
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
"git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
func firewallCmd() *cobra.Command {
|
func firewallCmd() *cobra.Command {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
"git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
func healthCmd() *cobra.Command {
|
func healthCmd() *cobra.Command {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
mcproxy "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
func policiesCmd() *cobra.Command {
|
func policiesCmd() *cobra.Command {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
"git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
const defaultSocketPath = "/srv/mc-proxy/mc-proxy.sock"
|
const defaultSocketPath = "/srv/mc-proxy/mc-proxy.sock"
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
mcproxy "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
func routesCmd() *cobra.Command {
|
func routesCmd() *cobra.Command {
|
||||||
|
|||||||
@@ -10,7 +10,7 @@
|
|||||||
let
|
let
|
||||||
system = "x86_64-linux";
|
system = "x86_64-linux";
|
||||||
pkgs = nixpkgs.legacyPackages.${system};
|
pkgs = nixpkgs.legacyPackages.${system};
|
||||||
version = "0.1.0";
|
version = "1.1.0";
|
||||||
in
|
in
|
||||||
{
|
{
|
||||||
packages.${system} = {
|
packages.${system} = {
|
||||||
|
|||||||
@@ -1449,7 +1449,7 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" +
|
|||||||
"\x0eListL7Policies\x12\".mc_proxy.v1.ListL7PoliciesRequest\x1a#.mc_proxy.v1.ListL7PoliciesResponse\x12P\n" +
|
"\x0eListL7Policies\x12\".mc_proxy.v1.ListL7PoliciesRequest\x1a#.mc_proxy.v1.ListL7PoliciesResponse\x12P\n" +
|
||||||
"\vAddL7Policy\x12\x1f.mc_proxy.v1.AddL7PolicyRequest\x1a .mc_proxy.v1.AddL7PolicyResponse\x12Y\n" +
|
"\vAddL7Policy\x12\x1f.mc_proxy.v1.AddL7PolicyRequest\x1a .mc_proxy.v1.AddL7PolicyResponse\x12Y\n" +
|
||||||
"\x0eRemoveL7Policy\x12\".mc_proxy.v1.RemoveL7PolicyRequest\x1a#.mc_proxy.v1.RemoveL7PolicyResponse\x12J\n" +
|
"\x0eRemoveL7Policy\x12\".mc_proxy.v1.RemoveL7PolicyRequest\x1a#.mc_proxy.v1.RemoveL7PolicyResponse\x12J\n" +
|
||||||
"\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB:Z8git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3"
|
"\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB8Z6git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
file_proto_mc_proxy_v1_admin_proto_rawDescOnce sync.Once
|
file_proto_mc_proxy_v1_admin_proto_rawDescOnce sync.Once
|
||||||
|
|||||||
4
go.mod
4
go.mod
@@ -1,9 +1,9 @@
|
|||||||
module git.wntrmute.dev/kyle/mc-proxy
|
module git.wntrmute.dev/mc/mc-proxy
|
||||||
|
|
||||||
go 1.25.7
|
go 1.25.7
|
||||||
|
|
||||||
require (
|
require (
|
||||||
git.wntrmute.dev/kyle/mcdsl v1.0.0
|
git.wntrmute.dev/mc/mcdsl v1.2.0
|
||||||
github.com/oschwald/maxminddb-golang v1.13.1
|
github.com/oschwald/maxminddb-golang v1.13.1
|
||||||
github.com/prometheus/client_golang v1.23.2
|
github.com/prometheus/client_golang v1.23.2
|
||||||
github.com/spf13/cobra v1.10.2
|
github.com/spf13/cobra v1.10.2
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -1,5 +1,5 @@
|
|||||||
git.wntrmute.dev/kyle/mcdsl v1.0.0 h1:YB7dx4gdNYKKcVySpL6UkwHqdCJ9Nl1yS0+eHk0hNtk=
|
git.wntrmute.dev/mc/mcdsl v1.2.0 h1:41hep7/PNZJfN0SN/nM+rQpyF1GSZcvNNjyVG81DI7U=
|
||||||
git.wntrmute.dev/kyle/mcdsl v1.0.0/go.mod h1:wo0tGfUAxci3XnOe4/rFmR0RjUElKdYUazc+Np986sg=
|
git.wntrmute.dev/mc/mcdsl v1.2.0/go.mod h1:lXYrAt74ZUix6rx9oVN8d2zH1YJoyp4uxPVKQ+SSxuM=
|
||||||
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
|
||||||
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
|
||||||
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config"
|
mcdslconfig "git.wntrmute.dev/mc/mcdsl/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Duration is an alias for the mcdsl config.Duration type, which wraps
|
// Duration is an alias for the mcdsl config.Duration type, which wraps
|
||||||
|
|||||||
@@ -304,7 +304,7 @@ func TestDuration(t *testing.T) {
|
|||||||
if err := d.UnmarshalText([]byte("5s")); err != nil {
|
if err := d.UnmarshalText([]byte("5s")); err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
if d.Duration.Seconds() != 5 {
|
if d.Seconds() != 5 {
|
||||||
t.Fatalf("got %v, want 5s", d.Duration)
|
t.Fatalf("got %v, want 5s", d.Duration)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -340,7 +340,7 @@ level = "info"
|
|||||||
if cfg.Log.Level != "debug" {
|
if cfg.Log.Level != "debug" {
|
||||||
t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug")
|
t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug")
|
||||||
}
|
}
|
||||||
if cfg.Proxy.IdleTimeout.Duration.Seconds() != 600 {
|
if cfg.Proxy.IdleTimeout.Seconds() != 600 {
|
||||||
t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration)
|
t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration)
|
||||||
}
|
}
|
||||||
if cfg.Database.Path != "/override/test.db" {
|
if cfg.Database.Path != "/override/test.db" {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Store wraps a SQLite database connection for mc-proxy persistence.
|
// Store wraps a SQLite database connection for mc-proxy persistence.
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"path/filepath"
|
"path/filepath"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func openTestDB(t *testing.T) *Store {
|
func openTestDB(t *testing.T) *Store {
|
||||||
@@ -17,7 +17,7 @@ func openTestDB(t *testing.T) *Store {
|
|||||||
if err := store.Migrate(); err != nil {
|
if err := store.Migrate(); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { store.Close() })
|
t.Cleanup(func() { _ = store.Close() })
|
||||||
return store
|
return store
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,8 +266,8 @@ func TestRouteCascadeDelete(t *testing.T) {
|
|||||||
store := openTestDB(t)
|
store := openTestDB(t)
|
||||||
|
|
||||||
listenerID, _ := store.CreateListener(":443", false, 0)
|
listenerID, _ := store.CreateListener(":443", false, 0)
|
||||||
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
|
_, _ = store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
|
||||||
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
|
_, _ = store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
|
||||||
|
|
||||||
if err := store.DeleteListener(listenerID); err != nil {
|
if err := store.DeleteListener(listenerID); err != nil {
|
||||||
t.Fatalf("delete listener: %v", err)
|
t.Fatalf("delete listener: %v", err)
|
||||||
@@ -412,7 +412,7 @@ func TestSeed(t *testing.T) {
|
|||||||
func TestSnapshot(t *testing.T) {
|
func TestSnapshot(t *testing.T) {
|
||||||
store := openTestDB(t)
|
store := openTestDB(t)
|
||||||
|
|
||||||
store.CreateListener(":443", false, 0)
|
_, _ = store.CreateListener(":443", false, 0)
|
||||||
|
|
||||||
dest := filepath.Join(t.TempDir(), "backup.db")
|
dest := filepath.Join(t.TempDir(), "backup.db")
|
||||||
if err := store.Snapshot(dest); err != nil {
|
if err := store.Snapshot(dest); err != nil {
|
||||||
@@ -424,7 +424,7 @@ func TestSnapshot(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open backup: %v", err)
|
t.Fatalf("open backup: %v", err)
|
||||||
}
|
}
|
||||||
defer backup.Close()
|
defer func() { _ = backup.Close() }()
|
||||||
|
|
||||||
if err := backup.Migrate(); err != nil {
|
if err := backup.Migrate(); err != nil {
|
||||||
t.Fatalf("migrate backup: %v", err)
|
t.Fatalf("migrate backup: %v", err)
|
||||||
@@ -463,7 +463,7 @@ func TestMigrationV2Upgrade(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open: %v", err)
|
t.Fatalf("open: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { store.Close() })
|
t.Cleanup(func() { _ = store.Close() })
|
||||||
|
|
||||||
// Run full migrations (v1 + v2).
|
// Run full migrations (v1 + v2).
|
||||||
if err := store.Migrate(); err != nil {
|
if err := store.Migrate(); err != nil {
|
||||||
@@ -556,10 +556,10 @@ func TestL7PolicyCascadeDelete(t *testing.T) {
|
|||||||
|
|
||||||
lid, _ := store.CreateListener(":443", false, 0)
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
store.CreateL7Policy(rid, "block_user_agent", "Bot")
|
_, _ = store.CreateL7Policy(rid, "block_user_agent", "Bot")
|
||||||
|
|
||||||
// Deleting the route should cascade-delete its policies.
|
// Deleting the route should cascade-delete its policies.
|
||||||
store.DeleteRoute(lid, "api.test")
|
_ = store.DeleteRoute(lid, "api.test")
|
||||||
|
|
||||||
policies, _ := store.ListL7Policies(rid)
|
policies, _ := store.ListL7Policies(rid)
|
||||||
if len(policies) != 0 {
|
if len(policies) != 0 {
|
||||||
@@ -585,7 +585,7 @@ func TestGetRouteID(t *testing.T) {
|
|||||||
store := openTestDB(t)
|
store := openTestDB(t)
|
||||||
|
|
||||||
lid, _ := store.CreateListener(":443", false, 0)
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
_, _ = store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
|
|
||||||
rid, err := store.GetRouteID(lid, "api.test")
|
rid, err := store.GetRouteID(lid, "api.test")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func (s *Store) ListFirewallRules() ([]FirewallRule, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("querying firewall rules: %w", err)
|
return nil, fmt.Errorf("querying firewall rules: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
var rules []FirewallRule
|
var rules []FirewallRule
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("querying l7 policies: %w", err)
|
return nil, fmt.Errorf("querying l7 policies: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
var policies []L7Policy
|
var policies []L7Policy
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("querying listeners: %w", err)
|
return nil, fmt.Errorf("querying listeners: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
var listeners []Listener
|
var listeners []Listener
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Migrations is the ordered list of schema migrations for mc-proxy.
|
// Migrations is the ordered list of schema migrations for mc-proxy.
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("querying routes: %w", err)
|
return nil, fmt.Errorf("querying routes: %w", err)
|
||||||
}
|
}
|
||||||
defer rows.Close()
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
var routes []Route
|
var routes []Route
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Seed populates the database from TOML config data. Only called when the
|
// Seed populates the database from TOML config data. Only called when the
|
||||||
@@ -14,7 +14,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("beginning seed transaction: %w", err)
|
return fmt.Errorf("beginning seed transaction: %w", err)
|
||||||
}
|
}
|
||||||
defer tx.Rollback()
|
defer func() { _ = tx.Rollback() }()
|
||||||
|
|
||||||
for _, l := range listeners {
|
for _, l := range listeners {
|
||||||
result, err := tx.Exec(
|
result, err := tx.Exec(
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package db
|
package db
|
||||||
|
|
||||||
import (
|
import (
|
||||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Snapshot creates a consistent backup of the database using VACUUM INTO.
|
// Snapshot creates a consistent backup of the database using VACUUM INTO.
|
||||||
|
|||||||
@@ -234,7 +234,7 @@ func (f *Firewall) loadGeoDB(path string) error {
|
|||||||
f.mu.Unlock()
|
f.mu.Unlock()
|
||||||
|
|
||||||
if old != nil {
|
if old != nil {
|
||||||
old.Close()
|
_ = old.Close()
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ func TestEmptyFirewall(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
|
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
|
||||||
for _, a := range addrs {
|
for _, a := range addrs {
|
||||||
@@ -27,7 +27,7 @@ func TestIPBlocking(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -52,7 +52,7 @@ func TestCIDRBlocking(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -78,7 +78,7 @@ func TestIPv4MappedIPv6(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
addr := netip.MustParseAddr("::ffff:192.0.2.1")
|
addr := netip.MustParseAddr("::ffff:192.0.2.1")
|
||||||
if !fw.Blocked(addr) {
|
if !fw.Blocked(addr) {
|
||||||
@@ -105,7 +105,7 @@ func TestCombinedRules(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -130,7 +130,7 @@ func TestRateLimitBlocking(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
addr := netip.MustParseAddr("10.0.0.1")
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
|
|
||||||
@@ -151,7 +151,7 @@ func TestRateLimitBlocklistFirst(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
blockedAddr := netip.MustParseAddr("10.0.0.1")
|
blockedAddr := netip.MustParseAddr("10.0.0.1")
|
||||||
otherAddr := netip.MustParseAddr("10.0.0.2")
|
otherAddr := netip.MustParseAddr("10.0.0.2")
|
||||||
@@ -175,7 +175,7 @@ func TestBlockedWithReason(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
addr string
|
addr string
|
||||||
@@ -216,7 +216,7 @@ func TestRuntimeMutation(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unexpected error: %v", err)
|
t.Fatalf("unexpected error: %v", err)
|
||||||
}
|
}
|
||||||
defer fw.Close()
|
defer func() { _ = fw.Close() }()
|
||||||
|
|
||||||
addr := netip.MustParseAddr("10.0.0.1")
|
addr := netip.MustParseAddr("10.0.0.1")
|
||||||
if fw.Blocked(addr) {
|
if fw.Blocked(addr) {
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func (rl *rateLimiter) Allow(addr netip.Addr) bool {
|
|||||||
now := rl.now().UnixNano()
|
now := rl.now().UnixNano()
|
||||||
|
|
||||||
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
|
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
|
||||||
entry := val.(*rateLimitEntry)
|
entry, _ := val.(*rateLimitEntry)
|
||||||
|
|
||||||
windowStart := entry.start.Load()
|
windowStart := entry.start.Load()
|
||||||
if now-windowStart >= rl.window.Nanoseconds() {
|
if now-windowStart >= rl.window.Nanoseconds() {
|
||||||
@@ -70,7 +70,7 @@ func (rl *rateLimiter) cleanup() {
|
|||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
|
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
|
||||||
rl.entries.Range(func(key, value any) bool {
|
rl.entries.Range(func(key, value any) bool {
|
||||||
entry := value.(*rateLimitEntry)
|
entry, _ := value.(*rateLimitEntry)
|
||||||
if entry.start.Load() < cutoff {
|
if entry.start.Load() < cutoff {
|
||||||
rl.entries.Delete(key)
|
rl.entries.Delete(key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,10 +17,10 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
"git.wntrmute.dev/mc/mc-proxy/internal/db"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
"git.wntrmute.dev/mc/mc-proxy/internal/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
|
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
|
||||||
@@ -53,7 +53,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
|
|||||||
path := cfg.SocketPath()
|
path := cfg.SocketPath()
|
||||||
|
|
||||||
// Remove stale socket file from a previous run.
|
// Remove stale socket file from a previous run.
|
||||||
os.Remove(path)
|
_ = os.Remove(path)
|
||||||
|
|
||||||
ln, err := net.Listen("unix", path)
|
ln, err := net.Listen("unix", path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -61,7 +61,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Chmod(path, 0600); err != nil {
|
if err := os.Chmod(path, 0600); err != nil {
|
||||||
ln.Close()
|
_ = ln.Close()
|
||||||
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
|
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -446,7 +446,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.
|
|||||||
}
|
}
|
||||||
listeners = append(listeners, &pb.ListenerStatus{
|
listeners = append(listeners, &pb.ListenerStatus{
|
||||||
Addr: ls.Addr,
|
Addr: ls.Addr,
|
||||||
RouteCount: int32(len(routes)),
|
RouteCount: int32(len(routes)), //nolint:gosec // route count can never exceed int32
|
||||||
ActiveConnections: ls.ActiveConnections.Load(),
|
ActiveConnections: ls.ActiveConnections.Load(),
|
||||||
ProxyProtocol: ls.ProxyProtocol,
|
ProxyProtocol: ls.ProxyProtocol,
|
||||||
MaxConnections: ls.MaxConnections,
|
MaxConnections: ls.MaxConnections,
|
||||||
|
|||||||
@@ -15,11 +15,11 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/grpc/test/bufconn"
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
|
||||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
"git.wntrmute.dev/mc/mc-proxy/internal/db"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
"git.wntrmute.dev/mc/mc-proxy/internal/server"
|
||||||
)
|
)
|
||||||
|
|
||||||
// testEnv bundles all the objects needed for a grpcserver test.
|
// testEnv bundles all the objects needed for a grpcserver test.
|
||||||
@@ -39,7 +39,7 @@ func setup(t *testing.T) *testEnv {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("open db: %v", err)
|
t.Fatalf("open db: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { store.Close() })
|
t.Cleanup(func() { _ = store.Close() })
|
||||||
|
|
||||||
if err := store.Migrate(); err != nil {
|
if err := store.Migrate(); err != nil {
|
||||||
t.Fatalf("migrate: %v", err)
|
t.Fatalf("migrate: %v", err)
|
||||||
@@ -130,7 +130,7 @@ func setup(t *testing.T) *testEnv {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial bufconn: %v", err)
|
t.Fatalf("dial bufconn: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { conn.Close() })
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
return &testEnv{
|
return &testEnv{
|
||||||
client: pb.NewProxyAdminServiceClient(conn),
|
client: pb.NewProxyAdminServiceClient(conn),
|
||||||
@@ -775,7 +775,7 @@ func TestRemoveL7Policy(t *testing.T) {
|
|||||||
env := setup(t)
|
env := setup(t)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
_, _ = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
ListenerAddr: ":443",
|
ListenerAddr: ":443",
|
||||||
Hostname: "a.test",
|
Hostname: "a.test",
|
||||||
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
|
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PolicyRule defines an L7 blocking policy.
|
// PolicyRule defines an L7 blocking policy.
|
||||||
|
|||||||
@@ -13,27 +13,27 @@ func TestPrefixConnRead(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
conn.Write([]byte("WORLD"))
|
_, _ = conn.Write([]byte("WORLD"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial: %v", err)
|
t.Fatalf("dial: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
pc := NewPrefixConn(conn, []byte("HELLO"))
|
pc := NewPrefixConn(conn, []byte("HELLO"))
|
||||||
|
|
||||||
// Read all data: should get "HELLOWORLD".
|
// Read all data: should get "HELLOWORLD".
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
all, err := io.ReadAll(pc)
|
all, err := io.ReadAll(pc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadAll: %v", err)
|
t.Fatalf("ReadAll: %v", err)
|
||||||
@@ -48,22 +48,22 @@ func TestPrefixConnSmallReads(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
conn.Write([]byte("CD"))
|
_, _ = conn.Write([]byte("CD"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial: %v", err)
|
t.Fatalf("dial: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
pc := NewPrefixConn(conn, []byte("AB"))
|
pc := NewPrefixConn(conn, []byte("AB"))
|
||||||
|
|
||||||
@@ -79,7 +79,7 @@ func TestPrefixConnSmallReads(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Now reads come from the underlying conn.
|
// Now reads come from the underlying conn.
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
rest, err := io.ReadAll(pc)
|
rest, err := io.ReadAll(pc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadAll: %v", err)
|
t.Fatalf("ReadAll: %v", err)
|
||||||
@@ -94,25 +94,25 @@ func TestPrefixConnEmptyPrefix(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := ln.Accept()
|
conn, err := ln.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
conn.Write([]byte("DATA"))
|
_, _ = conn.Write([]byte("DATA"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
conn, err := net.Dial("tcp", ln.Addr().String())
|
conn, err := net.Dial("tcp", ln.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial: %v", err)
|
t.Fatalf("dial: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
pc := NewPrefixConn(conn, nil)
|
pc := NewPrefixConn(conn, nil)
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
all, err := io.ReadAll(pc)
|
all, err := io.ReadAll(pc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("ReadAll: %v", err)
|
t.Fatalf("ReadAll: %v", err)
|
||||||
@@ -127,12 +127,12 @@ func TestPrefixConnDelegates(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, _ := ln.Accept()
|
conn, _ := ln.Accept()
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -140,7 +140,7 @@ func TestPrefixConnDelegates(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial: %v", err)
|
t.Fatalf("dial: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
pc := NewPrefixConn(conn, []byte("X"))
|
pc := NewPrefixConn(conn, []byte("X"))
|
||||||
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
|
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -120,7 +120,7 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
|
|||||||
ReadHeaderTimeout: 30 * time.Second,
|
ReadHeaderTimeout: 30 * time.Second,
|
||||||
}
|
}
|
||||||
singleConn := newSingleConnListener(tlsConn)
|
singleConn := newSingleConnListener(tlsConn)
|
||||||
srv.Serve(singleConn)
|
_ = srv.Serve(singleConn)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -186,16 +186,16 @@ func newTransport(route RouteConfig) (http.RoundTripper, error) {
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// h2c: HTTP/2 over plaintext TCP.
|
// Plain HTTP backend. Use standard http.Transport which speaks
|
||||||
return &http2.Transport{
|
// HTTP/1.1 by default and can upgrade to h2c if the backend
|
||||||
AllowHTTP: true,
|
// supports it. This handles backends like Gitea that only speak
|
||||||
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) {
|
// HTTP/1.1.
|
||||||
conn, err := dialBackend(ctx, network, addr, connectTimeout, route)
|
return &http.Transport{
|
||||||
if err != nil {
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
return nil, err
|
return dialBackend(ctx, network, addr, connectTimeout, route)
|
||||||
}
|
|
||||||
return conn, nil
|
|
||||||
},
|
},
|
||||||
|
MaxIdleConnsPerHost: 10,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,7 +213,7 @@ func dialBackend(ctx context.Context, network, addr string, timeout time.Duratio
|
|||||||
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
|
||||||
if clientAddr.IsValid() {
|
if clientAddr.IsValid() {
|
||||||
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -58,8 +58,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating cert file: %v", err)
|
t.Fatalf("creating cert file: %v", err)
|
||||||
}
|
}
|
||||||
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
_ = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
certFile.Close()
|
_ = certFile.Close()
|
||||||
|
|
||||||
keyDER, err := x509.MarshalECPrivateKey(key)
|
keyDER, err := x509.MarshalECPrivateKey(key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -69,8 +69,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("creating key file: %v", err)
|
t.Fatalf("creating key file: %v", err)
|
||||||
}
|
}
|
||||||
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
_ = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
keyFile.Close()
|
_ = keyFile.Close()
|
||||||
|
|
||||||
return certPath, keyPath
|
return certPath, keyPath
|
||||||
}
|
}
|
||||||
@@ -91,11 +91,11 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
|
|||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
srv.Close()
|
_ = srv.Close()
|
||||||
ln.Close()
|
_ = ln.Close()
|
||||||
})
|
})
|
||||||
|
|
||||||
go srv.Serve(ln)
|
go func() { _ = srv.Serve(ln) }()
|
||||||
return ln.Addr().String()
|
return ln.Addr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -118,7 +118,7 @@ func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TLS dial: %v", err)
|
t.Fatalf("TLS dial: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { conn.Close() })
|
t.Cleanup(func() { _ = conn.Close() })
|
||||||
|
|
||||||
// Create an HTTP/2 client transport over this single connection.
|
// Create an HTTP/2 client transport over this single connection.
|
||||||
tr := &http2.Transport{}
|
tr := &http2.Transport{}
|
||||||
@@ -142,29 +142,13 @@ func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
|
|||||||
return s.cc.RoundTrip(req)
|
return s.cc.RoundTrip(req)
|
||||||
}
|
}
|
||||||
|
|
||||||
// serveL7Route starts l7.Serve in a goroutine for a single connection.
|
|
||||||
// Returns when the goroutine completes.
|
|
||||||
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
|
|
||||||
t.Helper()
|
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
||||||
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
|
|
||||||
if l7Err != nil {
|
|
||||||
t.Logf("l7.Serve: %v", l7Err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestL7H2CBackend(t *testing.T) {
|
func TestL7H2CBackend(t *testing.T) {
|
||||||
certPath, keyPath := testCert(t, "l7.test")
|
certPath, keyPath := testCert(t, "l7.test")
|
||||||
|
|
||||||
// Start an h2c backend.
|
// Start an h2c backend.
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("X-Backend", "ok")
|
w.Header().Set("X-Backend", "ok")
|
||||||
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
|
_, _ = fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Start a TCP listener for the L7 proxy.
|
// Start a TCP listener for the L7 proxy.
|
||||||
@@ -172,7 +156,7 @@ func TestL7H2CBackend(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -190,17 +174,17 @@ func TestL7H2CBackend(t *testing.T) {
|
|||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
// No peeked bytes — the client is connecting directly with TLS.
|
// No peeked bytes — the client is connecting directly with TLS.
|
||||||
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect as an HTTP/2 TLS client.
|
// Connect as an HTTP/2 TLS client.
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
|
||||||
|
|
||||||
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
|
resp, err := client.Get("https://l7.test/foo")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
||||||
@@ -221,7 +205,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
|
|||||||
|
|
||||||
// Backend that echoes the forwarding headers.
|
// Backend that echoes the forwarding headers.
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
|
_, _ = fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
|
||||||
r.Header.Get("X-Forwarded-For"),
|
r.Header.Get("X-Forwarded-For"),
|
||||||
r.Header.Get("X-Forwarded-Proto"),
|
r.Header.Get("X-Forwarded-Proto"),
|
||||||
r.Header.Get("X-Real-IP"),
|
r.Header.Get("X-Real-IP"),
|
||||||
@@ -232,7 +216,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -248,7 +232,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
|
|||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
|
||||||
@@ -256,7 +240,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
|
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
|
||||||
@@ -274,13 +258,13 @@ func TestL7BackendUnreachable(t *testing.T) {
|
|||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
deadAddr := ln.Addr().String()
|
deadAddr := ln.Addr().String()
|
||||||
ln.Close()
|
_ = ln.Close()
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: deadAddr,
|
Backend: deadAddr,
|
||||||
@@ -296,7 +280,7 @@ func TestL7BackendUnreachable(t *testing.T) {
|
|||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
|
||||||
@@ -304,7 +288,7 @@ func TestL7BackendUnreachable(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusBadGateway {
|
if resp.StatusCode != http.StatusBadGateway {
|
||||||
t.Fatalf("status = %d, want 502", resp.StatusCode)
|
t.Fatalf("status = %d, want 502", resp.StatusCode)
|
||||||
@@ -342,14 +326,14 @@ func TestL7MultipleRequests(t *testing.T) {
|
|||||||
var reqCount int
|
var reqCount int
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
reqCount++
|
reqCount++
|
||||||
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
|
_, _ = fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -365,7 +349,7 @@ func TestL7MultipleRequests(t *testing.T) {
|
|||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
|
||||||
Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
|
||||||
@@ -378,7 +362,7 @@ func TestL7MultipleRequests(t *testing.T) {
|
|||||||
t.Fatalf("GET %s: %v", path, err)
|
t.Fatalf("GET %s: %v", path, err)
|
||||||
}
|
}
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
want := fmt.Sprintf("req=%d path=%s", i+1, path)
|
want := fmt.Sprintf("req=%d path=%s", i+1, path)
|
||||||
if string(body) != want {
|
if string(body) != want {
|
||||||
@@ -396,14 +380,14 @@ func TestL7LargeResponse(t *testing.T) {
|
|||||||
largeBody[i] = byte(i % 256)
|
largeBody[i] = byte(i % 256)
|
||||||
}
|
}
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Write(largeBody)
|
_, _ = w.Write(largeBody)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -418,7 +402,7 @@ func TestL7LargeResponse(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
|
||||||
@@ -426,7 +410,7 @@ func TestL7LargeResponse(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if len(body) != len(largeBody) {
|
if len(body) != len(largeBody) {
|
||||||
@@ -455,7 +439,7 @@ func TestL7GRPCTrailers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -470,7 +454,7 @@ func TestL7GRPCTrailers(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
|
||||||
@@ -480,10 +464,10 @@ func TestL7GRPCTrailers(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("POST: %v", err)
|
t.Fatalf("POST: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
// Read body to trigger trailer delivery.
|
// Read body to trigger trailer delivery.
|
||||||
io.ReadAll(resp.Body)
|
_, _ = io.ReadAll(resp.Body)
|
||||||
|
|
||||||
// Verify trailers were forwarded through the proxy.
|
// Verify trailers were forwarded through the proxy.
|
||||||
grpcStatus := resp.Trailer.Get("Grpc-Status")
|
grpcStatus := resp.Trailer.Get("Grpc-Status")
|
||||||
@@ -500,14 +484,14 @@ func TestL7HTTP11Fallback(t *testing.T) {
|
|||||||
certPath, keyPath := testCert(t, "http11.test")
|
certPath, keyPath := testCert(t, "http11.test")
|
||||||
|
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "proto=%s", r.Proto)
|
_, _ = fmt.Fprintf(w, "proto=%s", r.Proto)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -522,7 +506,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Connect with HTTP/1.1 only (no h2 ALPN).
|
// Connect with HTTP/1.1 only (no h2 ALPN).
|
||||||
@@ -538,7 +522,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
t.Fatalf("status = %d, want 200", resp.StatusCode)
|
||||||
@@ -556,14 +540,14 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
|
|||||||
certPath, keyPath := testCert(t, "policy.test")
|
certPath, keyPath := testCert(t, "policy.test")
|
||||||
|
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, "should-not-reach")
|
_, _ = fmt.Fprint(w, "should-not-reach")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -581,7 +565,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
|
||||||
@@ -591,7 +575,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET: %v", err)
|
t.Fatalf("GET: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
if resp.StatusCode != 403 {
|
if resp.StatusCode != 403 {
|
||||||
t.Fatalf("status = %d, want 403", resp.StatusCode)
|
t.Fatalf("status = %d, want 403", resp.StatusCode)
|
||||||
@@ -602,14 +586,14 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
|||||||
certPath, keyPath := testCert(t, "reqhdr.test")
|
certPath, keyPath := testCert(t, "reqhdr.test")
|
||||||
|
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, "ok")
|
_, _ = fmt.Fprint(w, "ok")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
defer proxyLn.Close()
|
defer func() { _ = proxyLn.Close() }()
|
||||||
|
|
||||||
route := RouteConfig{
|
route := RouteConfig{
|
||||||
Backend: backendAddr,
|
Backend: backendAddr,
|
||||||
@@ -630,7 +614,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
|||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -641,7 +625,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET without header: %v", err)
|
t.Fatalf("GET without header: %v", err)
|
||||||
}
|
}
|
||||||
resp1.Body.Close()
|
_ = resp1.Body.Close()
|
||||||
if resp1.StatusCode != 403 {
|
if resp1.StatusCode != 403 {
|
||||||
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
|
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
|
||||||
}
|
}
|
||||||
@@ -654,7 +638,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("GET with header: %v", err)
|
t.Fatalf("GET with header: %v", err)
|
||||||
}
|
}
|
||||||
defer resp2.Body.Close()
|
defer func() { _ = resp2.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp2.Body)
|
body, _ := io.ReadAll(resp2.Body)
|
||||||
if resp2.StatusCode != 200 {
|
if resp2.StatusCode != 200 {
|
||||||
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
|
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package proxy
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -31,8 +32,8 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
<-ctx.Done()
|
||||||
client.Close()
|
_ = client.Close()
|
||||||
backend.Close()
|
_ = backend.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -50,7 +51,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
|
|||||||
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
|
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
|
||||||
// Half-close backend's write side.
|
// Half-close backend's write side.
|
||||||
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
|
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
|
||||||
hc.CloseWrite()
|
_ = hc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -60,7 +61,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
|
|||||||
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
|
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
|
||||||
// Half-close client's write side.
|
// Half-close client's write side.
|
||||||
if hc, ok := client.(interface{ CloseWrite() error }); ok {
|
if hc, ok := client.(interface{ CloseWrite() error }); ok {
|
||||||
hc.CloseWrite()
|
_ = hc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -85,10 +86,10 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
|
|||||||
var total int64
|
var total int64
|
||||||
|
|
||||||
for {
|
for {
|
||||||
src.SetReadDeadline(time.Now().Add(idleTimeout))
|
_ = src.SetReadDeadline(time.Now().Add(idleTimeout))
|
||||||
nr, readErr := src.Read(buf)
|
nr, readErr := src.Read(buf)
|
||||||
if nr > 0 {
|
if nr > 0 {
|
||||||
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
|
_ = dst.SetWriteDeadline(time.Now().Add(idleTimeout))
|
||||||
nw, writeErr := dst.Write(buf[:nr])
|
nw, writeErr := dst.Write(buf[:nr])
|
||||||
total += int64(nw)
|
total += int64(nw)
|
||||||
if writeErr != nil {
|
if writeErr != nil {
|
||||||
@@ -96,7 +97,7 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
if readErr == io.EOF {
|
if errors.Is(readErr, io.EOF) {
|
||||||
return total, nil
|
return total, nil
|
||||||
}
|
}
|
||||||
return total, readErr
|
return total, readErr
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
peeked := []byte("peeked-hello-bytes")
|
peeked := []byte("peeked-hello-bytes")
|
||||||
clientData := []byte("data from client")
|
clientData := []byte("data from client")
|
||||||
@@ -29,7 +29,7 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Read everything the backend receives.
|
// Read everything the backend receives.
|
||||||
received, _ := io.ReadAll(conn)
|
received, _ := io.ReadAll(conn)
|
||||||
@@ -40,21 +40,21 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Restructure: use a more controlled flow.
|
// Restructure: use a more controlled flow.
|
||||||
backendLn.Close()
|
_ = backendLn.Close()
|
||||||
|
|
||||||
// Use a real TCP pair for proper half-close.
|
// Use a real TCP pair for proper half-close.
|
||||||
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
|
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn2.Close()
|
defer func() { _ = backendLn2.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := backendLn2.Accept()
|
conn, err := backendLn2.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Read peeked + client data.
|
// Read peeked + client data.
|
||||||
buf := make([]byte, len(peeked)+len(clientData))
|
buf := make([]byte, len(peeked)+len(clientData))
|
||||||
@@ -62,11 +62,11 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
backendDone <- buf[:n]
|
backendDone <- buf[:n]
|
||||||
|
|
||||||
// Send response.
|
// Send response.
|
||||||
conn.Write(backendData)
|
_, _ = conn.Write(backendData)
|
||||||
|
|
||||||
// Close write side to signal EOF.
|
// Close write side to signal EOF.
|
||||||
if tc, ok := conn.(*net.TCPConn); ok {
|
if tc, ok := conn.(*net.TCPConn); ok {
|
||||||
tc.CloseWrite()
|
_ = tc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer clientLn.Close()
|
defer func() { _ = clientLn.Close() }()
|
||||||
|
|
||||||
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
|
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -94,9 +94,9 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
|
|
||||||
// Client sends data then closes write.
|
// Client sends data then closes write.
|
||||||
go func() {
|
go func() {
|
||||||
clientConn.Write(clientData)
|
_, _ = clientConn.Write(clientData)
|
||||||
if tc, ok := clientConn.(*net.TCPConn); ok {
|
if tc, ok := clientConn.(*net.TCPConn); ok {
|
||||||
tc.CloseWrite()
|
_ = tc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -114,7 +114,7 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Verify client received backend data.
|
// Verify client received backend data.
|
||||||
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
clientReceived, _ := io.ReadAll(clientConn)
|
clientReceived, _ := io.ReadAll(clientConn)
|
||||||
if !bytes.Equal(clientReceived, backendData) {
|
if !bytes.Equal(clientReceived, backendData) {
|
||||||
t.Fatalf("client received %q, want %q", clientReceived, backendData)
|
t.Fatalf("client received %q, want %q", clientReceived, backendData)
|
||||||
@@ -131,12 +131,12 @@ func TestRelayBasic(t *testing.T) {
|
|||||||
func TestRelayIdleTimeout(t *testing.T) {
|
func TestRelayIdleTimeout(t *testing.T) {
|
||||||
// Two connected pairs via TCP.
|
// Two connected pairs via TCP.
|
||||||
clientA, clientB := tcpPair(t)
|
clientA, clientB := tcpPair(t)
|
||||||
defer clientA.Close()
|
defer func() { _ = clientA.Close() }()
|
||||||
defer clientB.Close()
|
defer func() { _ = clientB.Close() }()
|
||||||
|
|
||||||
backendA, backendB := tcpPair(t)
|
backendA, backendB := tcpPair(t)
|
||||||
defer backendA.Close()
|
defer func() { _ = backendA.Close() }()
|
||||||
defer backendB.Close()
|
defer func() { _ = backendB.Close() }()
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
|
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
|
||||||
@@ -154,18 +154,18 @@ func TestRelayIdleTimeout(t *testing.T) {
|
|||||||
|
|
||||||
func TestRelayContextCancel(t *testing.T) {
|
func TestRelayContextCancel(t *testing.T) {
|
||||||
clientA, clientB := tcpPair(t)
|
clientA, clientB := tcpPair(t)
|
||||||
defer clientA.Close()
|
defer func() { _ = clientA.Close() }()
|
||||||
defer clientB.Close()
|
defer func() { _ = clientB.Close() }()
|
||||||
|
|
||||||
backendA, backendB := tcpPair(t)
|
backendA, backendB := tcpPair(t)
|
||||||
defer backendA.Close()
|
defer func() { _ = backendA.Close() }()
|
||||||
defer backendB.Close()
|
defer func() { _ = backendB.Close() }()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
go func() {
|
go func() {
|
||||||
Relay(ctx, clientB, backendA, nil, time.Minute)
|
_, _ = Relay(ctx, clientB, backendA, nil, time.Minute)
|
||||||
close(done)
|
close(done)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -185,12 +185,12 @@ func TestRelayContextCancel(t *testing.T) {
|
|||||||
|
|
||||||
func TestRelayLargeTransfer(t *testing.T) {
|
func TestRelayLargeTransfer(t *testing.T) {
|
||||||
clientA, clientB := tcpPair(t)
|
clientA, clientB := tcpPair(t)
|
||||||
defer clientA.Close()
|
defer func() { _ = clientA.Close() }()
|
||||||
defer clientB.Close()
|
defer func() { _ = clientB.Close() }()
|
||||||
|
|
||||||
backendA, backendB := tcpPair(t)
|
backendA, backendB := tcpPair(t)
|
||||||
defer backendA.Close()
|
defer func() { _ = backendA.Close() }()
|
||||||
defer backendB.Close()
|
defer func() { _ = backendB.Close() }()
|
||||||
|
|
||||||
// 1 MB of random data.
|
// 1 MB of random data.
|
||||||
data := make([]byte, 1<<20)
|
data := make([]byte, 1<<20)
|
||||||
@@ -199,9 +199,9 @@ func TestRelayLargeTransfer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
clientA.Write(data)
|
_, _ = clientA.Write(data)
|
||||||
if tc, ok := clientA.(*net.TCPConn); ok {
|
if tc, ok := clientA.(*net.TCPConn); ok {
|
||||||
tc.CloseWrite()
|
_ = tc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -211,14 +211,14 @@ func TestRelayLargeTransfer(t *testing.T) {
|
|||||||
for {
|
for {
|
||||||
n, err := backendB.Read(buf)
|
n, err := backendB.Read(buf)
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
backendB.Write(buf[:n])
|
_, _ = backendB.Write(buf[:n])
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if tc, ok := backendB.(*net.TCPConn); ok {
|
if tc, ok := backendB.(*net.TCPConn); ok {
|
||||||
tc.CloseWrite()
|
_ = tc.CloseWrite()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -240,7 +240,7 @@ func tcpPair(t *testing.T) (net.Conn, net.Conn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
defer ln.Close()
|
defer func() { _ = ln.Close() }()
|
||||||
|
|
||||||
var serverConn net.Conn
|
var serverConn net.Conn
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
|
|||||||
@@ -50,8 +50,8 @@ const (
|
|||||||
// It reads only the exact bytes of the PROXY header, leaving the connection
|
// It reads only the exact bytes of the PROXY header, leaving the connection
|
||||||
// positioned at the first byte after the header (e.g., TLS ClientHello).
|
// positioned at the first byte after the header (e.g., TLS ClientHello).
|
||||||
func Parse(conn net.Conn, deadline time.Time) (Header, error) {
|
func Parse(conn net.Conn, deadline time.Time) (Header, error) {
|
||||||
conn.SetReadDeadline(deadline)
|
_ = conn.SetReadDeadline(deadline)
|
||||||
defer conn.SetReadDeadline(time.Time{})
|
defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
|
||||||
|
|
||||||
// Read the first byte to determine version.
|
// Read the first byte to determine version.
|
||||||
var first [1]byte
|
var first [1]byte
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { ln.Close() })
|
t.Cleanup(func() { _ = ln.Close() })
|
||||||
|
|
||||||
ch := make(chan net.Conn, 1)
|
ch := make(chan net.Conn, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -31,10 +31,10 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial: %v", err)
|
t.Fatalf("dial: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { w.Close() })
|
t.Cleanup(func() { _ = w.Close() })
|
||||||
|
|
||||||
r := <-ch
|
r := <-ch
|
||||||
t.Cleanup(func() { r.Close() })
|
t.Cleanup(func() { _ = r.Close() })
|
||||||
return r, w
|
return r, w
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -42,7 +42,7 @@ func TestParseV1TCP4(t *testing.T) {
|
|||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
|
_, _ = writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -68,7 +68,7 @@ func TestParseV1TCP6(t *testing.T) {
|
|||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
|
_, _ = writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -98,7 +98,7 @@ func TestParseV2TCP4(t *testing.T) {
|
|||||||
buf = append(buf, 10, 0, 0, 1) // dst IP
|
buf = append(buf, 10, 0, 0, 1) // dst IP
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 12345)
|
buf = binary.BigEndian.AppendUint16(buf, 12345)
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 443)
|
buf = binary.BigEndian.AppendUint16(buf, 443)
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -137,7 +137,7 @@ func TestParseV2TCP6(t *testing.T) {
|
|||||||
buf = append(buf, dst[:]...)
|
buf = append(buf, dst[:]...)
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 56324)
|
buf = binary.BigEndian.AppendUint16(buf, 56324)
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 8443)
|
buf = binary.BigEndian.AppendUint16(buf, 8443)
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -163,7 +163,7 @@ func TestParseV2Local(t *testing.T) {
|
|||||||
buf = append(buf, 0x20) // version 2, LOCAL command
|
buf = append(buf, 0x20) // version 2, LOCAL command
|
||||||
buf = append(buf, 0x00) // unspec family, unspec protocol
|
buf = append(buf, 0x00) // unspec family, unspec protocol
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 0)
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -195,7 +195,7 @@ func TestParseV1Malformed(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
go func() {
|
go func() {
|
||||||
writer.Write([]byte(tt.data))
|
_, _ = writer.Write([]byte(tt.data))
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -211,7 +211,7 @@ func TestParseV2Malformed(t *testing.T) {
|
|||||||
go func() {
|
go func() {
|
||||||
bad := make([]byte, v2HeaderLen)
|
bad := make([]byte, v2HeaderLen)
|
||||||
bad[0] = v2Signature[0] // first byte matches but rest doesn't
|
bad[0] = v2Signature[0] // first byte matches but rest doesn't
|
||||||
writer.Write(bad)
|
_, _ = writer.Write(bad)
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -227,7 +227,7 @@ func TestParseV2Malformed(t *testing.T) {
|
|||||||
buf = append(buf, 0x31) // version 3, PROXY command
|
buf = append(buf, 0x31) // version 3, PROXY command
|
||||||
buf = append(buf, 0x11)
|
buf = append(buf, 0x11)
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 0)
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -244,8 +244,8 @@ func TestParseV2Malformed(t *testing.T) {
|
|||||||
buf = append(buf, 0x11) // AF_INET, STREAM
|
buf = append(buf, 0x11) // AF_INET, STREAM
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 12)
|
buf = binary.BigEndian.AppendUint16(buf, 12)
|
||||||
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
|
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
writer.Close()
|
_ = writer.Close()
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -261,7 +261,7 @@ func TestParseV2Malformed(t *testing.T) {
|
|||||||
buf = append(buf, 0x21) // version 2, PROXY
|
buf = append(buf, 0x21) // version 2, PROXY
|
||||||
buf = append(buf, 0x31) // AF_UNIX (3), STREAM
|
buf = append(buf, 0x31) // AF_UNIX (3), STREAM
|
||||||
buf = binary.BigEndian.AppendUint16(buf, 0)
|
buf = binary.BigEndian.AppendUint16(buf, 0)
|
||||||
writer.Write(buf)
|
_, _ = writer.Write(buf)
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -332,7 +332,7 @@ func TestRoundTripV2IPv4(t *testing.T) {
|
|||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
WriteV2(writer, src, dst)
|
_ = WriteV2(writer, src, dst)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -361,7 +361,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
|
|||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
WriteV2(writer, src, dst)
|
_ = WriteV2(writer, src, dst)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
|
||||||
@@ -380,7 +380,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
|
|||||||
func TestParseGarbageFirstByte(t *testing.T) {
|
func TestParseGarbageFirstByte(t *testing.T) {
|
||||||
reader, writer := pipeWithDeadline(t)
|
reader, writer := pipeWithDeadline(t)
|
||||||
go func() {
|
go func() {
|
||||||
writer.Write([]byte{0xFF, 0x00, 0x01})
|
_, _ = writer.Write([]byte{0xFF, 0x00, 0x01})
|
||||||
}()
|
}()
|
||||||
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
_, err := Parse(reader, time.Now().Add(2*time.Second))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ import (
|
|||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/l7"
|
"git.wntrmute.dev/mc/mc-proxy/internal/l7"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics"
|
"git.wntrmute.dev/mc/mc-proxy/internal/metrics"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy"
|
"git.wntrmute.dev/mc/mc-proxy/internal/proxy"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
"git.wntrmute.dev/mc/mc-proxy/internal/sni"
|
||||||
)
|
)
|
||||||
|
|
||||||
// L7PolicyRule is an L7 blocking policy attached to a route.
|
// L7PolicyRule is an L7 blocking policy attached to a route.
|
||||||
@@ -235,7 +235,7 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
ln, err := net.Listen("tcp", ls.Addr)
|
ln, err := net.Listen("tcp", ls.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
for _, l := range netListeners {
|
for _, l := range netListeners {
|
||||||
l.Close()
|
_ = l.Close()
|
||||||
}
|
}
|
||||||
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
|
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
|
||||||
}
|
}
|
||||||
@@ -253,7 +253,7 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
s.logger.Info("shutting down")
|
s.logger.Info("shutting down")
|
||||||
|
|
||||||
for _, ln := range netListeners {
|
for _, ln := range netListeners {
|
||||||
ln.Close()
|
_ = ln.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
done := make(chan struct{})
|
done := make(chan struct{})
|
||||||
@@ -272,7 +272,7 @@ func (s *Server) Run(ctx context.Context) error {
|
|||||||
<-done
|
<-done
|
||||||
}
|
}
|
||||||
|
|
||||||
s.fw.Close()
|
_ = s.fw.Close()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -294,7 +294,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
|
|||||||
|
|
||||||
// Enforce per-listener connection limit.
|
// Enforce per-listener connection limit.
|
||||||
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
|
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
|
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -311,7 +311,7 @@ func (s *Server) forceCloseAll() {
|
|||||||
for _, ls := range s.listeners {
|
for _, ls := range s.listeners {
|
||||||
ls.connMu.Lock()
|
ls.connMu.Lock()
|
||||||
for conn := range ls.activeConns {
|
for conn := range ls.activeConns {
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
}
|
}
|
||||||
ls.connMu.Unlock()
|
ls.connMu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -321,7 +321,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
|
|||||||
defer s.wg.Done()
|
defer s.wg.Done()
|
||||||
defer ls.ActiveConnections.Add(-1)
|
defer ls.ActiveConnections.Add(-1)
|
||||||
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
|
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
ls.connMu.Lock()
|
ls.connMu.Lock()
|
||||||
ls.activeConns[conn] = struct{}{}
|
ls.activeConns[conn] = struct{}{}
|
||||||
@@ -392,7 +392,7 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
|
|||||||
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer backendConn.Close()
|
defer func() { _ = backendConn.Close() }()
|
||||||
|
|
||||||
// Send PROXY protocol v2 header to backend if configured.
|
// Send PROXY protocol v2 header to backend if configured.
|
||||||
if route.SendProxyProtocol {
|
if route.SendProxyProtocol {
|
||||||
|
|||||||
@@ -24,9 +24,9 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
"git.wntrmute.dev/mc/mc-proxy/internal/config"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
"git.wntrmute.dev/mc/mc-proxy/internal/firewall"
|
||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto"
|
"git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
|
||||||
"golang.org/x/net/http2"
|
"golang.org/x/net/http2"
|
||||||
"golang.org/x/net/http2/h2c"
|
"golang.org/x/net/http2/h2c"
|
||||||
)
|
)
|
||||||
@@ -43,8 +43,8 @@ func echoServer(t *testing.T, ln net.Listener) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
io.Copy(conn, conn)
|
_, _ = io.Copy(conn, conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// newTestServer creates a Server with the given listener data and no firewall rules.
|
// newTestServer creates a Server with the given listener data and no firewall rules.
|
||||||
@@ -92,7 +92,7 @@ func TestProxyRoundTrip(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
go echoServer(t, backendLn)
|
go echoServer(t, backendLn)
|
||||||
|
|
||||||
// Pick a free port for the proxy listener.
|
// Pick a free port for the proxy listener.
|
||||||
@@ -101,7 +101,7 @@ func TestProxyRoundTrip(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -121,7 +121,7 @@ func TestProxyRoundTrip(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("echo.test")
|
hello := buildClientHello("echo.test")
|
||||||
if _, err := conn.Write(hello); err != nil {
|
if _, err := conn.Write(hello); err != nil {
|
||||||
@@ -130,7 +130,7 @@ func TestProxyRoundTrip(t *testing.T) {
|
|||||||
|
|
||||||
// The backend will echo our ClientHello back. Read it.
|
// The backend will echo our ClientHello back. Read it.
|
||||||
echoed := make([]byte, len(hello))
|
echoed := make([]byte, len(hello))
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
if _, err := io.ReadFull(conn, echoed); err != nil {
|
if _, err := io.ReadFull(conn, echoed); err != nil {
|
||||||
t.Fatalf("read echoed data: %v", err)
|
t.Fatalf("read echoed data: %v", err)
|
||||||
}
|
}
|
||||||
@@ -157,7 +157,7 @@ func TestNoRouteResets(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -176,7 +176,7 @@ func TestNoRouteResets(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("unknown.test")
|
hello := buildClientHello("unknown.test")
|
||||||
if _, err := conn.Write(hello); err != nil {
|
if _, err := conn.Write(hello); err != nil {
|
||||||
@@ -184,7 +184,7 @@ func TestNoRouteResets(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// The proxy should close the connection (no route match).
|
// The proxy should close the connection (no route match).
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed, but read succeeded")
|
t.Fatal("expected connection to be closed, but read succeeded")
|
||||||
@@ -197,7 +197,7 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
reached := make(chan struct{}, 1)
|
reached := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -205,7 +205,7 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
reached <- struct{}{}
|
reached <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -214,7 +214,7 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
// Create a firewall that blocks 127.0.0.1 (the test client).
|
// Create a firewall that blocks 127.0.0.1 (the test client).
|
||||||
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
|
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
|
||||||
@@ -245,7 +245,7 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
srv.Run(ctx)
|
_ = srv.Run(ctx)
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
@@ -253,13 +253,13 @@ func TestFirewallBlocks(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("echo.test")
|
hello := buildClientHello("echo.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
// Connection should be closed (blocked by firewall).
|
// Connection should be closed (blocked by firewall).
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed by firewall")
|
t.Fatal("expected connection to be closed by firewall")
|
||||||
@@ -283,7 +283,7 @@ func TestNotTLSResets(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -300,12 +300,12 @@ func TestNotTLSResets(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Send HTTP, not TLS.
|
// Send HTTP, not TLS.
|
||||||
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
|
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
|
||||||
|
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed for non-TLS data")
|
t.Fatal("expected connection to be closed for non-TLS data")
|
||||||
@@ -318,7 +318,7 @@ func TestConnectionTracking(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
var backendConns []net.Conn
|
var backendConns []net.Conn
|
||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
@@ -332,7 +332,7 @@ func TestConnectionTracking(t *testing.T) {
|
|||||||
backendConns = append(backendConns, conn)
|
backendConns = append(backendConns, conn)
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
// Hold connection open, drain input.
|
// Hold connection open, drain input.
|
||||||
go io.Copy(io.Discard, conn)
|
go func() { _, _ = io.Copy(io.Discard, conn) }()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -341,7 +341,7 @@ func TestConnectionTracking(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -382,10 +382,10 @@ func TestConnectionTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Close one client and its corresponding backend connection.
|
// Close one client and its corresponding backend connection.
|
||||||
clientConns[0].Close()
|
_ = clientConns[0].Close()
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
if len(backendConns) > 0 {
|
if len(backendConns) > 0 {
|
||||||
backendConns[0].Close()
|
_ = backendConns[0].Close()
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
|
|
||||||
@@ -402,10 +402,10 @@ func TestConnectionTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Clean up.
|
// Clean up.
|
||||||
clientConns[1].Close()
|
_ = clientConns[1].Close()
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
for _, c := range backendConns {
|
for _, c := range backendConns {
|
||||||
c.Close()
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
mu.Unlock()
|
mu.Unlock()
|
||||||
}
|
}
|
||||||
@@ -416,13 +416,13 @@ func TestMultipleListeners(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend A listen: %v", err)
|
t.Fatalf("backend A listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendA.Close()
|
defer func() { _ = backendA.Close() }()
|
||||||
|
|
||||||
backendB, err := net.Listen("tcp", "127.0.0.1:0")
|
backendB, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend B listen: %v", err)
|
t.Fatalf("backend B listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendB.Close()
|
defer func() { _ = backendB.Close() }()
|
||||||
|
|
||||||
// Each backend writes its identity and closes.
|
// Each backend writes its identity and closes.
|
||||||
serve := func(ln net.Listener, id string) {
|
serve := func(ln net.Listener, id string) {
|
||||||
@@ -430,10 +430,10 @@ func TestMultipleListeners(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
// Drain the incoming data, then write identity.
|
// Drain the incoming data, then write identity.
|
||||||
go io.Copy(io.Discard, conn)
|
go func() { _, _ = io.Copy(io.Discard, conn) }()
|
||||||
conn.Write([]byte(id))
|
_, _ = conn.Write([]byte(id))
|
||||||
}
|
}
|
||||||
go serve(backendA, "A")
|
go serve(backendA, "A")
|
||||||
go serve(backendB, "B")
|
go serve(backendB, "B")
|
||||||
@@ -444,14 +444,14 @@ func TestMultipleListeners(t *testing.T) {
|
|||||||
t.Fatalf("finding free port 1: %v", err)
|
t.Fatalf("finding free port 1: %v", err)
|
||||||
}
|
}
|
||||||
addr1 := ln1.Addr().String()
|
addr1 := ln1.Addr().String()
|
||||||
ln1.Close()
|
_ = ln1.Close()
|
||||||
|
|
||||||
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
ln2, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("finding free port 2: %v", err)
|
t.Fatalf("finding free port 2: %v", err)
|
||||||
}
|
}
|
||||||
addr2 := ln2.Addr().String()
|
addr2 := ln2.Addr().String()
|
||||||
ln2.Close()
|
_ = ln2.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
|
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
|
||||||
@@ -467,12 +467,12 @@ func TestMultipleListeners(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial %s: %v", proxyAddr, err)
|
t.Fatalf("dial %s: %v", proxyAddr, err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("svc.test")
|
hello := buildClientHello("svc.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
buf := make([]byte, 128)
|
buf := make([]byte, 128)
|
||||||
// Read what the backend sends back: echoed ClientHello + ID.
|
// Read what the backend sends back: echoed ClientHello + ID.
|
||||||
// The backend drains input and writes the ID, so we read until we
|
// The backend drains input and writes the ID, so we read until we
|
||||||
@@ -508,7 +508,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
go echoServer(t, backendLn)
|
go echoServer(t, backendLn)
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
@@ -516,7 +516,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -537,7 +537,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("ECHO.TEST")
|
hello := buildClientHello("ECHO.TEST")
|
||||||
if _, err := conn.Write(hello); err != nil {
|
if _, err := conn.Write(hello); err != nil {
|
||||||
@@ -545,7 +545,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
echoed := make([]byte, len(hello))
|
echoed := make([]byte, len(hello))
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
if _, err := io.ReadFull(conn, echoed); err != nil {
|
if _, err := io.ReadFull(conn, echoed); err != nil {
|
||||||
t.Fatalf("read echoed data: %v", err)
|
t.Fatalf("read echoed data: %v", err)
|
||||||
}
|
}
|
||||||
@@ -558,14 +558,14 @@ func TestBackendUnreachable(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
deadAddr := ln.Addr().String()
|
deadAddr := ln.Addr().String()
|
||||||
ln.Close()
|
_ = ln.Close()
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -584,13 +584,13 @@ func TestBackendUnreachable(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("dead.test")
|
hello := buildClientHello("dead.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
// Proxy should close the connection after failing to dial backend.
|
// Proxy should close the connection after failing to dial backend.
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed when backend is unreachable")
|
t.Fatal("expected connection to be closed when backend is unreachable")
|
||||||
@@ -603,15 +603,15 @@ func TestGracefulShutdown(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
conn, err := backendLn.Accept()
|
conn, err := backendLn.Accept()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
io.Copy(io.Discard, conn)
|
_, _ = io.Copy(io.Discard, conn)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
@@ -619,7 +619,7 @@ func TestGracefulShutdown(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
fw, err := firewall.New("", nil, nil, nil, 0, 0)
|
fw, err := firewall.New("", nil, nil, nil, 0, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -649,10 +649,10 @@ func TestGracefulShutdown(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("hold.test")
|
hello := buildClientHello("hold.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
// Trigger shutdown.
|
// Trigger shutdown.
|
||||||
@@ -719,7 +719,7 @@ func TestProxyProtocolReceive(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
go echoServer(t, backendLn)
|
go echoServer(t, backendLn)
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
@@ -727,7 +727,7 @@ func TestProxyProtocolReceive(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -747,22 +747,22 @@ func TestProxyProtocolReceive(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Send PROXY v2 header followed by TLS ClientHello.
|
// Send PROXY v2 header followed by TLS ClientHello.
|
||||||
var ppBuf bytes.Buffer
|
var ppBuf bytes.Buffer
|
||||||
proxyproto.WriteV2(&ppBuf,
|
_ = proxyproto.WriteV2(&ppBuf,
|
||||||
netip.MustParseAddrPort("203.0.113.50:12345"),
|
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||||
netip.MustParseAddrPort("198.51.100.1:443"),
|
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||||
)
|
)
|
||||||
conn.Write(ppBuf.Bytes())
|
_, _ = conn.Write(ppBuf.Bytes())
|
||||||
|
|
||||||
hello := buildClientHello("echo.test")
|
hello := buildClientHello("echo.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
// Backend should echo the ClientHello back (not the PROXY header).
|
// Backend should echo the ClientHello back (not the PROXY header).
|
||||||
echoed := make([]byte, len(hello))
|
echoed := make([]byte, len(hello))
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
if _, err := io.ReadFull(conn, echoed); err != nil {
|
if _, err := io.ReadFull(conn, echoed); err != nil {
|
||||||
t.Fatalf("read echoed data: %v", err)
|
t.Fatalf("read echoed data: %v", err)
|
||||||
}
|
}
|
||||||
@@ -774,7 +774,7 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -794,13 +794,13 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Send garbage instead of a valid PROXY header.
|
// Send garbage instead of a valid PROXY header.
|
||||||
conn.Write([]byte("NOT A PROXY HEADER\r\n"))
|
_, _ = conn.Write([]byte("NOT A PROXY HEADER\r\n"))
|
||||||
|
|
||||||
// Connection should be closed.
|
// Connection should be closed.
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed for invalid PROXY header")
|
t.Fatal("expected connection to be closed for invalid PROXY header")
|
||||||
@@ -813,7 +813,7 @@ func TestProxyProtocolSend(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
received := make(chan []byte, 1)
|
received := make(chan []byte, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -821,9 +821,9 @@ func TestProxyProtocolSend(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
// Read all available data; the proxy sends PROXY header + ClientHello.
|
// Read all available data; the proxy sends PROXY header + ClientHello.
|
||||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
var all []byte
|
var all []byte
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
for {
|
for {
|
||||||
@@ -845,7 +845,7 @@ func TestProxyProtocolSend(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -868,10 +868,10 @@ func TestProxyProtocolSend(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("pp.test")
|
hello := buildClientHello("pp.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
// The backend should receive: PROXY v2 header + ClientHello.
|
// The backend should receive: PROXY v2 header + ClientHello.
|
||||||
select {
|
select {
|
||||||
@@ -904,7 +904,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
received := make(chan []byte, 1)
|
received := make(chan []byte, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -912,7 +912,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
buf := make([]byte, 4096)
|
buf := make([]byte, 4096)
|
||||||
n, _ := conn.Read(buf)
|
n, _ := conn.Read(buf)
|
||||||
received <- buf[:n]
|
received <- buf[:n]
|
||||||
@@ -923,7 +923,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -942,10 +942,10 @@ func TestProxyProtocolNotSent(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello("nopp.test")
|
hello := buildClientHello("nopp.test")
|
||||||
conn.Write(hello)
|
_, _ = conn.Write(hello)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case data := <-received:
|
case data := <-received:
|
||||||
@@ -964,7 +964,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
reached := make(chan struct{}, 1)
|
reached := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -972,7 +972,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
reached <- struct{}{}
|
reached <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -981,7 +981,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
t.Fatalf("finding free port: %v", err)
|
t.Fatalf("finding free port: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
// Block 203.0.113.50 (the "real" client IP from PROXY header).
|
// Block 203.0.113.50 (the "real" client IP from PROXY header).
|
||||||
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
|
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
|
||||||
@@ -1014,7 +1014,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
srv.Run(ctx)
|
_ = srv.Run(ctx)
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
@@ -1022,19 +1022,19 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial proxy: %v", err)
|
t.Fatalf("dial proxy: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
// Send PROXY v2 with the blocked real IP.
|
// Send PROXY v2 with the blocked real IP.
|
||||||
var ppBuf bytes.Buffer
|
var ppBuf bytes.Buffer
|
||||||
proxyproto.WriteV2(&ppBuf,
|
_ = proxyproto.WriteV2(&ppBuf,
|
||||||
netip.MustParseAddrPort("203.0.113.50:12345"),
|
netip.MustParseAddrPort("203.0.113.50:12345"),
|
||||||
netip.MustParseAddrPort("198.51.100.1:443"),
|
netip.MustParseAddrPort("198.51.100.1:443"),
|
||||||
)
|
)
|
||||||
conn.Write(ppBuf.Bytes())
|
_, _ = conn.Write(ppBuf.Bytes())
|
||||||
conn.Write(buildClientHello("blocked.test"))
|
_, _ = conn.Write(buildClientHello("blocked.test"))
|
||||||
|
|
||||||
// Connection should be closed (firewall blocks real IP).
|
// Connection should be closed (firewall blocks real IP).
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed by firewall")
|
t.Fatal("expected connection to be closed by firewall")
|
||||||
@@ -1060,7 +1060,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
@@ -1068,7 +1068,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
go io.Copy(io.Discard, conn)
|
go func() { _, _ = io.Copy(io.Discard, conn) }()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -1077,7 +1077,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -1100,7 +1100,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial %d: %v", i, err)
|
t.Fatalf("dial %d: %v", i, err)
|
||||||
}
|
}
|
||||||
conn.Write(buildClientHello("limit.test"))
|
_, _ = conn.Write(buildClientHello("limit.test"))
|
||||||
conns = append(conns, conn)
|
conns = append(conns, conn)
|
||||||
}
|
}
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -1110,16 +1110,16 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial 3: %v", err)
|
t.Fatalf("dial 3: %v", err)
|
||||||
}
|
}
|
||||||
conn3.Write(buildClientHello("limit.test"))
|
_, _ = conn3.Write(buildClientHello("limit.test"))
|
||||||
conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn3.Read(make([]byte, 1))
|
_, err = conn3.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected 3rd connection to be closed due to limit")
|
t.Fatal("expected 3rd connection to be closed due to limit")
|
||||||
}
|
}
|
||||||
conn3.Close()
|
_ = conn3.Close()
|
||||||
|
|
||||||
// Close one existing connection.
|
// Close one existing connection.
|
||||||
conns[0].Close()
|
_ = conns[0].Close()
|
||||||
time.Sleep(200 * time.Millisecond)
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
// Now a new connection should succeed.
|
// Now a new connection should succeed.
|
||||||
@@ -1127,8 +1127,8 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial 4: %v", err)
|
t.Fatalf("dial 4: %v", err)
|
||||||
}
|
}
|
||||||
defer conn4.Close()
|
defer func() { _ = conn4.Close() }()
|
||||||
conn4.Write(buildClientHello("limit.test"))
|
_, _ = conn4.Write(buildClientHello("limit.test"))
|
||||||
|
|
||||||
// Give it time to be proxied.
|
// Give it time to be proxied.
|
||||||
time.Sleep(100 * time.Millisecond)
|
time.Sleep(100 * time.Millisecond)
|
||||||
@@ -1138,7 +1138,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
|
|||||||
|
|
||||||
// Clean up.
|
// Clean up.
|
||||||
for _, c := range conns[1:] {
|
for _, c := range conns[1:] {
|
||||||
c.Close()
|
_ = c.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1155,7 +1155,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
|
|
||||||
// h2c backend on origin that echoes the X-Forwarded-For.
|
// h2c backend on origin that echoes the X-Forwarded-For.
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
|
_, _ = fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
// Origin proxy: proxy_protocol=true listener, L7 route to backend.
|
// Origin proxy: proxy_protocol=true listener, L7 route to backend.
|
||||||
@@ -1164,7 +1164,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
t.Fatalf("origin listen: %v", err)
|
t.Fatalf("origin listen: %v", err)
|
||||||
}
|
}
|
||||||
originAddr := originLn.Addr().String()
|
originAddr := originLn.Addr().String()
|
||||||
originLn.Close()
|
_ = originLn.Close()
|
||||||
|
|
||||||
originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
|
originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
|
||||||
originCfg := &config.Config{
|
originCfg := &config.Config{
|
||||||
@@ -1196,7 +1196,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
originWg.Add(1)
|
originWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer originWg.Done()
|
defer originWg.Done()
|
||||||
originSrv.Run(originCtx)
|
_ = originSrv.Run(originCtx)
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1210,7 +1210,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
t.Fatalf("edge listen: %v", err)
|
t.Fatalf("edge listen: %v", err)
|
||||||
}
|
}
|
||||||
edgeAddr := edgeLn.Addr().String()
|
edgeAddr := edgeLn.Addr().String()
|
||||||
edgeLn.Close()
|
_ = edgeLn.Close()
|
||||||
|
|
||||||
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
|
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
|
||||||
edgeSrv := New(originCfg, edgeFw, []ListenerData{
|
edgeSrv := New(originCfg, edgeFw, []ListenerData{
|
||||||
@@ -1232,7 +1232,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
edgeWg.Add(1)
|
edgeWg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer edgeWg.Done()
|
defer edgeWg.Done()
|
||||||
edgeSrv.Run(edgeCtx)
|
_ = edgeSrv.Run(edgeCtx)
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -1253,7 +1253,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TLS dial edge: %v", err)
|
t.Fatalf("TLS dial edge: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
tr := &http2.Transport{}
|
tr := &http2.Transport{}
|
||||||
h2conn, err := tr.NewClientConn(conn)
|
h2conn, err := tr.NewClientConn(conn)
|
||||||
@@ -1266,7 +1266,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RoundTrip: %v", err)
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
if resp.StatusCode != 200 {
|
if resp.StatusCode != 200 {
|
||||||
@@ -1289,7 +1289,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("backend listen: %v", err)
|
t.Fatalf("backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer backendLn.Close()
|
defer func() { _ = backendLn.Close() }()
|
||||||
|
|
||||||
reached := make(chan struct{}, 1)
|
reached := make(chan struct{}, 1)
|
||||||
go func() {
|
go func() {
|
||||||
@@ -1297,7 +1297,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
conn.Close()
|
_ = conn.Close()
|
||||||
reached <- struct{}{}
|
reached <- struct{}{}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -1306,7 +1306,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
|
|||||||
t.Fatalf("origin listen: %v", err)
|
t.Fatalf("origin listen: %v", err)
|
||||||
}
|
}
|
||||||
originAddr := originLn.Addr().String()
|
originAddr := originLn.Addr().String()
|
||||||
originLn.Close()
|
_ = originLn.Close()
|
||||||
|
|
||||||
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
|
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
|
||||||
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
|
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
|
||||||
@@ -1334,7 +1334,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
|
|||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func() {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
originSrv.Run(ctx)
|
_ = originSrv.Run(ctx)
|
||||||
}()
|
}()
|
||||||
time.Sleep(50 * time.Millisecond)
|
time.Sleep(50 * time.Millisecond)
|
||||||
|
|
||||||
@@ -1343,18 +1343,18 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial origin: %v", err)
|
t.Fatalf("dial origin: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
var ppBuf bytes.Buffer
|
var ppBuf bytes.Buffer
|
||||||
proxyproto.WriteV2(&ppBuf,
|
_ = proxyproto.WriteV2(&ppBuf,
|
||||||
netip.MustParseAddrPort("198.51.100.99:12345"),
|
netip.MustParseAddrPort("198.51.100.99:12345"),
|
||||||
netip.MustParseAddrPort("10.0.0.1:443"),
|
netip.MustParseAddrPort("10.0.0.1:443"),
|
||||||
)
|
)
|
||||||
conn.Write(ppBuf.Bytes())
|
_, _ = conn.Write(ppBuf.Bytes())
|
||||||
conn.Write(buildClientHello("blocked.test"))
|
_, _ = conn.Write(buildClientHello("blocked.test"))
|
||||||
|
|
||||||
// Connection should be dropped by firewall.
|
// Connection should be dropped by firewall.
|
||||||
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
|
||||||
_, err = conn.Read(make([]byte, 1))
|
_, err = conn.Read(make([]byte, 1))
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal("expected connection to be closed")
|
t.Fatal("expected connection to be closed")
|
||||||
@@ -1396,12 +1396,12 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
|
|||||||
certPath = filepath.Join(dir, "cert.pem")
|
certPath = filepath.Join(dir, "cert.pem")
|
||||||
keyPath = filepath.Join(dir, "key.pem")
|
keyPath = filepath.Join(dir, "key.pem")
|
||||||
cf, _ := os.Create(certPath)
|
cf, _ := os.Create(certPath)
|
||||||
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
_ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
cf.Close()
|
_ = cf.Close()
|
||||||
keyDER, _ := x509.MarshalECPrivateKey(key)
|
keyDER, _ := x509.MarshalECPrivateKey(key)
|
||||||
kf, _ := os.Create(keyPath)
|
kf, _ := os.Create(keyPath)
|
||||||
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
_ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
|
||||||
kf.Close()
|
_ = kf.Close()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1417,8 +1417,8 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("listen: %v", err)
|
t.Fatalf("listen: %v", err)
|
||||||
}
|
}
|
||||||
t.Cleanup(func() { srv.Close(); ln.Close() })
|
t.Cleanup(func() { _ = srv.Close(); _ = ln.Close() })
|
||||||
go srv.Serve(ln)
|
go func() { _ = srv.Serve(ln) }()
|
||||||
return ln.Addr().String()
|
return ln.Addr().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1426,7 +1426,7 @@ func TestL7ThroughServer(t *testing.T) {
|
|||||||
certPath, keyPath := testCert(t, "l7srv.test")
|
certPath, keyPath := testCert(t, "l7srv.test")
|
||||||
|
|
||||||
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
|
_, _ = fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
@@ -1434,7 +1434,7 @@ func TestL7ThroughServer(t *testing.T) {
|
|||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -1467,7 +1467,7 @@ func TestL7ThroughServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TLS dial: %v", err)
|
t.Fatalf("TLS dial: %v", err)
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer func() { _ = conn.Close() }()
|
||||||
|
|
||||||
tr := &http2.Transport{}
|
tr := &http2.Transport{}
|
||||||
h2conn, err := tr.NewClientConn(conn)
|
h2conn, err := tr.NewClientConn(conn)
|
||||||
@@ -1480,7 +1480,7 @@ func TestL7ThroughServer(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("RoundTrip: %v", err)
|
t.Fatalf("RoundTrip: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
|
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
|
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
|
||||||
@@ -1502,12 +1502,12 @@ func TestMixedL4L7SameListener(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("l4 backend listen: %v", err)
|
t.Fatalf("l4 backend listen: %v", err)
|
||||||
}
|
}
|
||||||
defer l4BackendLn.Close()
|
defer func() { _ = l4BackendLn.Close() }()
|
||||||
go echoServer(t, l4BackendLn)
|
go echoServer(t, l4BackendLn)
|
||||||
|
|
||||||
// L7 backend: h2c HTTP server.
|
// L7 backend: h2c HTTP server.
|
||||||
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprint(w, "l7-response")
|
_, _ = fmt.Fprint(w, "l7-response")
|
||||||
}))
|
}))
|
||||||
|
|
||||||
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
@@ -1515,7 +1515,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
|
|||||||
t.Fatalf("proxy listen: %v", err)
|
t.Fatalf("proxy listen: %v", err)
|
||||||
}
|
}
|
||||||
proxyAddr := proxyLn.Addr().String()
|
proxyAddr := proxyLn.Addr().String()
|
||||||
proxyLn.Close()
|
_ = proxyLn.Close()
|
||||||
|
|
||||||
srv := newTestServer(t, []ListenerData{
|
srv := newTestServer(t, []ListenerData{
|
||||||
{
|
{
|
||||||
@@ -1541,11 +1541,11 @@ func TestMixedL4L7SameListener(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("dial L4: %v", err)
|
t.Fatalf("dial L4: %v", err)
|
||||||
}
|
}
|
||||||
defer l4Conn.Close()
|
defer func() { _ = l4Conn.Close() }()
|
||||||
hello := buildClientHello("l4echo.test")
|
hello := buildClientHello("l4echo.test")
|
||||||
l4Conn.Write(hello)
|
_, _ = l4Conn.Write(hello)
|
||||||
echoed := make([]byte, len(hello))
|
echoed := make([]byte, len(hello))
|
||||||
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
_ = l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||||
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
|
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
|
||||||
t.Fatalf("L4 echo read: %v", err)
|
t.Fatalf("L4 echo read: %v", err)
|
||||||
}
|
}
|
||||||
@@ -1563,7 +1563,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("TLS dial L7: %v", err)
|
t.Fatalf("TLS dial L7: %v", err)
|
||||||
}
|
}
|
||||||
defer l7Conn.Close()
|
defer func() { _ = l7Conn.Close() }()
|
||||||
|
|
||||||
tr := &http2.Transport{}
|
tr := &http2.Transport{}
|
||||||
h2conn, err := tr.NewClientConn(l7Conn)
|
h2conn, err := tr.NewClientConn(l7Conn)
|
||||||
@@ -1576,7 +1576,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("L7 RoundTrip: %v", err)
|
t.Fatalf("L7 RoundTrip: %v", err)
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer func() { _ = resp.Body.Close() }()
|
||||||
body, _ := io.ReadAll(resp.Body)
|
body, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
if string(body) != "l7-response" {
|
if string(body) != "l7-response" {
|
||||||
@@ -1636,4 +1636,3 @@ func sniExtension(serverName string) []byte {
|
|||||||
|
|
||||||
return ext
|
return ext
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ const maxBufferSize = 16384 // 16 KiB, max TLS record size
|
|||||||
//
|
//
|
||||||
// A read deadline is set on the connection to prevent slowloris attacks.
|
// A read deadline is set on the connection to prevent slowloris attacks.
|
||||||
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
|
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
|
||||||
conn.SetReadDeadline(deadline)
|
_ = conn.SetReadDeadline(deadline)
|
||||||
defer conn.SetReadDeadline(time.Time{})
|
defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
|
||||||
|
|
||||||
// Read TLS record header (5 bytes).
|
// Read TLS record header (5 bytes).
|
||||||
header := make([]byte, 5)
|
header := make([]byte, 5)
|
||||||
|
|||||||
@@ -22,13 +22,13 @@ func TestExtract(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
hello := buildClientHello(tt.sni)
|
hello := buildClientHello(tt.sni)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
client.Write(hello)
|
_, _ = client.Write(hello)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
|
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
@@ -53,13 +53,13 @@ func TestExtract(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractNoSNI(t *testing.T) {
|
func TestExtractNoSNI(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
hello := buildClientHelloNoSNI()
|
hello := buildClientHelloNoSNI()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
client.Write(hello)
|
_, _ = client.Write(hello)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
@@ -70,11 +70,11 @@ func TestExtractNoSNI(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractNotTLS(t *testing.T) {
|
func TestExtractNotTLS(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
_, _ = client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
@@ -85,13 +85,13 @@ func TestExtractNotTLS(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractTruncated(t *testing.T) {
|
func TestExtractTruncated(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Write just the TLS record header, then close.
|
// Write just the TLS record header, then close.
|
||||||
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
|
_, _ = client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
|
||||||
client.Close()
|
_ = client.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
@@ -102,15 +102,15 @@ func TestExtractTruncated(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractOversizedRecord(t *testing.T) {
|
func TestExtractOversizedRecord(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Record header claiming a length larger than 16 KiB.
|
// Record header claiming a length larger than 16 KiB.
|
||||||
header := []byte{0x16, 0x03, 0x01}
|
header := []byte{0x16, 0x03, 0x01}
|
||||||
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
|
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
|
||||||
client.Write(header)
|
_, _ = client.Write(header)
|
||||||
client.Close()
|
_ = client.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
_, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
@@ -121,13 +121,13 @@ func TestExtractOversizedRecord(t *testing.T) {
|
|||||||
|
|
||||||
func TestExtractMultipleExtensions(t *testing.T) {
|
func TestExtractMultipleExtensions(t *testing.T) {
|
||||||
client, server := net.Pipe()
|
client, server := net.Pipe()
|
||||||
defer client.Close()
|
defer func() { _ = client.Close() }()
|
||||||
defer server.Close()
|
defer func() { _ = server.Close() }()
|
||||||
|
|
||||||
hello := buildClientHelloWithExtraExtensions("target.example.com")
|
hello := buildClientHelloWithExtraExtensions("target.example.com")
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
client.Write(hello)
|
_, _ = client.Write(hello)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
|
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ syntax = "proto3";
|
|||||||
|
|
||||||
package mc_proxy.v1;
|
package mc_proxy.v1;
|
||||||
|
|
||||||
option go_package = "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1;mcproxyv1";
|
option go_package = "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1";
|
||||||
|
|
||||||
import "google/protobuf/timestamp.proto";
|
import "google/protobuf/timestamp.proto";
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ import (
|
|||||||
|
|
||||||
"github.com/pelletier/go-toml/v2"
|
"github.com/pelletier/go-toml/v2"
|
||||||
|
|
||||||
"git.wntrmute.dev/kyle/mcdsl/auth"
|
"git.wntrmute.dev/mc/mcdsl/auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Base contains the configuration sections common to all Metacircular
|
// Base contains the configuration sections common to all Metacircular
|
||||||
@@ -144,6 +144,8 @@ func Load[T any](path string, envPrefix string) (*T, error) {
|
|||||||
applyEnvToStruct(reflect.ValueOf(&cfg).Elem(), envPrefix)
|
applyEnvToStruct(reflect.ValueOf(&cfg).Elem(), envPrefix)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
applyPortEnv(&cfg)
|
||||||
|
|
||||||
applyBaseDefaults(&cfg)
|
applyBaseDefaults(&cfg)
|
||||||
|
|
||||||
if err := validateBase(&cfg); err != nil {
|
if err := validateBase(&cfg); err != nil {
|
||||||
@@ -239,6 +241,70 @@ func findBase(cfg any) *Base {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// applyPortEnv overrides ServerConfig.ListenAddr and ServerConfig.GRPCAddr
|
||||||
|
// from $PORT and $PORT_GRPC respectively. These environment variables are
|
||||||
|
// set by the MCP agent to assign authoritative port bindings, so they take
|
||||||
|
// precedence over both TOML values and generic env overrides.
|
||||||
|
func applyPortEnv(cfg any) {
|
||||||
|
sc := findServerConfig(cfg)
|
||||||
|
if sc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if port, ok := os.LookupEnv("PORT"); ok {
|
||||||
|
sc.ListenAddr = ":" + port
|
||||||
|
}
|
||||||
|
if port, ok := os.LookupEnv("PORT_GRPC"); ok {
|
||||||
|
sc.GRPCAddr = ":" + port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// findServerConfig returns a pointer to the ServerConfig in the config
|
||||||
|
// struct. It first checks for an embedded Base (which contains Server),
|
||||||
|
// then walks the struct tree via reflection to find any ServerConfig field
|
||||||
|
// directly (e.g., the Metacrypt pattern where ServerConfig is embedded
|
||||||
|
// without Base).
|
||||||
|
func findServerConfig(cfg any) *ServerConfig {
|
||||||
|
if base := findBase(cfg); base != nil {
|
||||||
|
return &base.Server
|
||||||
|
}
|
||||||
|
|
||||||
|
return findServerConfigReflect(reflect.ValueOf(cfg))
|
||||||
|
}
|
||||||
|
|
||||||
|
// findServerConfigReflect walks the struct tree to find a ServerConfig field.
|
||||||
|
func findServerConfigReflect(v reflect.Value) *ServerConfig {
|
||||||
|
if v.Kind() == reflect.Ptr {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
if v.Kind() != reflect.Struct {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
scType := reflect.TypeOf(ServerConfig{})
|
||||||
|
t := v.Type()
|
||||||
|
for i := range t.NumField() {
|
||||||
|
field := t.Field(i)
|
||||||
|
fv := v.Field(i)
|
||||||
|
|
||||||
|
if field.Type == scType {
|
||||||
|
sc, ok := fv.Addr().Interface().(*ServerConfig)
|
||||||
|
if ok {
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recurse into embedded or nested structs.
|
||||||
|
if fv.Kind() == reflect.Struct && field.Type != scType {
|
||||||
|
if sc := findServerConfigReflect(fv); sc != nil {
|
||||||
|
return sc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// applyEnvToStruct recursively walks a struct and overrides field values
|
// applyEnvToStruct recursively walks a struct and overrides field values
|
||||||
// from environment variables. The env variable name is built from the
|
// from environment variables. The env variable name is built from the
|
||||||
// prefix and the toml tag: PREFIX_SECTION_FIELD (uppercased).
|
// prefix and the toml tag: PREFIX_SECTION_FIELD (uppercased).
|
||||||
8
vendor/modules.txt
vendored
8
vendor/modules.txt
vendored
@@ -1,8 +1,8 @@
|
|||||||
# git.wntrmute.dev/kyle/mcdsl v1.0.0
|
# git.wntrmute.dev/mc/mcdsl v1.2.0
|
||||||
## explicit; go 1.25.7
|
## explicit; go 1.25.7
|
||||||
git.wntrmute.dev/kyle/mcdsl/auth
|
git.wntrmute.dev/mc/mcdsl/auth
|
||||||
git.wntrmute.dev/kyle/mcdsl/config
|
git.wntrmute.dev/mc/mcdsl/config
|
||||||
git.wntrmute.dev/kyle/mcdsl/db
|
git.wntrmute.dev/mc/mcdsl/db
|
||||||
# github.com/beorn7/perks v1.0.1
|
# github.com/beorn7/perks v1.0.1
|
||||||
## explicit; go 1.11
|
## explicit; go 1.11
|
||||||
github.com/beorn7/perks/quantile
|
github.com/beorn7/perks/quantile
|
||||||
|
|||||||
Reference in New Issue
Block a user