6 Commits

Author SHA1 Message Date
a60e5cb86a Fix golangci-lint v2 compliance, make all passes clean
- Fix 314 errcheck violations (blank identifier for unrecoverable errors)
- Fix errorlint violation (errors.Is for io.EOF)
- Remove unused serveL7Route test helper
- Simplify Duration.Seconds() selectors in tests
- Remove unnecessary fmt.Sprintf in test
- Migrate exclusion rules from issues.exclusions to linters.exclusions (v2 schema)
- Add gosec test exclusions (G115, G304, G402, G705)
- Disable fieldalignment govet analyzer (optimization, not correctness)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 13:30:43 -07:00
4f3249fdc3 Regenerate proto files for mc/ module path
Raw descriptor bytes in .pb.go files were corrupted by the sed-based
module path rename (string length changed, breaking protobuf binary
encoding). Regenerated with protoc to fix.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:54:22 -07:00
f31a7f20fb Bump flake.nix version to match latest tag
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:16:38 -07:00
feeadc582b Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00
a45ed03432 Use http.Transport for non-TLS backends (HTTP/1.1 support)
The h2c-only transport (http2.Transport) fails against backends like
Gitea that only speak HTTP/1.1. Switch to standard http.Transport for
non-TLS backends, which handles HTTP/1.1 natively and can upgrade to
h2c if the backend supports it.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:32:42 -07:00
dc1816b159 Add MCP deployment section to RUNBOOK.md
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 22:09:18 -07:00
54 changed files with 533 additions and 428 deletions

View File

@@ -9,6 +9,20 @@ run:
tests: true tests: true
linters: linters:
exclusions:
paths:
- vendor
rules:
# In test files, suppress gosec rules that are false positives in test code:
# G101: hardcoded test credentials (intentional fixtures)
# G115: integer overflow in type conversions (test TLS packet builders)
# G304: file paths from variables (t.TempDir paths)
# G402: InsecureSkipVerify (required for test TLS clients)
# G705: XSS via taint analysis (test HTTP handlers, not real servers)
- path: "_test\\.go"
linters:
- gosec
text: "G101|G115|G304|G402|G705"
default: none default: none
enable: enable:
# --- Correctness --- # --- Correctness ---
@@ -52,12 +66,15 @@ linters:
check-type-assertions: true check-type-assertions: true
govet: govet:
# Enable all analyzers except shadow. The shadow analyzer flags the idiomatic # Enable all analyzers except shadow and fieldalignment. The shadow analyzer
# `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`, # flags the idiomatic `if err := f(); err != nil { ... }` pattern as shadowing
# which is ubiquitous in Go and does not pose a security risk in this codebase. # an outer `err`, which is ubiquitous in Go. The fieldalignment analyzer
# suggests struct field reordering for memory efficiency — useful as a one-off
# audit but too noisy for CI (every struct change triggers it).
enable-all: true enable-all: true
disable: disable:
- shadow - shadow
- fieldalignment
gosec: gosec:
# Treat all gosec findings as errors, not warnings. # Treat all gosec findings as errors, not warnings.
@@ -110,15 +127,3 @@ issues:
# Do not cap the number of reported issues; in security code every finding matters. # Do not cap the number of reported issues; in security code every finding matters.
max-issues-per-linter: 0 max-issues-per-linter: 0
max-same-issues: 0 max-same-issues: 0
exclusions:
paths:
- vendor
rules:
# In test files, allow hardcoded test credentials (gosec G101) since they are
# intentional fixtures, not production secrets.
- path: "_test\\.go"
linters:
- gosec
text: "G101"

View File

@@ -26,7 +26,7 @@ go test ./internal/sni -run TestExtract
## Architecture ## Architecture
- **Module path**: `git.wntrmute.dev/kyle/mc-proxy` - **Module path**: `git.wntrmute.dev/mc/mc-proxy`
- **Go with CGO_ENABLED=0**, statically linked, Alpine containers - **Go with CGO_ENABLED=0**, statically linked, Alpine containers
- **Dual mode, per-route** — L4 (passthrough) and L7 (TLS-terminating HTTP/2 reverse proxy) coexist on the same listener - **Dual mode, per-route** — L4 (passthrough) and L7 (TLS-terminating HTTP/2 reverse proxy) coexist on the same listener
- **PROXY protocol** — listeners accept v1/v2; routes send v2. Enables edge→origin deployments over Tailscale - **PROXY protocol** — listeners accept v1/v2; routes send v2. Enables edge→origin deployments over Tailscale

View File

