5 Commits

Author SHA1 Message Date
5580bf74b0 Add VERSION variable and push target to Makefile
Extract VERSION variable (was inline). Add version tag to docker image.
Add push target that builds then pushes to MCR.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 14:32:17 -07:00
28321e22f4 Make AddRoute idempotent (upsert instead of reject duplicates)
AddRoute now updates an existing route if one already exists for the
same (listener, hostname) pair, instead of returning AlreadyExists.
This makes repeated deploys idempotent — the MCP agent can register
routes on every deploy without needing to remove them first.

- DB: INSERT ... ON CONFLICT DO UPDATE (SQLite upsert)
- In-memory: overwrite existing route unconditionally
- gRPC: error code changed from AlreadyExists to Internal (for real DB errors)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 14:01:45 -07:00
a60e5cb86a Fix golangci-lint v2 compliance, make all passes clean
- Fix 314 errcheck violations (blank identifier for unrecoverable errors)
- Fix errorlint violation (errors.Is for io.EOF)
- Remove unused serveL7Route test helper
- Simplify Duration.Seconds() selectors in tests
- Remove unnecessary fmt.Sprintf in test
- Migrate exclusion rules from issues.exclusions to linters.exclusions (v2 schema)
- Add gosec test exclusions (G115, G304, G402, G705)
- Disable fieldalignment govet analyzer (optimization, not correctness)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 13:30:43 -07:00
4f3249fdc3 Regenerate proto files for mc/ module path
Raw descriptor bytes in .pb.go files were corrupted by the sed-based
module path rename (string length changed, breaking protobuf binary
encoding). Regenerated with protoc to fix.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:54:22 -07:00
f31a7f20fb Bump flake.nix version to match latest tag
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:16:38 -07:00
31 changed files with 409 additions and 392 deletions

View File

@@ -9,6 +9,20 @@ run:
tests: true
linters:
exclusions:
paths:
- vendor
rules:
# In test files, suppress gosec rules that are false positives in test code:
# G101: hardcoded test credentials (intentional fixtures)
# G115: integer overflow in type conversions (test TLS packet builders)
# G304: file paths from variables (t.TempDir paths)
# G402: InsecureSkipVerify (required for test TLS clients)
# G705: XSS via taint analysis (test HTTP handlers, not real servers)
- path: "_test\\.go"
linters:
- gosec
text: "G101|G115|G304|G402|G705"
default: none
enable:
# --- Correctness ---
@@ -52,12 +66,15 @@ linters:
check-type-assertions: true
govet:
# Enable all analyzers except shadow. The shadow analyzer flags the idiomatic
# `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`,
# which is ubiquitous in Go and does not pose a security risk in this codebase.
# Enable all analyzers except shadow and fieldalignment. The shadow analyzer
# flags the idiomatic `if err := f(); err != nil { ... }` pattern as shadowing
# an outer `err`, which is ubiquitous in Go. The fieldalignment analyzer
# suggests struct field reordering for memory efficiency — useful as a one-off
# audit but too noisy for CI (every struct change triggers it).
enable-all: true
disable:
- shadow
- fieldalignment
gosec:
# Treat all gosec findings as errors, not warnings.
@@ -110,15 +127,3 @@ issues:
# Do not cap the number of reported issues; in security code every finding matters.
max-issues-per-linter: 0
max-same-issues: 0
exclusions:
paths:
- vendor
rules:
# In test files, allow hardcoded test credentials (gosec G101) since they are
# intentional fixtures, not production secrets.
- path: "_test\\.go"
linters:
- gosec
text: "G101"

View File

@@ -1,6 +1,8 @@
.PHONY: build test vet lint proto proto-lint clean docker all devserver
.PHONY: build test vet lint proto proto-lint clean docker push all devserver
LDFLAGS := -trimpath -ldflags="-s -w -X main.version=$(shell git describe --tags --always --dirty)"
MCR := mcr.svc.mcp.metacircular.net:8443
VERSION := $(shell git describe --tags --always --dirty)
LDFLAGS := -trimpath -ldflags="-s -w -X main.version=$(VERSION)"
mc-proxy:
go build $(LDFLAGS) -o mc-proxy ./cmd/mc-proxy
@@ -33,7 +35,10 @@ clean:
rm -f mc-proxy mcproxyctl
docker:
docker build --build-arg VERSION=$(shell git describe --tags --always --dirty) -t mc-proxy -f Dockerfile .
docker build --build-arg VERSION=$(VERSION) -t $(MCR)/mc-proxy:$(VERSION) -f Dockerfile .
push: docker
docker push $(MCR)/mc-proxy:$(VERSION)
devserver: mc-proxy
@mkdir -p srv

View File