@@ -21,8 +21,8 @@ lint:
golangci-lint run ./... golangci-lint run ./...
proto: proto:
protoc --go_out=. --go_opt=module=git.wntrmute.dev/kyle/mc-proxy \ protoc --go_out=. --go_opt=module=git.wntrmute.dev/mc/mc-proxy \
--go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/kyle/mc-proxy \ --go-grpc_out=. --go-grpc_opt=module=git.wntrmute.dev/mc/mc-proxy \
proto/mc_proxy/v1/*.proto proto/mc_proxy/v1/*.proto
proto-lint: proto-lint:

View File

@@ -187,6 +187,56 @@ grpcurl -cacert ca.pem -cert client.pem -key client-key.pem \
-d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}' -d '{"rule": {"type": "FIREWALL_RULE_TYPE_IP", "value": "203.0.113.50"}}'
``` ```
## Deployment with MCP
mc-proxy runs on rift as a single container managed by MCP. The service
definition lives at `~/.config/mcp/services/mc-proxy.toml` on rift (reference
copy at `deploy/mc-proxy-rift.toml` in this repo). The container mounts
`/srv/mc-proxy` which holds the config file, SQLite database, GeoIP database,
and TLS certificates for backends. It runs as `--user 0:0` under rootless
podman.
Listeners: `:443` (L7 terminating), `:8443` (L4 passthrough), `:9443` (L4
passthrough).
### Deploy or Update
```bash
mcp deploy mc-proxy
```
### Restart / Stop
```bash
mcp restart mc-proxy
mcp stop mc-proxy
```
### Check Status
```bash
mcp ps
mcp status mc-proxy
```
### View Logs
```bash
ssh rift 'doas su - mcp -s /bin/sh -c "podman logs mc-proxy"'
```
### Update Routes
Edit the config at `/srv/mc-proxy/mc-proxy.toml` on rift, then restart:
```bash
mcp restart mc-proxy
```
Routes added at runtime via the gRPC admin API are persisted in the database
and survive restarts. Editing the TOML config is only necessary for changing
listener definitions or static seed routes.
## Incident Procedures ## Incident Procedures
### Proxy Not Starting ### Proxy Not Starting

View File

@@ -10,7 +10,7 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
) )
// Client provides access to the mc-proxy admin API. // Client provides access to the mc-proxy admin API.

View File

@@ -15,12 +15,12 @@ import (
healthpb "google.golang.org/grpc/health/grpc_health_v1" healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver" "git.wntrmute.dev/mc/mc-proxy/internal/grpcserver"
"git.wntrmute.dev/kyle/mc-proxy/internal/server" "git.wntrmute.dev/mc/mc-proxy/internal/server"
) )
func setupTestClient(t *testing.T) *Client { func setupTestClient(t *testing.T) *Client {
@@ -32,7 +32,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil { if err != nil {
t.Fatalf("open db: %v", err) t.Fatalf("open db: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
@@ -128,7 +128,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil { if err != nil {
t.Fatalf("dial bufconn: %v", err) t.Fatalf("dial bufconn: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
return &Client{ return &Client{
conn: conn, conn: conn,

View File

@@ -11,12 +11,12 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver" "git.wntrmute.dev/mc/mc-proxy/internal/grpcserver"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/mc/mc-proxy/internal/metrics"
"git.wntrmute.dev/kyle/mc-proxy/internal/server" "git.wntrmute.dev/mc/mc-proxy/internal/server"
) )
func serverCmd() *cobra.Command { func serverCmd() *cobra.Command {
@@ -40,7 +40,7 @@ func serverCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("opening database: %w", err) return fmt.Errorf("opening database: %w", err)
} }
defer store.Close() defer func() { _ = store.Close() }()
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
return fmt.Errorf("running migrations: %w", err) return fmt.Errorf("running migrations: %w", err)
@@ -93,7 +93,7 @@ func serverCmd() *cobra.Command {
}() }()
defer func() { defer func() {
grpcSrv.GracefulStop() grpcSrv.GracefulStop()
os.Remove(cfg.GRPC.SocketPath()) _ = os.Remove(cfg.GRPC.SocketPath())
}() }()
} }

View File

@@ -9,8 +9,8 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/db"
) )
func snapshotCmd() *cobra.Command { func snapshotCmd() *cobra.Command {
@@ -32,7 +32,7 @@ func snapshotCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("opening database: %w", err) return fmt.Errorf("opening database: %w", err)
} }
defer store.Close() defer func() { _ = store.Close() }()
dataDir := filepath.Dir(cfg.Database.Path) dataDir := filepath.Dir(cfg.Database.Path)

View File

@@ -9,8 +9,8 @@ import (
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
) )
func statusCmd() *cobra.Command { func statusCmd() *cobra.Command {
@@ -33,7 +33,7 @@ func statusCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("connecting to gRPC API: %w", err) return fmt.Errorf("connecting to gRPC API: %w", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
client := pb.NewProxyAdminServiceClient(conn) client := pb.NewProxyAdminServiceClient(conn)

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
) )
func firewallCmd() *cobra.Command { func firewallCmd() *cobra.Command {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
) )
func healthCmd() *cobra.Command { func healthCmd() *cobra.Command {

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" mcproxy "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
) )
func policiesCmd() *cobra.Command { func policiesCmd() *cobra.Command {

View File

@@ -6,7 +6,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
) )
const defaultSocketPath = "/srv/mc-proxy/mc-proxy.sock" const defaultSocketPath = "/srv/mc-proxy/mc-proxy.sock"

View File

@@ -7,7 +7,7 @@ import (
"github.com/spf13/cobra" "github.com/spf13/cobra"
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" mcproxy "git.wntrmute.dev/mc/mc-proxy/client/mcproxy"
) )
func routesCmd() *cobra.Command { func routesCmd() *cobra.Command {

View File

@@ -10,7 +10,7 @@
let let
system = "x86_64-linux"; system = "x86_64-linux";
pkgs = nixpkgs.legacyPackages.${system}; pkgs = nixpkgs.legacyPackages.${system};
version = "0.1.0"; version = "1.1.0";
in in
{ {
packages.${system} = { packages.${system} = {

View File

@@ -1449,7 +1449,7 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" +
"\x0eListL7Policies\x12\".mc_proxy.v1.ListL7PoliciesRequest\x1a#.mc_proxy.v1.ListL7PoliciesResponse\x12P\n" + "\x0eListL7Policies\x12\".mc_proxy.v1.ListL7PoliciesRequest\x1a#.mc_proxy.v1.ListL7PoliciesResponse\x12P\n" +
"\vAddL7Policy\x12\x1f.mc_proxy.v1.AddL7PolicyRequest\x1a .mc_proxy.v1.AddL7PolicyResponse\x12Y\n" + "\vAddL7Policy\x12\x1f.mc_proxy.v1.AddL7PolicyRequest\x1a .mc_proxy.v1.AddL7PolicyResponse\x12Y\n" +
"\x0eRemoveL7Policy\x12\".mc_proxy.v1.RemoveL7PolicyRequest\x1a#.mc_proxy.v1.RemoveL7PolicyResponse\x12J\n" + "\x0eRemoveL7Policy\x12\".mc_proxy.v1.RemoveL7PolicyRequest\x1a#.mc_proxy.v1.RemoveL7PolicyResponse\x12J\n" +
"\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB:Z8git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3" "\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB8Z6git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3"
var ( var (
file_proto_mc_proxy_v1_admin_proto_rawDescOnce sync.Once file_proto_mc_proxy_v1_admin_proto_rawDescOnce sync.Once

4
go.mod
View File

@@ -1,9 +1,9 @@
module git.wntrmute.dev/kyle/mc-proxy module git.wntrmute.dev/mc/mc-proxy
go 1.25.7 go 1.25.7
require ( require (
git.wntrmute.dev/kyle/mcdsl v1.0.0 git.wntrmute.dev/mc/mcdsl v1.2.0
github.com/oschwald/maxminddb-golang v1.13.1 github.com/oschwald/maxminddb-golang v1.13.1
github.com/prometheus/client_golang v1.23.2 github.com/prometheus/client_golang v1.23.2
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2

4
go.sum
View File

@@ -1,5 +1,5 @@
git.wntrmute.dev/kyle/mcdsl v1.0.0 h1:YB7dx4gdNYKKcVySpL6UkwHqdCJ9Nl1yS0+eHk0hNtk= git.wntrmute.dev/mc/mcdsl v1.2.0 h1:41hep7/PNZJfN0SN/nM+rQpyF1GSZcvNNjyVG81DI7U=
git.wntrmute.dev/kyle/mcdsl v1.0.0/go.mod h1:wo0tGfUAxci3XnOe4/rFmR0RjUElKdYUazc+Np986sg= git.wntrmute.dev/mc/mcdsl v1.2.0/go.mod h1:lXYrAt74ZUix6rx9oVN8d2zH1YJoyp4uxPVKQ+SSxuM=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs=

View File

@@ -7,7 +7,7 @@ import (
"os" "os"
"strings" "strings"
mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config" mcdslconfig "git.wntrmute.dev/mc/mcdsl/config"
) )
// Duration is an alias for the mcdsl config.Duration type, which wraps // Duration is an alias for the mcdsl config.Duration type, which wraps

View File

@@ -304,7 +304,7 @@ func TestDuration(t *testing.T) {
if err := d.UnmarshalText([]byte("5s")); err != nil { if err := d.UnmarshalText([]byte("5s")); err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if d.Duration.Seconds() != 5 { if d.Seconds() != 5 {
t.Fatalf("got %v, want 5s", d.Duration) t.Fatalf("got %v, want 5s", d.Duration)
} }
} }
@@ -340,7 +340,7 @@ level = "info"
if cfg.Log.Level != "debug" { if cfg.Log.Level != "debug" {
t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug") t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug")
} }
if cfg.Proxy.IdleTimeout.Duration.Seconds() != 600 { if cfg.Proxy.IdleTimeout.Seconds() != 600 {
t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration) t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration)
} }
if cfg.Database.Path != "/override/test.db" { if cfg.Database.Path != "/override/test.db" {

View File

@@ -4,7 +4,7 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
) )
// Store wraps a SQLite database connection for mc-proxy persistence. // Store wraps a SQLite database connection for mc-proxy persistence.

View File

@@ -4,7 +4,7 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
) )
func openTestDB(t *testing.T) *Store { func openTestDB(t *testing.T) *Store {
@@ -17,7 +17,7 @@ func openTestDB(t *testing.T) *Store {
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
return store return store
} }
@@ -266,8 +266,8 @@ func TestRouteCascadeDelete(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
listenerID, _ := store.CreateListener(":443", false, 0) listenerID, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false) _, _ = store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false) _, _ = store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
if err := store.DeleteListener(listenerID); err != nil { if err := store.DeleteListener(listenerID); err != nil {
t.Fatalf("delete listener: %v", err) t.Fatalf("delete listener: %v", err)
@@ -412,7 +412,7 @@ func TestSeed(t *testing.T) {
func TestSnapshot(t *testing.T) { func TestSnapshot(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
store.CreateListener(":443", false, 0) _, _ = store.CreateListener(":443", false, 0)
dest := filepath.Join(t.TempDir(), "backup.db") dest := filepath.Join(t.TempDir(), "backup.db")
if err := store.Snapshot(dest); err != nil { if err := store.Snapshot(dest); err != nil {
@@ -424,7 +424,7 @@ func TestSnapshot(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("open backup: %v", err) t.Fatalf("open backup: %v", err)
} }
defer backup.Close() defer func() { _ = backup.Close() }()
if err := backup.Migrate(); err != nil { if err := backup.Migrate(); err != nil {
t.Fatalf("migrate backup: %v", err) t.Fatalf("migrate backup: %v", err)
@@ -463,7 +463,7 @@ func TestMigrationV2Upgrade(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("open: %v", err) t.Fatalf("open: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
// Run full migrations (v1 + v2). // Run full migrations (v1 + v2).
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
@@ -556,10 +556,10 @@ func TestL7PolicyCascadeDelete(t *testing.T) {
lid, _ := store.CreateListener(":443", false, 0) lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false) rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
store.CreateL7Policy(rid, "block_user_agent", "Bot") _, _ = store.CreateL7Policy(rid, "block_user_agent", "Bot")
// Deleting the route should cascade-delete its policies. // Deleting the route should cascade-delete its policies.
store.DeleteRoute(lid, "api.test") _ = store.DeleteRoute(lid, "api.test")
policies, _ := store.ListL7Policies(rid) policies, _ := store.ListL7Policies(rid)
if len(policies) != 0 { if len(policies) != 0 {
@@ -585,7 +585,7 @@ func TestGetRouteID(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0) lid, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false) _, _ = store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
rid, err := store.GetRouteID(lid, "api.test") rid, err := store.GetRouteID(lid, "api.test")
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Store) ListFirewallRules() ([]FirewallRule, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying firewall rules: %w", err) return nil, fmt.Errorf("querying firewall rules: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var rules []FirewallRule var rules []FirewallRule
for rows.Next() { for rows.Next() {

View File

@@ -19,7 +19,7 @@ func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying l7 policies: %w", err) return nil, fmt.Errorf("querying l7 policies: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var policies []L7Policy var policies []L7Policy
for rows.Next() { for rows.Next() {

View File

@@ -16,7 +16,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying listeners: %w", err) return nil, fmt.Errorf("querying listeners: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var listeners []Listener var listeners []Listener
for rows.Next() { for rows.Next() {

View File

@@ -1,7 +1,7 @@
package db package db
import ( import (
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
) )
// Migrations is the ordered list of schema migrations for mc-proxy. // Migrations is the ordered list of schema migrations for mc-proxy.

View File

@@ -25,7 +25,7 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying routes: %w", err) return nil, fmt.Errorf("querying routes: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var routes []Route var routes []Route
for rows.Next() { for rows.Next() {

View File

@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"strings" "strings"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
) )
// Seed populates the database from TOML config data. Only called when the // Seed populates the database from TOML config data. Only called when the
@@ -14,7 +14,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if err != nil { if err != nil {
return fmt.Errorf("beginning seed transaction: %w", err) return fmt.Errorf("beginning seed transaction: %w", err)
} }
defer tx.Rollback() defer func() { _ = tx.Rollback() }()
for _, l := range listeners { for _, l := range listeners {
result, err := tx.Exec( result, err := tx.Exec(

View File

@@ -1,7 +1,7 @@
package db package db
import ( import (
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" mcdsldb "git.wntrmute.dev/mc/mcdsl/db"
) )
// Snapshot creates a consistent backup of the database using VACUUM INTO. // Snapshot creates a consistent backup of the database using VACUUM INTO.

View File

@@ -234,7 +234,7 @@ func (f *Firewall) loadGeoDB(path string) error {
f.mu.Unlock() f.mu.Unlock()
if old != nil { if old != nil {
old.Close() _ = old.Close()
} }
return nil return nil
} }

View File

@@ -11,7 +11,7 @@ func TestEmptyFirewall(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"} addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
for _, a := range addrs { for _, a := range addrs {
@@ -27,7 +27,7 @@ func TestIPBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -52,7 +52,7 @@ func TestCIDRBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -78,7 +78,7 @@ func TestIPv4MappedIPv6(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("::ffff:192.0.2.1") addr := netip.MustParseAddr("::ffff:192.0.2.1")
if !fw.Blocked(addr) { if !fw.Blocked(addr) {
@@ -105,7 +105,7 @@ func TestCombinedRules(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -130,7 +130,7 @@ func TestRateLimitBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1") addr := netip.MustParseAddr("10.0.0.1")
@@ -151,7 +151,7 @@ func TestRateLimitBlocklistFirst(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
blockedAddr := netip.MustParseAddr("10.0.0.1") blockedAddr := netip.MustParseAddr("10.0.0.1")
otherAddr := netip.MustParseAddr("10.0.0.2") otherAddr := netip.MustParseAddr("10.0.0.2")
@@ -175,7 +175,7 @@ func TestBlockedWithReason(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -216,7 +216,7 @@ func TestRuntimeMutation(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1") addr := netip.MustParseAddr("10.0.0.1")
if fw.Blocked(addr) { if fw.Blocked(addr) {

View File

@@ -38,7 +38,7 @@ func (rl *rateLimiter) Allow(addr netip.Addr) bool {
now := rl.now().UnixNano() now := rl.now().UnixNano()
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{}) val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
entry := val.(*rateLimitEntry) entry, _ := val.(*rateLimitEntry)
windowStart := entry.start.Load() windowStart := entry.start.Load()
if now-windowStart >= rl.window.Nanoseconds() { if now-windowStart >= rl.window.Nanoseconds() {
@@ -70,7 +70,7 @@ func (rl *rateLimiter) cleanup() {
case <-ticker.C: case <-ticker.C:
cutoff := rl.now().Add(-2 * rl.window).UnixNano() cutoff := rl.now().Add(-2 * rl.window).UnixNano()
rl.entries.Range(func(key, value any) bool { rl.entries.Range(func(key, value any) bool {
entry := value.(*rateLimitEntry) entry, _ := value.(*rateLimitEntry)
if entry.start.Load() < cutoff { if entry.start.Load() < cutoff {
rl.entries.Delete(key) rl.entries.Delete(key)
} }

View File

@@ -17,10 +17,10 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/server" "git.wntrmute.dev/mc/mc-proxy/internal/server"
) )
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`) var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
@@ -53,7 +53,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
path := cfg.SocketPath() path := cfg.SocketPath()
// Remove stale socket file from a previous run. // Remove stale socket file from a previous run.
os.Remove(path) _ = os.Remove(path)
ln, err := net.Listen("unix", path) ln, err := net.Listen("unix", path)
if err != nil { if err != nil {
@@ -61,7 +61,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
} }
if err := os.Chmod(path, 0600); err != nil { if err := os.Chmod(path, 0600); err != nil {
ln.Close() _ = ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err) return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
} }
@@ -446,7 +446,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.
} }
listeners = append(listeners, &pb.ListenerStatus{ listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr, Addr: ls.Addr,
RouteCount: int32(len(routes)), RouteCount: int32(len(routes)), //nolint:gosec // route count can never exceed int32
ActiveConnections: ls.ActiveConnections.Load(), ActiveConnections: ls.ActiveConnections.Load(),
ProxyProtocol: ls.ProxyProtocol, ProxyProtocol: ls.ProxyProtocol,
MaxConnections: ls.MaxConnections, MaxConnections: ls.MaxConnections,

View File

@@ -15,11 +15,11 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn" "google.golang.org/grpc/test/bufconn"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/server" "git.wntrmute.dev/mc/mc-proxy/internal/server"
) )
// testEnv bundles all the objects needed for a grpcserver test. // testEnv bundles all the objects needed for a grpcserver test.
@@ -39,7 +39,7 @@ func setup(t *testing.T) *testEnv {
if err != nil { if err != nil {
t.Fatalf("open db: %v", err) t.Fatalf("open db: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
@@ -130,7 +130,7 @@ func setup(t *testing.T) *testEnv {
if err != nil { if err != nil {
t.Fatalf("dial bufconn: %v", err) t.Fatalf("dial bufconn: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
return &testEnv{ return &testEnv{
client: pb.NewProxyAdminServiceClient(conn), client: pb.NewProxyAdminServiceClient(conn),
@@ -775,7 +775,7 @@ func TestRemoveL7Policy(t *testing.T) {
env := setup(t) env := setup(t)
ctx := context.Background() ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ _, _ = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443", ListenerAddr: ":443",
Hostname: "a.test", Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"}, Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},

View File

@@ -4,7 +4,7 @@ import (
"net/http" "net/http"
"strings" "strings"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/mc/mc-proxy/internal/metrics"
) )
// PolicyRule defines an L7 blocking policy. // PolicyRule defines an L7 blocking policy.

View File

@@ -13,27 +13,27 @@ func TestPrefixConnRead(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("WORLD")) _, _ = conn.Write([]byte("WORLD"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("HELLO")) pc := NewPrefixConn(conn, []byte("HELLO"))
// Read all data: should get "HELLOWORLD". // Read all data: should get "HELLOWORLD".
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc) all, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -48,22 +48,22 @@ func TestPrefixConnSmallReads(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("CD")) _, _ = conn.Write([]byte("CD"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("AB")) pc := NewPrefixConn(conn, []byte("AB"))
@@ -79,7 +79,7 @@ func TestPrefixConnSmallReads(t *testing.T) {
} }
// Now reads come from the underlying conn. // Now reads come from the underlying conn.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
rest, err := io.ReadAll(pc) rest, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -94,25 +94,25 @@ func TestPrefixConnEmptyPrefix(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("DATA")) _, _ = conn.Write([]byte("DATA"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, nil) pc := NewPrefixConn(conn, nil)
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc) all, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -127,12 +127,12 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, _ := ln.Accept() conn, _ := ln.Accept()
if conn != nil { if conn != nil {
conn.Close() _ = conn.Close()
} }
}() }()
@@ -140,7 +140,7 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("X")) pc := NewPrefixConn(conn, []byte("X"))

View File

@@ -15,8 +15,8 @@ import (
"strconv" "strconv"
"time" "time"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/mc/mc-proxy/internal/metrics"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2" "golang.org/x/net/http2"
) )
@@ -120,7 +120,7 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
ReadHeaderTimeout: 30 * time.Second, ReadHeaderTimeout: 30 * time.Second,
} }
singleConn := newSingleConnListener(tlsConn) singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn) _ = srv.Serve(singleConn)
} }
return nil return nil
@@ -186,16 +186,16 @@ func newTransport(route RouteConfig) (http.RoundTripper, error) {
}, nil }, nil
} }
// h2c: HTTP/2 over plaintext TCP. // Plain HTTP backend. Use standard http.Transport which speaks
return &http2.Transport{ // HTTP/1.1 by default and can upgrade to h2c if the backend
AllowHTTP: true, // supports it. This handles backends like Gitea that only speak
DialTLSContext: func(ctx context.Context, network, addr string, _ *tls.Config) (net.Conn, error) { // HTTP/1.1.
conn, err := dialBackend(ctx, network, addr, connectTimeout, route) return &http.Transport{
if err != nil { DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return nil, err return dialBackend(ctx, network, addr, connectTimeout, route)
}
return conn, nil
}, },
MaxIdleConnsPerHost: 10,
IdleConnTimeout: 90 * time.Second,
}, nil }, nil
} }
@@ -213,7 +213,7 @@ func dialBackend(ctx context.Context, network, addr string, timeout time.Duratio
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String()) backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() { if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil { if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close() _ = conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err) return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
} }
} }

View File

@@ -58,8 +58,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil { if err != nil {
t.Fatalf("creating cert file: %v", err) t.Fatalf("creating cert file: %v", err)
} }
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) _ = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close() _ = certFile.Close()
keyDER, err := x509.MarshalECPrivateKey(key) keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil { if err != nil {
@@ -69,8 +69,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil { if err != nil {
t.Fatalf("creating key file: %v", err) t.Fatalf("creating key file: %v", err)
} }
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) _ = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
keyFile.Close() _ = keyFile.Close()
return certPath, keyPath return certPath, keyPath
} }
@@ -91,11 +91,11 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { t.Cleanup(func() {
srv.Close() _ = srv.Close()
ln.Close() _ = ln.Close()
}) })
go srv.Serve(ln) go func() { _ = srv.Serve(ln) }()
return ln.Addr().String() return ln.Addr().String()
} }
@@ -118,7 +118,7 @@ func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
if err != nil { if err != nil {
t.Fatalf("TLS dial: %v", err) t.Fatalf("TLS dial: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
// Create an HTTP/2 client transport over this single connection. // Create an HTTP/2 client transport over this single connection.
tr := &http2.Transport{} tr := &http2.Transport{}
@@ -142,29 +142,13 @@ func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return s.cc.RoundTrip(req) return s.cc.RoundTrip(req)
} }
// serveL7Route starts l7.Serve in a goroutine for a single connection.
// Returns when the goroutine completes.
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
ctx := context.Background()
go func() {
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
if l7Err != nil {
t.Logf("l7.Serve: %v", l7Err)
}
}()
}
func TestL7H2CBackend(t *testing.T) { func TestL7H2CBackend(t *testing.T) {
certPath, keyPath := testCert(t, "l7.test") certPath, keyPath := testCert(t, "l7.test")
// Start an h2c backend. // Start an h2c backend.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok") w.Header().Set("X-Backend", "ok")
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path) _, _ = fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
})) }))
// Start a TCP listener for the L7 proxy. // Start a TCP listener for the L7 proxy.
@@ -172,7 +156,7 @@ func TestL7H2CBackend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -190,17 +174,17 @@ func TestL7H2CBackend(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
// No peeked bytes — the client is connecting directly with TLS. // No peeked bytes — the client is connecting directly with TLS.
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
// Connect as an HTTP/2 TLS client. // Connect as an HTTP/2 TLS client.
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo")) resp, err := client.Get("https://l7.test/foo")
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode) t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -221,7 +205,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
// Backend that echoes the forwarding headers. // Backend that echoes the forwarding headers.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s", _, _ = fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
r.Header.Get("X-Forwarded-For"), r.Header.Get("X-Forwarded-For"),
r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Forwarded-Proto"),
r.Header.Get("X-Real-IP"), r.Header.Get("X-Real-IP"),
@@ -232,7 +216,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -248,7 +232,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
@@ -256,7 +240,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50" want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
@@ -274,13 +258,13 @@ func TestL7BackendUnreachable(t *testing.T) {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
deadAddr := ln.Addr().String() deadAddr := ln.Addr().String()
ln.Close() _ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: deadAddr, Backend: deadAddr,
@@ -296,7 +280,7 @@ func TestL7BackendUnreachable(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
@@ -304,7 +288,7 @@ func TestL7BackendUnreachable(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadGateway { if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want 502", resp.StatusCode) t.Fatalf("status = %d, want 502", resp.StatusCode)
@@ -342,14 +326,14 @@ func TestL7MultipleRequests(t *testing.T) {
var reqCount int var reqCount int
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++ reqCount++
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path) _, _ = fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -365,7 +349,7 @@ func TestL7MultipleRequests(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
@@ -378,7 +362,7 @@ func TestL7MultipleRequests(t *testing.T) {
t.Fatalf("GET %s: %v", path, err) t.Fatalf("GET %s: %v", path, err)
} }
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
resp.Body.Close() _ = resp.Body.Close()
want := fmt.Sprintf("req=%d path=%s", i+1, path) want := fmt.Sprintf("req=%d path=%s", i+1, path)
if string(body) != want { if string(body) != want {
@@ -396,14 +380,14 @@ func TestL7LargeResponse(t *testing.T) {
largeBody[i] = byte(i % 256) largeBody[i] = byte(i % 256)
} }
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(largeBody) _, _ = w.Write(largeBody)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -418,7 +402,7 @@ func TestL7LargeResponse(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
@@ -426,7 +410,7 @@ func TestL7LargeResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if len(body) != len(largeBody) { if len(body) != len(largeBody) {
@@ -455,7 +439,7 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -470,7 +454,7 @@ func TestL7GRPCTrailers(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
@@ -480,10 +464,10 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("POST: %v", err) t.Fatalf("POST: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
// Read body to trigger trailer delivery. // Read body to trigger trailer delivery.
io.ReadAll(resp.Body) _, _ = io.ReadAll(resp.Body)
// Verify trailers were forwarded through the proxy. // Verify trailers were forwarded through the proxy.
grpcStatus := resp.Trailer.Get("Grpc-Status") grpcStatus := resp.Trailer.Get("Grpc-Status")
@@ -500,14 +484,14 @@ func TestL7HTTP11Fallback(t *testing.T) {
certPath, keyPath := testCert(t, "http11.test") certPath, keyPath := testCert(t, "http11.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "proto=%s", r.Proto) _, _ = fmt.Fprintf(w, "proto=%s", r.Proto)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -522,7 +506,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
// Connect with HTTP/1.1 only (no h2 ALPN). // Connect with HTTP/1.1 only (no h2 ALPN).
@@ -538,7 +522,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode) t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -556,14 +540,14 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
certPath, keyPath := testCert(t, "policy.test") certPath, keyPath := testCert(t, "policy.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "should-not-reach") _, _ = fmt.Fprint(w, "should-not-reach")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -581,7 +565,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
@@ -591,7 +575,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 403 { if resp.StatusCode != 403 {
t.Fatalf("status = %d, want 403", resp.StatusCode) t.Fatalf("status = %d, want 403", resp.StatusCode)
@@ -602,14 +586,14 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
certPath, keyPath := testCert(t, "reqhdr.test") certPath, keyPath := testCert(t, "reqhdr.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok") _, _ = fmt.Fprint(w, "ok")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -630,7 +614,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
} }
go func() { go func() {
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
} }
}() }()
@@ -641,7 +625,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET without header: %v", err) t.Fatalf("GET without header: %v", err)
} }
resp1.Body.Close() _ = resp1.Body.Close()
if resp1.StatusCode != 403 { if resp1.StatusCode != 403 {
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode) t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
} }
@@ -654,7 +638,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET with header: %v", err) t.Fatalf("GET with header: %v", err)
} }
defer resp2.Body.Close() defer func() { _ = resp2.Body.Close() }()
body, _ := io.ReadAll(resp2.Body) body, _ := io.ReadAll(resp2.Body)
if resp2.StatusCode != 200 { if resp2.StatusCode != 200 {
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode) t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)

View File

@@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
@@ -31,8 +32,8 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
go func() { go func() {
<-ctx.Done() <-ctx.Done()
client.Close() _ = client.Close()
backend.Close() _ = backend.Close()
}() }()
var ( var (
@@ -50,7 +51,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout) result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side. // Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok { if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite() _ = hc.CloseWrite()
} }
}() }()
@@ -60,7 +61,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout) result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side. // Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok { if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite() _ = hc.CloseWrite()
} }
}() }()
@@ -85,10 +86,10 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
var total int64 var total int64
for { for {
src.SetReadDeadline(time.Now().Add(idleTimeout)) _ = src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf) nr, readErr := src.Read(buf)
if nr > 0 { if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout)) _ = dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr]) nw, writeErr := dst.Write(buf[:nr])
total += int64(nw) total += int64(nw)
if writeErr != nil { if writeErr != nil {
@@ -96,7 +97,7 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
} }
} }
if readErr != nil { if readErr != nil {
if readErr == io.EOF { if errors.Is(readErr, io.EOF) {
return total, nil return total, nil
} }
return total, readErr return total, readErr

View File

@@ -16,7 +16,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
peeked := []byte("peeked-hello-bytes") peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client") clientData := []byte("data from client")
@@ -29,7 +29,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read everything the backend receives. // Read everything the backend receives.
received, _ := io.ReadAll(conn) received, _ := io.ReadAll(conn)
@@ -40,21 +40,21 @@ func TestRelayBasic(t *testing.T) {
}() }()
// Restructure: use a more controlled flow. // Restructure: use a more controlled flow.
backendLn.Close() _ = backendLn.Close()
// Use a real TCP pair for proper half-close. // Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0") backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer backendLn2.Close() defer func() { _ = backendLn2.Close() }()
go func() { go func() {
conn, err := backendLn2.Accept() conn, err := backendLn2.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read peeked + client data. // Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData)) buf := make([]byte, len(peeked)+len(clientData))
@@ -62,11 +62,11 @@ func TestRelayBasic(t *testing.T) {
backendDone <- buf[:n] backendDone <- buf[:n]
// Send response. // Send response.
conn.Write(backendData) _, _ = conn.Write(backendData)
// Close write side to signal EOF. // Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok { if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -81,7 +81,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer clientLn.Close() defer func() { _ = clientLn.Close() }()
clientConn, err := net.Dial("tcp", clientLn.Addr().String()) clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil { if err != nil {
@@ -94,9 +94,9 @@ func TestRelayBasic(t *testing.T) {
// Client sends data then closes write. // Client sends data then closes write.
go func() { go func() {
clientConn.Write(clientData) _, _ = clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok { if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -114,7 +114,7 @@ func TestRelayBasic(t *testing.T) {
} }
// Verify client received backend data. // Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn) clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) { if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData) t.Fatalf("client received %q, want %q", clientReceived, backendData)
@@ -131,12 +131,12 @@ func TestRelayBasic(t *testing.T) {
func TestRelayIdleTimeout(t *testing.T) { func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP. // Two connected pairs via TCP.
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
start := time.Now() start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond) _, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
@@ -154,18 +154,18 @@ func TestRelayIdleTimeout(t *testing.T) {
func TestRelayContextCancel(t *testing.T) { func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
Relay(ctx, clientB, backendA, nil, time.Minute) _, _ = Relay(ctx, clientB, backendA, nil, time.Minute)
close(done) close(done)
}() }()
@@ -185,12 +185,12 @@ func TestRelayContextCancel(t *testing.T) {
func TestRelayLargeTransfer(t *testing.T) { func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
// 1 MB of random data. // 1 MB of random data.
data := make([]byte, 1<<20) data := make([]byte, 1<<20)
@@ -199,9 +199,9 @@ func TestRelayLargeTransfer(t *testing.T) {
} }
go func() { go func() {
clientA.Write(data) _, _ = clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok { if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -211,14 +211,14 @@ func TestRelayLargeTransfer(t *testing.T) {
for { for {
n, err := backendB.Read(buf) n, err := backendB.Read(buf)
if n > 0 { if n > 0 {
backendB.Write(buf[:n]) _, _ = backendB.Write(buf[:n])
} }
if err != nil { if err != nil {
break break
} }
} }
if tc, ok := backendB.(*net.TCPConn); ok { if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -240,7 +240,7 @@ func tcpPair(t *testing.T) (net.Conn, net.Conn) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
var serverConn net.Conn var serverConn net.Conn
done := make(chan struct{}) done := make(chan struct{})

View File

@@ -50,8 +50,8 @@ const (
// It reads only the exact bytes of the PROXY header, leaving the connection // It reads only the exact bytes of the PROXY header, leaving the connection
// positioned at the first byte after the header (e.g., TLS ClientHello). // positioned at the first byte after the header (e.g., TLS ClientHello).
func Parse(conn net.Conn, deadline time.Time) (Header, error) { func Parse(conn net.Conn, deadline time.Time) (Header, error) {
conn.SetReadDeadline(deadline) _ = conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{}) defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read the first byte to determine version. // Read the first byte to determine version.
var first [1]byte var first [1]byte

View File

@@ -16,7 +16,7 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { ln.Close() }) t.Cleanup(func() { _ = ln.Close() })
ch := make(chan net.Conn, 1) ch := make(chan net.Conn, 1)
go func() { go func() {
@@ -31,10 +31,10 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
t.Cleanup(func() { w.Close() }) t.Cleanup(func() { _ = w.Close() })
r := <-ch r := <-ch
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { _ = r.Close() })
return r, w return r, w
} }
@@ -42,7 +42,7 @@ func TestParseV1TCP4(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n")) _, _ = writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -68,7 +68,7 @@ func TestParseV1TCP6(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n")) _, _ = writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -98,7 +98,7 @@ func TestParseV2TCP4(t *testing.T) {
buf = append(buf, 10, 0, 0, 1) // dst IP buf = append(buf, 10, 0, 0, 1) // dst IP
buf = binary.BigEndian.AppendUint16(buf, 12345) buf = binary.BigEndian.AppendUint16(buf, 12345)
buf = binary.BigEndian.AppendUint16(buf, 443) buf = binary.BigEndian.AppendUint16(buf, 443)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -137,7 +137,7 @@ func TestParseV2TCP6(t *testing.T) {
buf = append(buf, dst[:]...) buf = append(buf, dst[:]...)
buf = binary.BigEndian.AppendUint16(buf, 56324) buf = binary.BigEndian.AppendUint16(buf, 56324)
buf = binary.BigEndian.AppendUint16(buf, 8443) buf = binary.BigEndian.AppendUint16(buf, 8443)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -163,7 +163,7 @@ func TestParseV2Local(t *testing.T) {
buf = append(buf, 0x20) // version 2, LOCAL command buf = append(buf, 0x20) // version 2, LOCAL command
buf = append(buf, 0x00) // unspec family, unspec protocol buf = append(buf, 0x00) // unspec family, unspec protocol
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -195,7 +195,7 @@ func TestParseV1Malformed(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte(tt.data)) _, _ = writer.Write([]byte(tt.data))
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -211,7 +211,7 @@ func TestParseV2Malformed(t *testing.T) {
go func() { go func() {
bad := make([]byte, v2HeaderLen) bad := make([]byte, v2HeaderLen)
bad[0] = v2Signature[0] // first byte matches but rest doesn't bad[0] = v2Signature[0] // first byte matches but rest doesn't
writer.Write(bad) _, _ = writer.Write(bad)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -227,7 +227,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x31) // version 3, PROXY command buf = append(buf, 0x31) // version 3, PROXY command
buf = append(buf, 0x11) buf = append(buf, 0x11)
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -244,8 +244,8 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x11) // AF_INET, STREAM buf = append(buf, 0x11) // AF_INET, STREAM
buf = binary.BigEndian.AppendUint16(buf, 12) buf = binary.BigEndian.AppendUint16(buf, 12)
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12 buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
writer.Write(buf) _, _ = writer.Write(buf)
writer.Close() _ = writer.Close()
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -261,7 +261,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x21) // version 2, PROXY buf = append(buf, 0x21) // version 2, PROXY
buf = append(buf, 0x31) // AF_UNIX (3), STREAM buf = append(buf, 0x31) // AF_UNIX (3), STREAM
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -332,7 +332,7 @@ func TestRoundTripV2IPv4(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
WriteV2(writer, src, dst) _ = WriteV2(writer, src, dst)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -361,7 +361,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
WriteV2(writer, src, dst) _ = WriteV2(writer, src, dst)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -380,7 +380,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
func TestParseGarbageFirstByte(t *testing.T) { func TestParseGarbageFirstByte(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte{0xFF, 0x00, 0x01}) _, _ = writer.Write([]byte{0xFF, 0x00, 0x01})
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {

View File

@@ -11,13 +11,13 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/l7" "git.wntrmute.dev/mc/mc-proxy/internal/l7"
"git.wntrmute.dev/kyle/mc-proxy/internal/metrics" "git.wntrmute.dev/mc/mc-proxy/internal/metrics"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxy" "git.wntrmute.dev/mc/mc-proxy/internal/proxy"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
"git.wntrmute.dev/kyle/mc-proxy/internal/sni" "git.wntrmute.dev/mc/mc-proxy/internal/sni"
) )
// L7PolicyRule is an L7 blocking policy attached to a route. // L7PolicyRule is an L7 blocking policy attached to a route.
@@ -235,7 +235,7 @@ func (s *Server) Run(ctx context.Context) error {
ln, err := net.Listen("tcp", ls.Addr) ln, err := net.Listen("tcp", ls.Addr)
if err != nil { if err != nil {
for _, l := range netListeners { for _, l := range netListeners {
l.Close() _ = l.Close()
} }
return fmt.Errorf("listening on %s: %w", ls.Addr, err) return fmt.Errorf("listening on %s: %w", ls.Addr, err)
} }
@@ -253,7 +253,7 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("shutting down") s.logger.Info("shutting down")
for _, ln := range netListeners { for _, ln := range netListeners {
ln.Close() _ = ln.Close()
} }
done := make(chan struct{}) done := make(chan struct{})
@@ -272,7 +272,7 @@ func (s *Server) Run(ctx context.Context) error {
<-done <-done
} }
s.fw.Close() _ = s.fw.Close()
return nil return nil
} }
@@ -294,7 +294,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
// Enforce per-listener connection limit. // Enforce per-listener connection limit.
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections { if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
conn.Close() _ = conn.Close()
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections) s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
continue continue
} }
@@ -311,7 +311,7 @@ func (s *Server) forceCloseAll() {
for _, ls := range s.listeners { for _, ls := range s.listeners {
ls.connMu.Lock() ls.connMu.Lock()
for conn := range ls.activeConns { for conn := range ls.activeConns {
conn.Close() _ = conn.Close()
} }
ls.connMu.Unlock() ls.connMu.Unlock()
} }
@@ -321,7 +321,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
defer s.wg.Done() defer s.wg.Done()
defer ls.ActiveConnections.Add(-1) defer ls.ActiveConnections.Add(-1)
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec() defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
defer conn.Close() defer func() { _ = conn.Close() }()
ls.connMu.Lock() ls.connMu.Lock()
ls.activeConns[conn] = struct{}{} ls.activeConns[conn] = struct{}{}
@@ -392,7 +392,7 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err) s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return return
} }
defer backendConn.Close() defer func() { _ = backendConn.Close() }()
// Send PROXY protocol v2 header to backend if configured. // Send PROXY protocol v2 header to backend if configured.
if route.SendProxyProtocol { if route.SendProxyProtocol {

View File

@@ -24,9 +24,9 @@ import (
"testing" "testing"
"time" "time"
"git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/proxyproto" "git.wntrmute.dev/mc/mc-proxy/internal/proxyproto"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
) )
@@ -43,8 +43,8 @@ func echoServer(t *testing.T, ln net.Listener) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
io.Copy(conn, conn) _, _ = io.Copy(conn, conn)
} }
// newTestServer creates a Server with the given listener data and no firewall rules. // newTestServer creates a Server with the given listener data and no firewall rules.
@@ -92,7 +92,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
// Pick a free port for the proxy listener. // Pick a free port for the proxy listener.
@@ -101,7 +101,7 @@ func TestProxyRoundTrip(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -121,7 +121,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -130,7 +130,7 @@ func TestProxyRoundTrip(t *testing.T) {
// The backend will echo our ClientHello back. Read it. // The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -157,7 +157,7 @@ func TestNoRouteResets(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -176,7 +176,7 @@ func TestNoRouteResets(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("unknown.test") hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -184,7 +184,7 @@ func TestNoRouteResets(t *testing.T) {
} }
// The proxy should close the connection (no route match). // The proxy should close the connection (no route match).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed, but read succeeded") t.Fatal("expected connection to be closed, but read succeeded")
@@ -197,7 +197,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -205,7 +205,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -214,7 +214,7 @@ func TestFirewallBlocks(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client). // Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0) fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
@@ -245,7 +245,7 @@ func TestFirewallBlocks(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
srv.Run(ctx) _ = srv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -253,13 +253,13 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Connection should be closed (blocked by firewall). // Connection should be closed (blocked by firewall).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed by firewall") t.Fatal("expected connection to be closed by firewall")
@@ -283,7 +283,7 @@ func TestNotTLSResets(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -300,12 +300,12 @@ func TestNotTLSResets(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send HTTP, not TLS. // Send HTTP, not TLS.
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n")) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed for non-TLS data") t.Fatal("expected connection to be closed for non-TLS data")
@@ -318,7 +318,7 @@ func TestConnectionTracking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
var backendConns []net.Conn var backendConns []net.Conn
var mu sync.Mutex var mu sync.Mutex
@@ -332,7 +332,7 @@ func TestConnectionTracking(t *testing.T) {
backendConns = append(backendConns, conn) backendConns = append(backendConns, conn)
mu.Unlock() mu.Unlock()
// Hold connection open, drain input. // Hold connection open, drain input.
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
} }
}() }()
@@ -341,7 +341,7 @@ func TestConnectionTracking(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -382,10 +382,10 @@ func TestConnectionTracking(t *testing.T) {
} }
// Close one client and its corresponding backend connection. // Close one client and its corresponding backend connection.
clientConns[0].Close() _ = clientConns[0].Close()
mu.Lock() mu.Lock()
if len(backendConns) > 0 { if len(backendConns) > 0 {
backendConns[0].Close() _ = backendConns[0].Close()
} }
mu.Unlock() mu.Unlock()
@@ -402,10 +402,10 @@ func TestConnectionTracking(t *testing.T) {
} }
// Clean up. // Clean up.
clientConns[1].Close() _ = clientConns[1].Close()
mu.Lock() mu.Lock()
for _, c := range backendConns { for _, c := range backendConns {
c.Close() _ = c.Close()
} }
mu.Unlock() mu.Unlock()
} }
@@ -416,13 +416,13 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend A listen: %v", err) t.Fatalf("backend A listen: %v", err)
} }
defer backendA.Close() defer func() { _ = backendA.Close() }()
backendB, err := net.Listen("tcp", "127.0.0.1:0") backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("backend B listen: %v", err) t.Fatalf("backend B listen: %v", err)
} }
defer backendB.Close() defer func() { _ = backendB.Close() }()
// Each backend writes its identity and closes. // Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) { serve := func(ln net.Listener, id string) {
@@ -430,10 +430,10 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Drain the incoming data, then write identity. // Drain the incoming data, then write identity.
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
conn.Write([]byte(id)) _, _ = conn.Write([]byte(id))
} }
go serve(backendA, "A") go serve(backendA, "A")
go serve(backendB, "B") go serve(backendB, "B")
@@ -444,14 +444,14 @@ func TestMultipleListeners(t *testing.T) {
t.Fatalf("finding free port 1: %v", err) t.Fatalf("finding free port 1: %v", err)
} }
addr1 := ln1.Addr().String() addr1 := ln1.Addr().String()
ln1.Close() _ = ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0") ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("finding free port 2: %v", err) t.Fatalf("finding free port 2: %v", err)
} }
addr2 := ln2.Addr().String() addr2 := ln2.Addr().String()
ln2.Close() _ = ln2.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}}, {ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
@@ -467,12 +467,12 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err) t.Fatalf("dial %s: %v", proxyAddr, err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("svc.test") hello := buildClientHello("svc.test")
conn.Write(hello) _, _ = conn.Write(hello)
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128) buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID. // Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we // The backend drains input and writes the ID, so we read until we
@@ -508,7 +508,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -516,7 +516,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -537,7 +537,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("ECHO.TEST") hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -545,7 +545,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
} }
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -558,14 +558,14 @@ func TestBackendUnreachable(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
deadAddr := ln.Addr().String() deadAddr := ln.Addr().String()
ln.Close() _ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -584,13 +584,13 @@ func TestBackendUnreachable(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("dead.test") hello := buildClientHello("dead.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Proxy should close the connection after failing to dial backend. // Proxy should close the connection after failing to dial backend.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable") t.Fatal("expected connection to be closed when backend is unreachable")
@@ -603,15 +603,15 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go func() { go func() {
conn, err := backendLn.Accept() conn, err := backendLn.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
io.Copy(io.Discard, conn) _, _ = io.Copy(io.Discard, conn)
}() }()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -619,7 +619,7 @@ func TestGracefulShutdown(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil, 0, 0) fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil { if err != nil {
@@ -649,10 +649,10 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("hold.test") hello := buildClientHello("hold.test")
conn.Write(hello) _, _ = conn.Write(hello)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
// Trigger shutdown. // Trigger shutdown.
@@ -719,7 +719,7 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -727,7 +727,7 @@ func TestProxyProtocolReceive(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -747,22 +747,22 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send PROXY v2 header followed by TLS ClientHello. // Send PROXY v2 header followed by TLS ClientHello.
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"), netip.MustParseAddrPort("198.51.100.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Backend should echo the ClientHello back (not the PROXY header). // Backend should echo the ClientHello back (not the PROXY header).
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -774,7 +774,7 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -794,13 +794,13 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send garbage instead of a valid PROXY header. // Send garbage instead of a valid PROXY header.
conn.Write([]byte("NOT A PROXY HEADER\r\n")) _, _ = conn.Write([]byte("NOT A PROXY HEADER\r\n"))
// Connection should be closed. // Connection should be closed.
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed for invalid PROXY header") t.Fatal("expected connection to be closed for invalid PROXY header")
@@ -813,7 +813,7 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1) received := make(chan []byte, 1)
go func() { go func() {
@@ -821,9 +821,9 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read all available data; the proxy sends PROXY header + ClientHello. // Read all available data; the proxy sends PROXY header + ClientHello.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var all []byte var all []byte
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
@@ -845,7 +845,7 @@ func TestProxyProtocolSend(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -868,10 +868,10 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("pp.test") hello := buildClientHello("pp.test")
conn.Write(hello) _, _ = conn.Write(hello)
// The backend should receive: PROXY v2 header + ClientHello. // The backend should receive: PROXY v2 header + ClientHello.
select { select {
@@ -904,7 +904,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1) received := make(chan []byte, 1)
go func() { go func() {
@@ -912,7 +912,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
buf := make([]byte, 4096) buf := make([]byte, 4096)
n, _ := conn.Read(buf) n, _ := conn.Read(buf)
received <- buf[:n] received <- buf[:n]
@@ -923,7 +923,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -942,10 +942,10 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("nopp.test") hello := buildClientHello("nopp.test")
conn.Write(hello) _, _ = conn.Write(hello)
select { select {
case data := <-received: case data := <-received:
@@ -964,7 +964,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -972,7 +972,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -981,7 +981,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
// Block 203.0.113.50 (the "real" client IP from PROXY header). // Block 203.0.113.50 (the "real" client IP from PROXY header).
// 127.0.0.1 (the actual TCP peer) is NOT blocked. // 127.0.0.1 (the actual TCP peer) is NOT blocked.
@@ -1014,7 +1014,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
srv.Run(ctx) _ = srv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -1022,19 +1022,19 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send PROXY v2 with the blocked real IP. // Send PROXY v2 with the blocked real IP.
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"), netip.MustParseAddrPort("198.51.100.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test")) _, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be closed (firewall blocks real IP). // Connection should be closed (firewall blocks real IP).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed by firewall") t.Fatal("expected connection to be closed by firewall")
@@ -1060,7 +1060,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go func() { go func() {
for { for {
@@ -1068,7 +1068,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
return return
} }
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
} }
}() }()
@@ -1077,7 +1077,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1100,7 +1100,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial %d: %v", i, err) t.Fatalf("dial %d: %v", i, err)
} }
conn.Write(buildClientHello("limit.test")) _, _ = conn.Write(buildClientHello("limit.test"))
conns = append(conns, conn) conns = append(conns, conn)
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -1110,16 +1110,16 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial 3: %v", err) t.Fatalf("dial 3: %v", err)
} }
conn3.Write(buildClientHello("limit.test")) _, _ = conn3.Write(buildClientHello("limit.test"))
conn3.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn3.Read(make([]byte, 1)) _, err = conn3.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected 3rd connection to be closed due to limit") t.Fatal("expected 3rd connection to be closed due to limit")
} }
conn3.Close() _ = conn3.Close()
// Close one existing connection. // Close one existing connection.
conns[0].Close() _ = conns[0].Close()
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
// Now a new connection should succeed. // Now a new connection should succeed.
@@ -1127,8 +1127,8 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial 4: %v", err) t.Fatalf("dial 4: %v", err)
} }
defer conn4.Close() defer func() { _ = conn4.Close() }()
conn4.Write(buildClientHello("limit.test")) _, _ = conn4.Write(buildClientHello("limit.test"))
// Give it time to be proxied. // Give it time to be proxied.
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -1138,7 +1138,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
// Clean up. // Clean up.
for _, c := range conns[1:] { for _, c := range conns[1:] {
c.Close() _ = c.Close()
} }
} }
@@ -1155,7 +1155,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
// h2c backend on origin that echoes the X-Forwarded-For. // h2c backend on origin that echoes the X-Forwarded-For.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For")) _, _ = fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
})) }))
// Origin proxy: proxy_protocol=true listener, L7 route to backend. // Origin proxy: proxy_protocol=true listener, L7 route to backend.
@@ -1164,7 +1164,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("origin listen: %v", err) t.Fatalf("origin listen: %v", err)
} }
originAddr := originLn.Addr().String() originAddr := originLn.Addr().String()
originLn.Close() _ = originLn.Close()
originFw, _ := firewall.New("", nil, nil, nil, 0, 0) originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
originCfg := &config.Config{ originCfg := &config.Config{
@@ -1196,7 +1196,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
originWg.Add(1) originWg.Add(1)
go func() { go func() {
defer originWg.Done() defer originWg.Done()
originSrv.Run(originCtx) _ = originSrv.Run(originCtx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
defer func() { defer func() {
@@ -1210,7 +1210,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("edge listen: %v", err) t.Fatalf("edge listen: %v", err)
} }
edgeAddr := edgeLn.Addr().String() edgeAddr := edgeLn.Addr().String()
edgeLn.Close() _ = edgeLn.Close()
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0) edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
edgeSrv := New(originCfg, edgeFw, []ListenerData{ edgeSrv := New(originCfg, edgeFw, []ListenerData{
@@ -1232,7 +1232,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
edgeWg.Add(1) edgeWg.Add(1)
go func() { go func() {
defer edgeWg.Done() defer edgeWg.Done()
edgeSrv.Run(edgeCtx) _ = edgeSrv.Run(edgeCtx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
defer func() { defer func() {
@@ -1253,7 +1253,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial edge: %v", err) t.Fatalf("TLS dial edge: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn) h2conn, err := tr.NewClientConn(conn)
@@ -1266,7 +1266,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("RoundTrip: %v", err) t.Fatalf("RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
@@ -1289,7 +1289,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -1297,7 +1297,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -1306,7 +1306,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
t.Fatalf("origin listen: %v", err) t.Fatalf("origin listen: %v", err)
} }
originAddr := originLn.Addr().String() originAddr := originLn.Addr().String()
originLn.Close() _ = originLn.Close()
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header. // Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0) originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
@@ -1334,7 +1334,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
originSrv.Run(ctx) _ = originSrv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -1343,18 +1343,18 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial origin: %v", err) t.Fatalf("dial origin: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("198.51.100.99:12345"), netip.MustParseAddrPort("198.51.100.99:12345"),
netip.MustParseAddrPort("10.0.0.1:443"), netip.MustParseAddrPort("10.0.0.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test")) _, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be dropped by firewall. // Connection should be dropped by firewall.
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed") t.Fatal("expected connection to be closed")
@@ -1396,12 +1396,12 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
certPath = filepath.Join(dir, "cert.pem") certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem") keyPath = filepath.Join(dir, "key.pem")
cf, _ := os.Create(certPath) cf, _ := os.Create(certPath)
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) _ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
cf.Close() _ = cf.Close()
keyDER, _ := x509.MarshalECPrivateKey(key) keyDER, _ := x509.MarshalECPrivateKey(key)
kf, _ := os.Create(keyPath) kf, _ := os.Create(keyPath)
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) _ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
kf.Close() _ = kf.Close()
return return
} }
@@ -1417,8 +1417,8 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { srv.Close(); ln.Close() }) t.Cleanup(func() { _ = srv.Close(); _ = ln.Close() })
go srv.Serve(ln) go func() { _ = srv.Serve(ln) }()
return ln.Addr().String() return ln.Addr().String()
} }
@@ -1426,7 +1426,7 @@ func TestL7ThroughServer(t *testing.T) {
certPath, keyPath := testCert(t, "l7srv.test") certPath, keyPath := testCert(t, "l7srv.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For")) _, _ = fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1434,7 +1434,7 @@ func TestL7ThroughServer(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1467,7 +1467,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial: %v", err) t.Fatalf("TLS dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn) h2conn, err := tr.NewClientConn(conn)
@@ -1480,7 +1480,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("RoundTrip: %v", err) t.Fatalf("RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since // The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
@@ -1502,12 +1502,12 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("l4 backend listen: %v", err) t.Fatalf("l4 backend listen: %v", err)
} }
defer l4BackendLn.Close() defer func() { _ = l4BackendLn.Close() }()
go echoServer(t, l4BackendLn) go echoServer(t, l4BackendLn)
// L7 backend: h2c HTTP server. // L7 backend: h2c HTTP server.
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "l7-response") _, _ = fmt.Fprint(w, "l7-response")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1515,7 +1515,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1541,11 +1541,11 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial L4: %v", err) t.Fatalf("dial L4: %v", err)
} }
defer l4Conn.Close() defer func() { _ = l4Conn.Close() }()
hello := buildClientHello("l4echo.test") hello := buildClientHello("l4echo.test")
l4Conn.Write(hello) _, _ = l4Conn.Write(hello)
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(l4Conn, echoed); err != nil { if _, err := io.ReadFull(l4Conn, echoed); err != nil {
t.Fatalf("L4 echo read: %v", err) t.Fatalf("L4 echo read: %v", err)
} }
@@ -1563,7 +1563,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial L7: %v", err) t.Fatalf("TLS dial L7: %v", err)
} }
defer l7Conn.Close() defer func() { _ = l7Conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(l7Conn) h2conn, err := tr.NewClientConn(l7Conn)
@@ -1576,7 +1576,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("L7 RoundTrip: %v", err) t.Fatalf("L7 RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if string(body) != "l7-response" { if string(body) != "l7-response" {
@@ -1636,4 +1636,3 @@ func sniExtension(serverName string) []byte {
return ext return ext
} }

View File

@@ -17,8 +17,8 @@ const maxBufferSize = 16384 // 16 KiB, max TLS record size
// //
// A read deadline is set on the connection to prevent slowloris attacks. // A read deadline is set on the connection to prevent slowloris attacks.
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) { func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
conn.SetReadDeadline(deadline) _ = conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{}) defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read TLS record header (5 bytes). // Read TLS record header (5 bytes).
header := make([]byte, 5) header := make([]byte, 5)

View File

@@ -22,13 +22,13 @@ func TestExtract(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHello(tt.sni) hello := buildClientHello(tt.sni)
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second)) hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
@@ -53,13 +53,13 @@ func TestExtract(t *testing.T) {
func TestExtractNoSNI(t *testing.T) { func TestExtractNoSNI(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHelloNoSNI() hello := buildClientHelloNoSNI()
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -70,11 +70,11 @@ func TestExtractNoSNI(t *testing.T) {
func TestExtractNotTLS(t *testing.T) { func TestExtractNotTLS(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) _, _ = client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -85,13 +85,13 @@ func TestExtractNotTLS(t *testing.T) {
func TestExtractTruncated(t *testing.T) { func TestExtractTruncated(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
// Write just the TLS record header, then close. // Write just the TLS record header, then close.
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50}) _, _ = client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
client.Close() _ = client.Close()
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -102,15 +102,15 @@ func TestExtractTruncated(t *testing.T) {
func TestExtractOversizedRecord(t *testing.T) { func TestExtractOversizedRecord(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
// Record header claiming a length larger than 16 KiB. // Record header claiming a length larger than 16 KiB.
header := []byte{0x16, 0x03, 0x01} header := []byte{0x16, 0x03, 0x01}
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5 header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
client.Write(header) _, _ = client.Write(header)
client.Close() _ = client.Close()
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -121,13 +121,13 @@ func TestExtractOversizedRecord(t *testing.T) {
func TestExtractMultipleExtensions(t *testing.T) { func TestExtractMultipleExtensions(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHelloWithExtraExtensions("target.example.com") hello := buildClientHelloWithExtraExtensions("target.example.com")
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
hostname, _, err := Extract(server, time.Now().Add(5*time.Second)) hostname, _, err := Extract(server, time.Now().Add(5*time.Second))

View File

@@ -2,7 +2,7 @@ syntax = "proto3";
package mc_proxy.v1; package mc_proxy.v1;
option go_package = "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1;mcproxyv1"; option go_package = "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1";
import "google/protobuf/timestamp.proto"; import "google/protobuf/timestamp.proto";

View File

@@ -34,7 +34,7 @@ import (
"github.com/pelletier/go-toml/v2" "github.com/pelletier/go-toml/v2"
"git.wntrmute.dev/kyle/mcdsl/auth" "git.wntrmute.dev/mc/mcdsl/auth"
) )
// Base contains the configuration sections common to all Metacircular // Base contains the configuration sections common to all Metacircular
@@ -144,6 +144,8 @@ func Load[T any](path string, envPrefix string) (*T, error) {
applyEnvToStruct(reflect.ValueOf(&cfg).Elem(), envPrefix) applyEnvToStruct(reflect.ValueOf(&cfg).Elem(), envPrefix)
} }
applyPortEnv(&cfg)
applyBaseDefaults(&cfg) applyBaseDefaults(&cfg)
if err := validateBase(&cfg); err != nil { if err := validateBase(&cfg); err != nil {
@@ -239,6 +241,70 @@ func findBase(cfg any) *Base {
return nil return nil
} }
// applyPortEnv overrides ServerConfig.ListenAddr and ServerConfig.GRPCAddr
// from $PORT and $PORT_GRPC respectively. These environment variables are
// set by the MCP agent to assign authoritative port bindings, so they take
// precedence over both TOML values and generic env overrides.
func applyPortEnv(cfg any) {
sc := findServerConfig(cfg)
if sc == nil {
return
}
if port, ok := os.LookupEnv("PORT"); ok {
sc.ListenAddr = ":" + port
}
if port, ok := os.LookupEnv("PORT_GRPC"); ok {
sc.GRPCAddr = ":" + port
}
}
// findServerConfig returns a pointer to the ServerConfig in the config
// struct. It first checks for an embedded Base (which contains Server),
// then walks the struct tree via reflection to find any ServerConfig field
// directly (e.g., the Metacrypt pattern where ServerConfig is embedded
// without Base).
func findServerConfig(cfg any) *ServerConfig {
if base := findBase(cfg); base != nil {
return &base.Server
}
return findServerConfigReflect(reflect.ValueOf(cfg))
}
// findServerConfigReflect walks the struct tree to find a ServerConfig field.
func findServerConfigReflect(v reflect.Value) *ServerConfig {
if v.Kind() == reflect.Ptr {
v = v.Elem()
}
if v.Kind() != reflect.Struct {
return nil
}
scType := reflect.TypeOf(ServerConfig{})
t := v.Type()
for i := range t.NumField() {
field := t.Field(i)
fv := v.Field(i)
if field.Type == scType {
sc, ok := fv.Addr().Interface().(*ServerConfig)
if ok {
return sc
}
}
// Recurse into embedded or nested structs.
if fv.Kind() == reflect.Struct && field.Type != scType {
if sc := findServerConfigReflect(fv); sc != nil {
return sc
}
}
}
return nil
}
// applyEnvToStruct recursively walks a struct and overrides field values // applyEnvToStruct recursively walks a struct and overrides field values
// from environment variables. The env variable name is built from the // from environment variables. The env variable name is built from the
// prefix and the toml tag: PREFIX_SECTION_FIELD (uppercased). // prefix and the toml tag: PREFIX_SECTION_FIELD (uppercased).

8
vendor/modules.txt vendored
View File

@@ -1,8 +1,8 @@
# git.wntrmute.dev/kyle/mcdsl v1.0.0 # git.wntrmute.dev/mc/mcdsl v1.2.0
## explicit; go 1.25.7 ## explicit; go 1.25.7
git.wntrmute.dev/kyle/mcdsl/auth git.wntrmute.dev/mc/mcdsl/auth
git.wntrmute.dev/kyle/mcdsl/config git.wntrmute.dev/mc/mcdsl/config
git.wntrmute.dev/kyle/mcdsl/db git.wntrmute.dev/mc/mcdsl/db
# github.com/beorn7/perks v1.0.1 # github.com/beorn7/perks v1.0.1
## explicit; go 1.11 ## explicit; go 1.11
github.com/beorn7/perks/quantile github.com/beorn7/perks/quantile