@@ -32,7 +32,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil {
t.Fatalf("open db: %v", err)
}
t.Cleanup(func() { store.Close() })
t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
@@ -128,7 +128,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil {
t.Fatalf("dial bufconn: %v", err)
}
t.Cleanup(func() { conn.Close() })
t.Cleanup(func() { _ = conn.Close() })
return &Client{
conn: conn,

View File

@@ -40,7 +40,7 @@ func serverCmd() *cobra.Command {
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer store.Close()
defer func() { _ = store.Close() }()
if err := store.Migrate(); err != nil {
return fmt.Errorf("running migrations: %w", err)
@@ -93,7 +93,7 @@ func serverCmd() *cobra.Command {
}()
defer func() {
grpcSrv.GracefulStop()
os.Remove(cfg.GRPC.SocketPath())
_ = os.Remove(cfg.GRPC.SocketPath())
}()
}

View File

@@ -32,7 +32,7 @@ func snapshotCmd() *cobra.Command {
if err != nil {
return fmt.Errorf("opening database: %w", err)
}
defer store.Close()
defer func() { _ = store.Close() }()
dataDir := filepath.Dir(cfg.Database.Path)

View File

@@ -33,7 +33,7 @@ func statusCmd() *cobra.Command {
if err != nil {
return fmt.Errorf("connecting to gRPC API: %w", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
client := pb.NewProxyAdminServiceClient(conn)

View File

@@ -10,7 +10,7 @@
let
system = "x86_64-linux";
pkgs = nixpkgs.legacyPackages.${system};
version = "0.1.0";
version = "1.1.0";
in
{
packages.${system} = {

View File

@@ -1449,7 +1449,7 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" +
"\x0eListL7Policies\x12\".mc_proxy.v1.ListL7PoliciesRequest\x1a#.mc_proxy.v1.ListL7PoliciesResponse\x12P\n" +
"\vAddL7Policy\x12\x1f.mc_proxy.v1.AddL7PolicyRequest\x1a .mc_proxy.v1.AddL7PolicyResponse\x12Y\n" +
"\x0eRemoveL7Policy\x12\".mc_proxy.v1.RemoveL7PolicyRequest\x1a#.mc_proxy.v1.RemoveL7PolicyResponse\x12J\n" +
"\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB:Z8git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3"
"\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB8Z6git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3"
var (
file_proto_mc_proxy_v1_admin_proto_rawDescOnce sync.Once

View File

@@ -304,7 +304,7 @@ func TestDuration(t *testing.T) {
if err := d.UnmarshalText([]byte("5s")); err != nil {
t.Fatalf("unexpected error: %v", err)
}
if d.Duration.Seconds() != 5 {
if d.Seconds() != 5 {
t.Fatalf("got %v, want 5s", d.Duration)
}
}
@@ -340,7 +340,7 @@ level = "info"
if cfg.Log.Level != "debug" {
t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug")
}
if cfg.Proxy.IdleTimeout.Duration.Seconds() != 600 {
if cfg.Proxy.IdleTimeout.Seconds() != 600 {
t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration)
}
if cfg.Database.Path != "/override/test.db" {

View File

@@ -17,7 +17,7 @@ func openTestDB(t *testing.T) *Store {
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { store.Close() })
t.Cleanup(func() { _ = store.Close() })
return store
}
@@ -250,15 +250,30 @@ func TestRouteL7Fields(t *testing.T) {
}
}
func TestRouteDuplicateHostname(t *testing.T) {
func TestRouteUpsert(t *testing.T) {
store := openTestDB(t)
listenerID, _ := store.CreateListener(":443", false, 0)
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false); err != nil {
t.Fatalf("first create: %v", err)
}
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l4", "", "", false, false); err == nil {
t.Fatal("expected error for duplicate hostname on same listener")
// Same (listener, hostname) with different backend — should upsert, not error.
if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l7", "/cert.pem", "/key.pem", false, false); err != nil {
t.Fatalf("upsert: %v", err)
}
routes, err := store.ListRoutes(listenerID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
if len(routes) != 1 {
t.Fatalf("expected 1 route after upsert, got %d", len(routes))
}
if routes[0].Backend != "127.0.0.1:9443" {
t.Fatalf("expected updated backend, got %q", routes[0].Backend)
}
if routes[0].Mode != "l7" {
t.Fatalf("expected updated mode, got %q", routes[0].Mode)
}
}
@@ -266,8 +281,8 @@ func TestRouteCascadeDelete(t *testing.T) {
store := openTestDB(t)
listenerID, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
_, _ = store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
_, _ = store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
if err := store.DeleteListener(listenerID); err != nil {
t.Fatalf("delete listener: %v", err)
@@ -412,7 +427,7 @@ func TestSeed(t *testing.T) {
func TestSnapshot(t *testing.T) {
store := openTestDB(t)
store.CreateListener(":443", false, 0)
_, _ = store.CreateListener(":443", false, 0)
dest := filepath.Join(t.TempDir(), "backup.db")
if err := store.Snapshot(dest); err != nil {
@@ -424,7 +439,7 @@ func TestSnapshot(t *testing.T) {
if err != nil {
t.Fatalf("open backup: %v", err)
}
defer backup.Close()
defer func() { _ = backup.Close() }()
if err := backup.Migrate(); err != nil {
t.Fatalf("migrate backup: %v", err)
@@ -463,7 +478,7 @@ func TestMigrationV2Upgrade(t *testing.T) {
if err != nil {
t.Fatalf("open: %v", err)
}
t.Cleanup(func() { store.Close() })
t.Cleanup(func() { _ = store.Close() })
// Run full migrations (v1 + v2).
if err := store.Migrate(); err != nil {
@@ -556,10 +571,10 @@ func TestL7PolicyCascadeDelete(t *testing.T) {
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
store.CreateL7Policy(rid, "block_user_agent", "Bot")
_, _ = store.CreateL7Policy(rid, "block_user_agent", "Bot")
// Deleting the route should cascade-delete its policies.
store.DeleteRoute(lid, "api.test")
_ = store.DeleteRoute(lid, "api.test")
policies, _ := store.ListL7Policies(rid)
if len(policies) != 0 {
@@ -585,7 +600,7 @@ func TestGetRouteID(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
_, _ = store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
rid, err := store.GetRouteID(lid, "api.test")
if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Store) ListFirewallRules() ([]FirewallRule, error) {
if err != nil {
return nil, fmt.Errorf("querying firewall rules: %w", err)
}
defer rows.Close()
defer func() { _ = rows.Close() }()
var rules []FirewallRule
for rows.Next() {

View File

@@ -19,7 +19,7 @@ func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
if err != nil {
return nil, fmt.Errorf("querying l7 policies: %w", err)
}
defer rows.Close()
defer func() { _ = rows.Close() }()
var policies []L7Policy
for rows.Next() {

View File

@@ -16,7 +16,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
if err != nil {
return nil, fmt.Errorf("querying listeners: %w", err)
}
defer rows.Close()
defer func() { _ = rows.Close() }()
var listeners []Listener
for rows.Next() {

View File

@@ -25,7 +25,7 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
if err != nil {
return nil, fmt.Errorf("querying routes: %w", err)
}
defer rows.Close()
defer func() { _ = rows.Close() }()
var routes []Route
for rows.Next() {
@@ -39,11 +39,20 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
return routes, rows.Err()
}
// CreateRoute inserts a route and returns its ID.
// CreateRoute inserts or updates a route and returns its ID. If a route
// for the same (listener_id, hostname) already exists, it is updated
// with the new values (upsert), making the operation idempotent.
func (s *Store) CreateRoute(listenerID int64, hostname, backend, mode, tlsCert, tlsKey string, backendTLS, sendProxyProtocol bool) (int64, error) {
result, err := s.db.Exec(
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(listener_id, hostname) DO UPDATE SET
backend = excluded.backend,
mode = excluded.mode,
tls_cert = excluded.tls_cert,
tls_key = excluded.tls_key,
backend_tls = excluded.backend_tls,
send_proxy_protocol = excluded.send_proxy_protocol`,
listenerID, hostname, backend, mode, tlsCert, tlsKey, backendTLS, sendProxyProtocol,
)
if err != nil {

View File

@@ -14,7 +14,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if err != nil {
return fmt.Errorf("beginning seed transaction: %w", err)
}
defer tx.Rollback()
defer func() { _ = tx.Rollback() }()
for _, l := range listeners {
result, err := tx.Exec(

View File

@@ -234,7 +234,7 @@ func (f *Firewall) loadGeoDB(path string) error {
f.mu.Unlock()
if old != nil {
old.Close()
_ = old.Close()
}
return nil
}

View File

@@ -11,7 +11,7 @@ func TestEmptyFirewall(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
for _, a := range addrs {
@@ -27,7 +27,7 @@ func TestIPBlocking(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
tests := []struct {
addr string
@@ -52,7 +52,7 @@ func TestCIDRBlocking(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
tests := []struct {
addr string
@@ -78,7 +78,7 @@ func TestIPv4MappedIPv6(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("::ffff:192.0.2.1")
if !fw.Blocked(addr) {
@@ -105,7 +105,7 @@ func TestCombinedRules(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
tests := []struct {
addr string
@@ -130,7 +130,7 @@ func TestRateLimitBlocking(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1")
@@ -151,7 +151,7 @@ func TestRateLimitBlocklistFirst(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
blockedAddr := netip.MustParseAddr("10.0.0.1")
otherAddr := netip.MustParseAddr("10.0.0.2")
@@ -175,7 +175,7 @@ func TestBlockedWithReason(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
tests := []struct {
addr string
@@ -216,7 +216,7 @@ func TestRuntimeMutation(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
defer fw.Close()
defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1")
if fw.Blocked(addr) {

View File

@@ -38,7 +38,7 @@ func (rl *rateLimiter) Allow(addr netip.Addr) bool {
now := rl.now().UnixNano()
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
entry := val.(*rateLimitEntry)
entry, _ := val.(*rateLimitEntry)
windowStart := entry.start.Load()
if now-windowStart >= rl.window.Nanoseconds() {
@@ -70,7 +70,7 @@ func (rl *rateLimiter) cleanup() {
case <-ticker.C:
cutoff := rl.now().Add(-2 * rl.window).UnixNano()
rl.entries.Range(func(key, value any) bool {
entry := value.(*rateLimitEntry)
entry, _ := value.(*rateLimitEntry)
if entry.start.Load() < cutoff {
rl.entries.Delete(key)
}

View File

@@ -53,7 +53,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
os.Remove(path)
_ = os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
@@ -61,7 +61,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
}
if err := os.Chmod(path, 0600); err != nil {
ln.Close()
_ = ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
@@ -144,10 +144,10 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
}
}
// Write-through: DB first, then memory.
// Write-through: DB first (upsert), then memory.
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode,
req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
return nil, status.Errorf(codes.Internal, "%v", err)
}
info := server.RouteInfo{
@@ -158,10 +158,7 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
BackendTLS: req.Route.BackendTls,
SendProxyProtocol: req.Route.SendProxyProtocol,
}
if err := ls.AddRoute(hostname, info); err != nil {
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
ls.AddRoute(hostname, info)
a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode)
return &pb.AddRouteResponse{}, nil
@@ -446,7 +443,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.
}
listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr,
RouteCount: int32(len(routes)),
RouteCount: int32(len(routes)), //nolint:gosec // route count can never exceed int32
ActiveConnections: ls.ActiveConnections.Load(),
ProxyProtocol: ls.ProxyProtocol,
MaxConnections: ls.MaxConnections,

View File

@@ -39,7 +39,7 @@ func setup(t *testing.T) *testEnv {
if err != nil {
t.Fatalf("open db: %v", err)
}
t.Cleanup(func() { store.Close() })
t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
@@ -130,7 +130,7 @@ func setup(t *testing.T) *testEnv {
if err != nil {
t.Fatalf("dial bufconn: %v", err)
}
t.Cleanup(func() { conn.Close() })
t.Cleanup(func() { _ = conn.Close() })
return &testEnv{
client: pb.NewProxyAdminServiceClient(conn),
@@ -229,19 +229,29 @@ func TestAddRoute(t *testing.T) {
}
}
func TestAddRouteDuplicate(t *testing.T) {
func TestAddRouteUpsert(t *testing.T) {
env := setup(t)
ctx := context.Background()
// a.test already exists from setup(). Adding again with a different
// backend should succeed (upsert) and update the route.
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
})
if err == nil {
t.Fatal("expected error for duplicate route")
if err != nil {
t.Fatalf("upsert should succeed: %v", err)
}
// Verify the route was updated, not duplicated.
routes, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("list routes: %v", err)
}
for _, r := range routes.Routes {
if r.Hostname == "a.test" && r.Backend != "127.0.0.1:1111" {
t.Fatalf("expected updated backend 127.0.0.1:1111, got %q", r.Backend)
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
t.Fatalf("expected AlreadyExists, got %v", err)
}
}
@@ -775,7 +785,7 @@ func TestRemoveL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
_, _ = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},

View File

@@ -13,27 +13,27 @@ func TestPrefixConnRead(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("WORLD"))
defer func() { _ = conn.Close() }()
_, _ = conn.Write([]byte("WORLD"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("HELLO"))
// Read all data: should get "HELLOWORLD".
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
@@ -48,22 +48,22 @@ func TestPrefixConnSmallReads(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("CD"))
defer func() { _ = conn.Close() }()
_, _ = conn.Write([]byte("CD"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("AB"))
@@ -79,7 +79,7 @@ func TestPrefixConnSmallReads(t *testing.T) {
}
// Now reads come from the underlying conn.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
rest, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
@@ -94,25 +94,25 @@ func TestPrefixConnEmptyPrefix(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
go func() {
conn, err := ln.Accept()
if err != nil {
return
}
defer conn.Close()
conn.Write([]byte("DATA"))
defer func() { _ = conn.Close() }()
_, _ = conn.Write([]byte("DATA"))
}()
conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, nil)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
@@ -127,12 +127,12 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
go func() {
conn, _ := ln.Accept()
if conn != nil {
conn.Close()
_ = conn.Close()
}
}()
@@ -140,7 +140,7 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil {
t.Fatalf("dial: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("X"))

View File

@@ -120,7 +120,7 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
ReadHeaderTimeout: 30 * time.Second,
}
singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn)
_ = srv.Serve(singleConn)
}
return nil
@@ -213,7 +213,7 @@ func dialBackend(ctx context.Context, network, addr string, timeout time.Duratio
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close()
_ = conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
}
}

View File

@@ -58,8 +58,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil {
t.Fatalf("creating cert file: %v", err)
}
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close()
_ = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
_ = certFile.Close()
keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil {
@@ -69,8 +69,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil {
t.Fatalf("creating key file: %v", err)
}
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
keyFile.Close()
_ = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
_ = keyFile.Close()
return certPath, keyPath
}
@@ -91,11 +91,11 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() {
srv.Close()
ln.Close()
_ = srv.Close()
_ = ln.Close()
})
go srv.Serve(ln)
go func() { _ = srv.Serve(ln) }()
return ln.Addr().String()
}
@@ -118,7 +118,7 @@ func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
t.Cleanup(func() { conn.Close() })
t.Cleanup(func() { _ = conn.Close() })
// Create an HTTP/2 client transport over this single connection.
tr := &http2.Transport{}
@@ -142,29 +142,13 @@ func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return s.cc.RoundTrip(req)
}
// serveL7Route starts l7.Serve in a goroutine for a single connection.
// Returns when the goroutine completes.
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
ctx := context.Background()
go func() {
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
if l7Err != nil {
t.Logf("l7.Serve: %v", l7Err)
}
}()
}
func TestL7H2CBackend(t *testing.T) {
certPath, keyPath := testCert(t, "l7.test")
// Start an h2c backend.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok")
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
_, _ = fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
}))
// Start a TCP listener for the L7 proxy.
@@ -172,7 +156,7 @@ func TestL7H2CBackend(t *testing.T) {
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -190,17 +174,17 @@ func TestL7H2CBackend(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
// No peeked bytes — the client is connecting directly with TLS.
Serve(context.Background(), conn, nil, route, clientAddr, logger)
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
// Connect as an HTTP/2 TLS client.
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo"))
resp, err := client.Get("https://l7.test/foo")
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -221,7 +205,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
// Backend that echoes the forwarding headers.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
_, _ = fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
r.Header.Get("X-Forwarded-For"),
r.Header.Get("X-Forwarded-Proto"),
r.Header.Get("X-Real-IP"),
@@ -232,7 +216,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -248,7 +232,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
@@ -256,7 +240,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
@@ -274,13 +258,13 @@ func TestL7BackendUnreachable(t *testing.T) {
t.Fatalf("listen: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
_ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: deadAddr,
@@ -296,7 +280,7 @@ func TestL7BackendUnreachable(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
@@ -304,7 +288,7 @@ func TestL7BackendUnreachable(t *testing.T) {
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want 502", resp.StatusCode)
@@ -342,14 +326,14 @@ func TestL7MultipleRequests(t *testing.T) {
var reqCount int
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
_, _ = fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -365,7 +349,7 @@ func TestL7MultipleRequests(t *testing.T) {
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger)
_ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
@@ -378,7 +362,7 @@ func TestL7MultipleRequests(t *testing.T) {
t.Fatalf("GET %s: %v", path, err)
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
_ = resp.Body.Close()
want := fmt.Sprintf("req=%d path=%s", i+1, path)
if string(body) != want {
@@ -396,14 +380,14 @@ func TestL7LargeResponse(t *testing.T) {
largeBody[i] = byte(i % 256)
}
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(largeBody)
_, _ = w.Write(largeBody)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -418,7 +402,7 @@ func TestL7LargeResponse(t *testing.T) {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
@@ -426,7 +410,7 @@ func TestL7LargeResponse(t *testing.T) {
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if len(body) != len(largeBody) {
@@ -455,7 +439,7 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -470,7 +454,7 @@ func TestL7GRPCTrailers(t *testing.T) {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
@@ -480,10 +464,10 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil {
t.Fatalf("POST: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
// Read body to trigger trailer delivery.
io.ReadAll(resp.Body)
_, _ = io.ReadAll(resp.Body)
// Verify trailers were forwarded through the proxy.
grpcStatus := resp.Trailer.Get("Grpc-Status")
@@ -500,14 +484,14 @@ func TestL7HTTP11Fallback(t *testing.T) {
certPath, keyPath := testCert(t, "http11.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "proto=%s", r.Proto)
_, _ = fmt.Fprintf(w, "proto=%s", r.Proto)
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -522,7 +506,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
// Connect with HTTP/1.1 only (no h2 ALPN).
@@ -538,7 +522,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -556,14 +540,14 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
certPath, keyPath := testCert(t, "policy.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "should-not-reach")
_, _ = fmt.Fprint(w, "should-not-reach")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -581,7 +565,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
@@ -591,7 +575,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 403 {
t.Fatalf("status = %d, want 403", resp.StatusCode)
@@ -602,14 +586,14 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
certPath, keyPath := testCert(t, "reqhdr.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok")
_, _ = fmt.Fprint(w, "ok")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
defer func() { _ = proxyLn.Close() }()
route := RouteConfig{
Backend: backendAddr,
@@ -630,7 +614,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
}
go func() {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
_ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
}
}()
@@ -641,7 +625,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil {
t.Fatalf("GET without header: %v", err)
}
resp1.Body.Close()
_ = resp1.Body.Close()
if resp1.StatusCode != 403 {
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
}
@@ -654,7 +638,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil {
t.Fatalf("GET with header: %v", err)
}
defer resp2.Body.Close()
defer func() { _ = resp2.Body.Close() }()
body, _ := io.ReadAll(resp2.Body)
if resp2.StatusCode != 200 {
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)

View File

@@ -2,6 +2,7 @@ package proxy
import (
"context"
"errors"
"io"
"net"
"sync"
@@ -31,8 +32,8 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
go func() {
<-ctx.Done()
client.Close()
backend.Close()
_ = client.Close()
_ = backend.Close()
}()
var (
@@ -50,7 +51,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
_ = hc.CloseWrite()
}
}()
@@ -60,7 +61,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite()
_ = hc.CloseWrite()
}
}()
@@ -85,10 +86,10 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
var total int64
for {
src.SetReadDeadline(time.Now().Add(idleTimeout))
_ = src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf)
if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout))
_ = dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr])
total += int64(nw)
if writeErr != nil {
@@ -96,7 +97,7 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
}
}
if readErr != nil {
if readErr == io.EOF {
if errors.Is(readErr, io.EOF) {
return total, nil
}
return total, readErr

View File

@@ -16,7 +16,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client")
@@ -29,7 +29,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Read everything the backend receives.
received, _ := io.ReadAll(conn)
@@ -40,21 +40,21 @@ func TestRelayBasic(t *testing.T) {
}()
// Restructure: use a more controlled flow.
backendLn.Close()
_ = backendLn.Close()
// Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
defer backendLn2.Close()
defer func() { _ = backendLn2.Close() }()
go func() {
conn, err := backendLn2.Accept()
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData))
@@ -62,11 +62,11 @@ func TestRelayBasic(t *testing.T) {
backendDone <- buf[:n]
// Send response.
conn.Write(backendData)
_, _ = conn.Write(backendData)
// Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite()
_ = tc.CloseWrite()
}
}()
@@ -81,7 +81,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer clientLn.Close()
defer func() { _ = clientLn.Close() }()
clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil {
@@ -94,9 +94,9 @@ func TestRelayBasic(t *testing.T) {
// Client sends data then closes write.
go func() {
clientConn.Write(clientData)
_, _ = clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite()
_ = tc.CloseWrite()
}
}()
@@ -114,7 +114,7 @@ func TestRelayBasic(t *testing.T) {
}
// Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData)
@@ -131,12 +131,12 @@ func TestRelayBasic(t *testing.T) {
func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP.
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
defer func() { _ = clientA.Close() }()
defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
defer func() { _ = backendA.Close() }()
defer func() { _ = backendB.Close() }()
start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
@@ -154,18 +154,18 @@ func TestRelayIdleTimeout(t *testing.T) {
func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
defer func() { _ = clientA.Close() }()
defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
defer func() { _ = backendA.Close() }()
defer func() { _ = backendB.Close() }()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
Relay(ctx, clientB, backendA, nil, time.Minute)
_, _ = Relay(ctx, clientB, backendA, nil, time.Minute)
close(done)
}()
@@ -185,12 +185,12 @@ func TestRelayContextCancel(t *testing.T) {
func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t)
defer clientA.Close()
defer clientB.Close()
defer func() { _ = clientA.Close() }()
defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t)
defer backendA.Close()
defer backendB.Close()
defer func() { _ = backendA.Close() }()
defer func() { _ = backendB.Close() }()
// 1 MB of random data.
data := make([]byte, 1<<20)
@@ -199,9 +199,9 @@ func TestRelayLargeTransfer(t *testing.T) {
}
go func() {
clientA.Write(data)
_, _ = clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite()
_ = tc.CloseWrite()
}
}()
@@ -211,14 +211,14 @@ func TestRelayLargeTransfer(t *testing.T) {
for {
n, err := backendB.Read(buf)
if n > 0 {
backendB.Write(buf[:n])
_, _ = backendB.Write(buf[:n])
}
if err != nil {
break
}
}
if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite()
_ = tc.CloseWrite()
}
}()
@@ -240,7 +240,7 @@ func tcpPair(t *testing.T) (net.Conn, net.Conn) {
if err != nil {
t.Fatalf("listen: %v", err)
}
defer ln.Close()
defer func() { _ = ln.Close() }()
var serverConn net.Conn
done := make(chan struct{})

View File

@@ -50,8 +50,8 @@ const (
// It reads only the exact bytes of the PROXY header, leaving the connection
// positioned at the first byte after the header (e.g., TLS ClientHello).
func Parse(conn net.Conn, deadline time.Time) (Header, error) {
conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{})
_ = conn.SetReadDeadline(deadline)
defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read the first byte to determine version.
var first [1]byte

View File

@@ -16,7 +16,7 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { ln.Close() })
t.Cleanup(func() { _ = ln.Close() })
ch := make(chan net.Conn, 1)
go func() {
@@ -31,10 +31,10 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil {
t.Fatalf("dial: %v", err)
}
t.Cleanup(func() { w.Close() })
t.Cleanup(func() { _ = w.Close() })
r := <-ch
t.Cleanup(func() { r.Close() })
t.Cleanup(func() { _ = r.Close() })
return r, w
}
@@ -42,7 +42,7 @@ func TestParseV1TCP4(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
_, _ = writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -68,7 +68,7 @@ func TestParseV1TCP6(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
_, _ = writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -98,7 +98,7 @@ func TestParseV2TCP4(t *testing.T) {
buf = append(buf, 10, 0, 0, 1) // dst IP
buf = binary.BigEndian.AppendUint16(buf, 12345)
buf = binary.BigEndian.AppendUint16(buf, 443)
writer.Write(buf)
_, _ = writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -137,7 +137,7 @@ func TestParseV2TCP6(t *testing.T) {
buf = append(buf, dst[:]...)
buf = binary.BigEndian.AppendUint16(buf, 56324)
buf = binary.BigEndian.AppendUint16(buf, 8443)
writer.Write(buf)
_, _ = writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -163,7 +163,7 @@ func TestParseV2Local(t *testing.T) {
buf = append(buf, 0x20) // version 2, LOCAL command
buf = append(buf, 0x00) // unspec family, unspec protocol
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
_, _ = writer.Write(buf)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -195,7 +195,7 @@ func TestParseV1Malformed(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte(tt.data))
_, _ = writer.Write([]byte(tt.data))
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
@@ -211,7 +211,7 @@ func TestParseV2Malformed(t *testing.T) {
go func() {
bad := make([]byte, v2HeaderLen)
bad[0] = v2Signature[0] // first byte matches but rest doesn't
writer.Write(bad)
_, _ = writer.Write(bad)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
@@ -227,7 +227,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x31) // version 3, PROXY command
buf = append(buf, 0x11)
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
_, _ = writer.Write(buf)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
@@ -244,8 +244,8 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x11) // AF_INET, STREAM
buf = binary.BigEndian.AppendUint16(buf, 12)
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
writer.Write(buf)
writer.Close()
_, _ = writer.Write(buf)
_ = writer.Close()
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
@@ -261,7 +261,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x21) // version 2, PROXY
buf = append(buf, 0x31) // AF_UNIX (3), STREAM
buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf)
_, _ = writer.Write(buf)
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {
@@ -332,7 +332,7 @@ func TestRoundTripV2IPv4(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
WriteV2(writer, src, dst)
_ = WriteV2(writer, src, dst)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -361,7 +361,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
WriteV2(writer, src, dst)
_ = WriteV2(writer, src, dst)
}()
hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -380,7 +380,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
func TestParseGarbageFirstByte(t *testing.T) {
reader, writer := pipeWithDeadline(t)
go func() {
writer.Write([]byte{0xFF, 0x00, 0x01})
_, _ = writer.Write([]byte{0xFF, 0x00, 0x01})
}()
_, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil {

View File

@@ -69,19 +69,15 @@ func (ls *ListenerState) Routes() map[string]RouteInfo {
return m
}
// AddRoute adds a route to the listener. Returns an error if the hostname
// already exists.
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error {
// AddRoute adds or updates a route on the listener. If a route for the
// hostname already exists, it is replaced (upsert).
func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if _, ok := ls.routes[key]; ok {
return fmt.Errorf("route %q already exists", hostname)
}
ls.routes[key] = info
return nil
}
// RemoveRoute removes a route from the listener. Returns an error if the
@@ -235,7 +231,7 @@ func (s *Server) Run(ctx context.Context) error {
ln, err := net.Listen("tcp", ls.Addr)
if err != nil {
for _, l := range netListeners {
l.Close()
_ = l.Close()
}
return fmt.Errorf("listening on %s: %w", ls.Addr, err)
}
@@ -253,7 +249,7 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("shutting down")
for _, ln := range netListeners {
ln.Close()
_ = ln.Close()
}
done := make(chan struct{})
@@ -272,7 +268,7 @@ func (s *Server) Run(ctx context.Context) error {
<-done
}
s.fw.Close()
_ = s.fw.Close()
return nil
}
@@ -294,7 +290,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
// Enforce per-listener connection limit.
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
conn.Close()
_ = conn.Close()
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
continue
}
@@ -311,7 +307,7 @@ func (s *Server) forceCloseAll() {
for _, ls := range s.listeners {
ls.connMu.Lock()
for conn := range ls.activeConns {
conn.Close()
_ = conn.Close()
}
ls.connMu.Unlock()
}
@@ -321,7 +317,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
defer s.wg.Done()
defer ls.ActiveConnections.Add(-1)
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
defer conn.Close()
defer func() { _ = conn.Close() }()
ls.connMu.Lock()
ls.activeConns[conn] = struct{}{}
@@ -392,7 +388,7 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return
}
defer backendConn.Close()
defer func() { _ = backendConn.Close() }()
// Send PROXY protocol v2 header to backend if configured.
if route.SendProxyProtocol {

View File

@@ -43,8 +43,8 @@ func echoServer(t *testing.T, ln net.Listener) {
if err != nil {
return
}
defer conn.Close()
io.Copy(conn, conn)
defer func() { _ = conn.Close() }()
_, _ = io.Copy(conn, conn)
}
// newTestServer creates a Server with the given listener data and no firewall rules.
@@ -92,7 +92,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
// Pick a free port for the proxy listener.
@@ -101,7 +101,7 @@ func TestProxyRoundTrip(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -121,7 +121,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil {
@@ -130,7 +130,7 @@ func TestProxyRoundTrip(t *testing.T) {
// The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
@@ -157,7 +157,7 @@ func TestNoRouteResets(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -176,7 +176,7 @@ func TestNoRouteResets(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil {
@@ -184,7 +184,7 @@ func TestNoRouteResets(t *testing.T) {
}
// The proxy should close the connection (no route match).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed, but read succeeded")
@@ -197,7 +197,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
@@ -205,7 +205,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil {
return
}
conn.Close()
_ = conn.Close()
reached <- struct{}{}
}()
@@ -214,7 +214,7 @@ func TestFirewallBlocks(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
@@ -245,7 +245,7 @@ func TestFirewallBlocks(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
srv.Run(ctx)
_ = srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
@@ -253,13 +253,13 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test")
conn.Write(hello)
_, _ = conn.Write(hello)
// Connection should be closed (blocked by firewall).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
@@ -283,7 +283,7 @@ func TestNotTLSResets(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -300,12 +300,12 @@ func TestNotTLSResets(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Send HTTP, not TLS.
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
_, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for non-TLS data")
@@ -318,7 +318,7 @@ func TestConnectionTracking(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
var backendConns []net.Conn
var mu sync.Mutex
@@ -332,7 +332,7 @@ func TestConnectionTracking(t *testing.T) {
backendConns = append(backendConns, conn)
mu.Unlock()
// Hold connection open, drain input.
go io.Copy(io.Discard, conn)
go func() { _, _ = io.Copy(io.Discard, conn) }()
}
}()
@@ -341,7 +341,7 @@ func TestConnectionTracking(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -382,10 +382,10 @@ func TestConnectionTracking(t *testing.T) {
}
// Close one client and its corresponding backend connection.
clientConns[0].Close()
_ = clientConns[0].Close()
mu.Lock()
if len(backendConns) > 0 {
backendConns[0].Close()
_ = backendConns[0].Close()
}
mu.Unlock()
@@ -402,10 +402,10 @@ func TestConnectionTracking(t *testing.T) {
}
// Clean up.
clientConns[1].Close()
_ = clientConns[1].Close()
mu.Lock()
for _, c := range backendConns {
c.Close()
_ = c.Close()
}
mu.Unlock()
}
@@ -416,13 +416,13 @@ func TestMultipleListeners(t *testing.T) {
if err != nil {
t.Fatalf("backend A listen: %v", err)
}
defer backendA.Close()
defer func() { _ = backendA.Close() }()
backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("backend B listen: %v", err)
}
defer backendB.Close()
defer func() { _ = backendB.Close() }()
// Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) {
@@ -430,10 +430,10 @@ func TestMultipleListeners(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Drain the incoming data, then write identity.
go io.Copy(io.Discard, conn)
conn.Write([]byte(id))
go func() { _, _ = io.Copy(io.Discard, conn) }()
_, _ = conn.Write([]byte(id))
}
go serve(backendA, "A")
go serve(backendB, "B")
@@ -444,14 +444,14 @@ func TestMultipleListeners(t *testing.T) {
t.Fatalf("finding free port 1: %v", err)
}
addr1 := ln1.Addr().String()
ln1.Close()
_ = ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port 2: %v", err)
}
addr2 := ln2.Addr().String()
ln2.Close()
_ = ln2.Close()
srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
@@ -467,12 +467,12 @@ func TestMultipleListeners(t *testing.T) {
if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("svc.test")
conn.Write(hello)
_, _ = conn.Write(hello)
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we
@@ -508,7 +508,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -516,7 +516,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -537,7 +537,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil {
@@ -545,7 +545,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
}
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
@@ -558,14 +558,14 @@ func TestBackendUnreachable(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
deadAddr := ln.Addr().String()
ln.Close()
_ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -584,13 +584,13 @@ func TestBackendUnreachable(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("dead.test")
conn.Write(hello)
_, _ = conn.Write(hello)
// Proxy should close the connection after failing to dial backend.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable")
@@ -603,15 +603,15 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
go func() {
conn, err := backendLn.Accept()
if err != nil {
return
}
defer conn.Close()
io.Copy(io.Discard, conn)
defer func() { _ = conn.Close() }()
_, _ = io.Copy(io.Discard, conn)
}()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -619,7 +619,7 @@ func TestGracefulShutdown(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil {
@@ -649,10 +649,10 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("hold.test")
conn.Write(hello)
_, _ = conn.Write(hello)
time.Sleep(50 * time.Millisecond)
// Trigger shutdown.
@@ -679,16 +679,12 @@ func TestListenerStateRoutes(t *testing.T) {
}
// AddRoute
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil {
t.Fatalf("AddRoute: %v", err)
}
ls.AddRoute("b.test", l4Route("127.0.0.1:2"))
// AddRoute duplicate
if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil {
t.Fatal("expected error for duplicate route")
}
// AddRoute duplicate (upsert — updates in place)
ls.AddRoute("b.test", l4Route("127.0.0.1:3"))
// Routes snapshot
// Routes snapshot — still 2 routes, the duplicate replaced the first.
routes := ls.Routes()
if len(routes) != 2 {
t.Fatalf("expected 2 routes, got %d", len(routes))
@@ -708,8 +704,8 @@ func TestListenerStateRoutes(t *testing.T) {
if len(routes) != 1 {
t.Fatalf("expected 1 route, got %d", len(routes))
}
if routes["b.test"].Backend != "127.0.0.1:2" {
t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend)
if routes["b.test"].Backend != "127.0.0.1:3" {
t.Fatalf("expected b.test → 127.0.0.1:3 (upserted), got %q", routes["b.test"].Backend)
}
}
@@ -719,7 +715,7 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -727,7 +723,7 @@ func TestProxyProtocolReceive(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -747,22 +743,22 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Send PROXY v2 header followed by TLS ClientHello.
var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf,
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
conn.Write(ppBuf.Bytes())
_, _ = conn.Write(ppBuf.Bytes())
hello := buildClientHello("echo.test")
conn.Write(hello)
_, _ = conn.Write(hello)
// Backend should echo the ClientHello back (not the PROXY header).
echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err)
}
@@ -774,7 +770,7 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -794,13 +790,13 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Send garbage instead of a valid PROXY header.
conn.Write([]byte("NOT A PROXY HEADER\r\n"))
_, _ = conn.Write([]byte("NOT A PROXY HEADER\r\n"))
// Connection should be closed.
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed for invalid PROXY header")
@@ -813,7 +809,7 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1)
go func() {
@@ -821,9 +817,9 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Read all available data; the proxy sends PROXY header + ClientHello.
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var all []byte
buf := make([]byte, 4096)
for {
@@ -845,7 +841,7 @@ func TestProxyProtocolSend(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -868,10 +864,10 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("pp.test")
conn.Write(hello)
_, _ = conn.Write(hello)
// The backend should receive: PROXY v2 header + ClientHello.
select {
@@ -904,7 +900,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1)
go func() {
@@ -912,7 +908,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil {
return
}
defer conn.Close()
defer func() { _ = conn.Close() }()
buf := make([]byte, 4096)
n, _ := conn.Read(buf)
received <- buf[:n]
@@ -923,7 +919,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -942,10 +938,10 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
hello := buildClientHello("nopp.test")
conn.Write(hello)
_, _ = conn.Write(hello)
select {
case data := <-received:
@@ -964,7 +960,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
@@ -972,7 +968,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil {
return
}
conn.Close()
_ = conn.Close()
reached <- struct{}{}
}()
@@ -981,7 +977,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
t.Fatalf("finding free port: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
// Block 203.0.113.50 (the "real" client IP from PROXY header).
// 127.0.0.1 (the actual TCP peer) is NOT blocked.
@@ -1014,7 +1010,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
srv.Run(ctx)
_ = srv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
@@ -1022,19 +1018,19 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil {
t.Fatalf("dial proxy: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
// Send PROXY v2 with the blocked real IP.
var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf,
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"),
)
conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test"))
_, _ = conn.Write(ppBuf.Bytes())
_, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be closed (firewall blocks real IP).
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed by firewall")
@@ -1060,7 +1056,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
go func() {
for {
@@ -1068,7 +1064,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil {
return
}
go io.Copy(io.Discard, conn)
go func() { _, _ = io.Copy(io.Discard, conn) }()
}
}()
@@ -1077,7 +1073,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -1100,7 +1096,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil {
t.Fatalf("dial %d: %v", i, err)
}
conn.Write(buildClientHello("limit.test"))
_, _ = conn.Write(buildClientHello("limit.test"))
conns = append(conns, conn)
}
time.Sleep(100 * time.Millisecond)
@@ -1110,16 +1106,16 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil {
t.Fatalf("dial 3: %v", err)
}
conn3.Write(buildClientHello("limit.test"))
conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
_, _ = conn3.Write(buildClientHello("limit.test"))
_ = conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn3.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected 3rd connection to be closed due to limit")
}
conn3.Close()
_ = conn3.Close()
// Close one existing connection.
conns[0].Close()
_ = conns[0].Close()
time.Sleep(200 * time.Millisecond)
// Now a new connection should succeed.
@@ -1127,8 +1123,8 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil {
t.Fatalf("dial 4: %v", err)
}
defer conn4.Close()
conn4.Write(buildClientHello("limit.test"))
defer func() { _ = conn4.Close() }()
_, _ = conn4.Write(buildClientHello("limit.test"))
// Give it time to be proxied.
time.Sleep(100 * time.Millisecond)
@@ -1138,7 +1134,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
// Clean up.
for _, c := range conns[1:] {
c.Close()
_ = c.Close()
}
}
@@ -1155,7 +1151,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
// h2c backend on origin that echoes the X-Forwarded-For.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
_, _ = fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
}))
// Origin proxy: proxy_protocol=true listener, L7 route to backend.
@@ -1164,7 +1160,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("origin listen: %v", err)
}
originAddr := originLn.Addr().String()
originLn.Close()
_ = originLn.Close()
originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
originCfg := &config.Config{
@@ -1196,7 +1192,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
originWg.Add(1)
go func() {
defer originWg.Done()
originSrv.Run(originCtx)
_ = originSrv.Run(originCtx)
}()
time.Sleep(50 * time.Millisecond)
defer func() {
@@ -1210,7 +1206,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("edge listen: %v", err)
}
edgeAddr := edgeLn.Addr().String()
edgeLn.Close()
_ = edgeLn.Close()
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
edgeSrv := New(originCfg, edgeFw, []ListenerData{
@@ -1232,7 +1228,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
edgeWg.Add(1)
go func() {
defer edgeWg.Done()
edgeSrv.Run(edgeCtx)
_ = edgeSrv.Run(edgeCtx)
}()
time.Sleep(50 * time.Millisecond)
defer func() {
@@ -1253,7 +1249,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil {
t.Fatalf("TLS dial edge: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
@@ -1266,7 +1262,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 {
@@ -1289,7 +1285,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil {
t.Fatalf("backend listen: %v", err)
}
defer backendLn.Close()
defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1)
go func() {
@@ -1297,7 +1293,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil {
return
}
conn.Close()
_ = conn.Close()
reached <- struct{}{}
}()
@@ -1306,7 +1302,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
t.Fatalf("origin listen: %v", err)
}
originAddr := originLn.Addr().String()
originLn.Close()
_ = originLn.Close()
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
@@ -1334,7 +1330,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
originSrv.Run(ctx)
_ = originSrv.Run(ctx)
}()
time.Sleep(50 * time.Millisecond)
@@ -1343,18 +1339,18 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil {
t.Fatalf("dial origin: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf,
_ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("198.51.100.99:12345"),
netip.MustParseAddrPort("10.0.0.1:443"),
)
conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test"))
_, _ = conn.Write(ppBuf.Bytes())
_, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be dropped by firewall.
conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1))
if err == nil {
t.Fatal("expected connection to be closed")
@@ -1396,12 +1392,12 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem")
cf, _ := os.Create(certPath)
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
cf.Close()
_ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
_ = cf.Close()
keyDER, _ := x509.MarshalECPrivateKey(key)
kf, _ := os.Create(keyPath)
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
kf.Close()
_ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
_ = kf.Close()
return
}
@@ -1417,8 +1413,8 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
if err != nil {
t.Fatalf("listen: %v", err)
}
t.Cleanup(func() { srv.Close(); ln.Close() })
go srv.Serve(ln)
t.Cleanup(func() { _ = srv.Close(); _ = ln.Close() })
go func() { _ = srv.Serve(ln) }()
return ln.Addr().String()
}
@@ -1426,7 +1422,7 @@ func TestL7ThroughServer(t *testing.T) {
certPath, keyPath := testCert(t, "l7srv.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
_, _ = fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1434,7 +1430,7 @@ func TestL7ThroughServer(t *testing.T) {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -1467,7 +1463,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil {
t.Fatalf("TLS dial: %v", err)
}
defer conn.Close()
defer func() { _ = conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn)
@@ -1480,7 +1476,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
@@ -1502,12 +1498,12 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil {
t.Fatalf("l4 backend listen: %v", err)
}
defer l4BackendLn.Close()
defer func() { _ = l4BackendLn.Close() }()
go echoServer(t, l4BackendLn)
// L7 backend: h2c HTTP server.
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "l7-response")
_, _ = fmt.Fprint(w, "l7-response")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1515,7 +1511,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
t.Fatalf("proxy listen: %v", err)
}
proxyAddr := proxyLn.Addr().String()
proxyLn.Close()
_ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{
{
@@ -1541,11 +1537,11 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil {
t.Fatalf("dial L4: %v", err)
}
defer l4Conn.Close()
defer func() { _ = l4Conn.Close() }()
hello := buildClientHello("l4echo.test")
l4Conn.Write(hello)
_, _ = l4Conn.Write(hello)
echoed := make([]byte, len(hello))
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_ = l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(l4Conn, echoed); err != nil {
t.Fatalf("L4 echo read: %v", err)
}
@@ -1563,7 +1559,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil {
t.Fatalf("TLS dial L7: %v", err)
}
defer l7Conn.Close()
defer func() { _ = l7Conn.Close() }()
tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(l7Conn)
@@ -1576,7 +1572,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil {
t.Fatalf("L7 RoundTrip: %v", err)
}
defer resp.Body.Close()
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if string(body) != "l7-response" {
@@ -1636,4 +1632,3 @@ func sniExtension(serverName string) []byte {
return ext
}

View File

@@ -17,8 +17,8 @@ const maxBufferSize = 16384 // 16 KiB, max TLS record size
//
// A read deadline is set on the connection to prevent slowloris attacks.
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{})
_ = conn.SetReadDeadline(deadline)
defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read TLS record header (5 bytes).
header := make([]byte, 5)

View File

@@ -22,13 +22,13 @@ func TestExtract(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
hello := buildClientHello(tt.sni)
go func() {
client.Write(hello)
_, _ = client.Write(hello)
}()
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
@@ -53,13 +53,13 @@ func TestExtract(t *testing.T) {
func TestExtractNoSNI(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
hello := buildClientHelloNoSNI()
go func() {
client.Write(hello)
_, _ = client.Write(hello)
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -70,11 +70,11 @@ func TestExtractNoSNI(t *testing.T) {
func TestExtractNotTLS(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
go func() {
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
_, _ = client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -85,13 +85,13 @@ func TestExtractNotTLS(t *testing.T) {
func TestExtractTruncated(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
go func() {
// Write just the TLS record header, then close.
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
client.Close()
_, _ = client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
_ = client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -102,15 +102,15 @@ func TestExtractTruncated(t *testing.T) {
func TestExtractOversizedRecord(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
go func() {
// Record header claiming a length larger than 16 KiB.
header := []byte{0x16, 0x03, 0x01}
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
client.Write(header)
client.Close()
_, _ = client.Write(header)
_ = client.Close()
}()
_, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -121,13 +121,13 @@ func TestExtractOversizedRecord(t *testing.T) {
func TestExtractMultipleExtensions(t *testing.T) {
client, server := net.Pipe()
defer client.Close()
defer server.Close()
defer func() { _ = client.Close() }()
defer func() { _ = server.Close() }()
hello := buildClientHelloWithExtraExtensions("target.example.com")
go func() {
client.Write(hello)
_, _ = client.Write(hello)
}()
hostname, _, err := Extract(server, time.Now().Add(5*time.Second))