Fix golangci-lint v2 compliance, make all passes clean

- Fix 314 errcheck violations (blank identifier for unrecoverable errors)
- Fix errorlint violation (errors.Is for io.EOF)
- Remove unused serveL7Route test helper
- Simplify Duration.Seconds() selectors in tests
- Remove unnecessary fmt.Sprintf in test
- Migrate exclusion rules from issues.exclusions to linters.exclusions (v2 schema)
- Add gosec test exclusions (G115, G304, G402, G705)
- Disable fieldalignment govet analyzer (optimization, not correctness)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-27 13:30:43 -07:00
parent 4f3249fdc3
commit a60e5cb86a
28 changed files with 343 additions and 354 deletions

View File

@@ -9,6 +9,20 @@ run:
tests: true tests: true
linters: linters:
exclusions:
paths:
- vendor
rules:
# In test files, suppress gosec rules that are false positives in test code:
# G101: hardcoded test credentials (intentional fixtures)
# G115: integer overflow in type conversions (test TLS packet builders)
# G304: file paths from variables (t.TempDir paths)
# G402: InsecureSkipVerify (required for test TLS clients)
# G705: XSS via taint analysis (test HTTP handlers, not real servers)
- path: "_test\\.go"
linters:
- gosec
text: "G101|G115|G304|G402|G705"
default: none default: none
enable: enable:
# --- Correctness --- # --- Correctness ---
@@ -52,12 +66,15 @@ linters:
check-type-assertions: true check-type-assertions: true
govet: govet:
# Enable all analyzers except shadow. The shadow analyzer flags the idiomatic # Enable all analyzers except shadow and fieldalignment. The shadow analyzer
# `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`, # flags the idiomatic `if err := f(); err != nil { ... }` pattern as shadowing
# which is ubiquitous in Go and does not pose a security risk in this codebase. # an outer `err`, which is ubiquitous in Go. The fieldalignment analyzer
# suggests struct field reordering for memory efficiency — useful as a one-off
# audit but too noisy for CI (every struct change triggers it).
enable-all: true enable-all: true
disable: disable:
- shadow - shadow
- fieldalignment
gosec: gosec:
# Treat all gosec findings as errors, not warnings. # Treat all gosec findings as errors, not warnings.
@@ -110,15 +127,3 @@ issues:
# Do not cap the number of reported issues; in security code every finding matters. # Do not cap the number of reported issues; in security code every finding matters.
max-issues-per-linter: 0 max-issues-per-linter: 0
max-same-issues: 0 max-same-issues: 0
exclusions:
paths:
- vendor
rules:
# In test files, allow hardcoded test credentials (gosec G101) since they are
# intentional fixtures, not production secrets.
- path: "_test\\.go"
linters:
- gosec
text: "G101"

View File

@@ -32,7 +32,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil { if err != nil {
t.Fatalf("open db: %v", err) t.Fatalf("open db: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
@@ -128,7 +128,7 @@ func setupTestClient(t *testing.T) *Client {
if err != nil { if err != nil {
t.Fatalf("dial bufconn: %v", err) t.Fatalf("dial bufconn: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
return &Client{ return &Client{
conn: conn, conn: conn,

View File

@@ -40,7 +40,7 @@ func serverCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("opening database: %w", err) return fmt.Errorf("opening database: %w", err)
} }
defer store.Close() defer func() { _ = store.Close() }()
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
return fmt.Errorf("running migrations: %w", err) return fmt.Errorf("running migrations: %w", err)
@@ -93,7 +93,7 @@ func serverCmd() *cobra.Command {
}() }()
defer func() { defer func() {
grpcSrv.GracefulStop() grpcSrv.GracefulStop()
os.Remove(cfg.GRPC.SocketPath()) _ = os.Remove(cfg.GRPC.SocketPath())
}() }()
} }

View File

@@ -32,7 +32,7 @@ func snapshotCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("opening database: %w", err) return fmt.Errorf("opening database: %w", err)
} }
defer store.Close() defer func() { _ = store.Close() }()
dataDir := filepath.Dir(cfg.Database.Path) dataDir := filepath.Dir(cfg.Database.Path)

View File

@@ -33,7 +33,7 @@ func statusCmd() *cobra.Command {
if err != nil { if err != nil {
return fmt.Errorf("connecting to gRPC API: %w", err) return fmt.Errorf("connecting to gRPC API: %w", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
client := pb.NewProxyAdminServiceClient(conn) client := pb.NewProxyAdminServiceClient(conn)

View File

@@ -304,7 +304,7 @@ func TestDuration(t *testing.T) {
if err := d.UnmarshalText([]byte("5s")); err != nil { if err := d.UnmarshalText([]byte("5s")); err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
if d.Duration.Seconds() != 5 { if d.Seconds() != 5 {
t.Fatalf("got %v, want 5s", d.Duration) t.Fatalf("got %v, want 5s", d.Duration)
} }
} }
@@ -340,7 +340,7 @@ level = "info"
if cfg.Log.Level != "debug" { if cfg.Log.Level != "debug" {
t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug") t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug")
} }
if cfg.Proxy.IdleTimeout.Duration.Seconds() != 600 { if cfg.Proxy.IdleTimeout.Seconds() != 600 {
t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration) t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration)
} }
if cfg.Database.Path != "/override/test.db" { if cfg.Database.Path != "/override/test.db" {

View File

@@ -17,7 +17,7 @@ func openTestDB(t *testing.T) *Store {
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
return store return store
} }
@@ -266,8 +266,8 @@ func TestRouteCascadeDelete(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
listenerID, _ := store.CreateListener(":443", false, 0) listenerID, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false) _, _ = store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false)
store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false) _, _ = store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false)
if err := store.DeleteListener(listenerID); err != nil { if err := store.DeleteListener(listenerID); err != nil {
t.Fatalf("delete listener: %v", err) t.Fatalf("delete listener: %v", err)
@@ -412,7 +412,7 @@ func TestSeed(t *testing.T) {
func TestSnapshot(t *testing.T) { func TestSnapshot(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
store.CreateListener(":443", false, 0) _, _ = store.CreateListener(":443", false, 0)
dest := filepath.Join(t.TempDir(), "backup.db") dest := filepath.Join(t.TempDir(), "backup.db")
if err := store.Snapshot(dest); err != nil { if err := store.Snapshot(dest); err != nil {
@@ -424,7 +424,7 @@ func TestSnapshot(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("open backup: %v", err) t.Fatalf("open backup: %v", err)
} }
defer backup.Close() defer func() { _ = backup.Close() }()
if err := backup.Migrate(); err != nil { if err := backup.Migrate(); err != nil {
t.Fatalf("migrate backup: %v", err) t.Fatalf("migrate backup: %v", err)
@@ -463,7 +463,7 @@ func TestMigrationV2Upgrade(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("open: %v", err) t.Fatalf("open: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
// Run full migrations (v1 + v2). // Run full migrations (v1 + v2).
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
@@ -556,10 +556,10 @@ func TestL7PolicyCascadeDelete(t *testing.T) {
lid, _ := store.CreateListener(":443", false, 0) lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false) rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
store.CreateL7Policy(rid, "block_user_agent", "Bot") _, _ = store.CreateL7Policy(rid, "block_user_agent", "Bot")
// Deleting the route should cascade-delete its policies. // Deleting the route should cascade-delete its policies.
store.DeleteRoute(lid, "api.test") _ = store.DeleteRoute(lid, "api.test")
policies, _ := store.ListL7Policies(rid) policies, _ := store.ListL7Policies(rid)
if len(policies) != 0 { if len(policies) != 0 {
@@ -585,7 +585,7 @@ func TestGetRouteID(t *testing.T) {
store := openTestDB(t) store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0) lid, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false) _, _ = store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
rid, err := store.GetRouteID(lid, "api.test") rid, err := store.GetRouteID(lid, "api.test")
if err != nil { if err != nil {

View File

@@ -15,7 +15,7 @@ func (s *Store) ListFirewallRules() ([]FirewallRule, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying firewall rules: %w", err) return nil, fmt.Errorf("querying firewall rules: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var rules []FirewallRule var rules []FirewallRule
for rows.Next() { for rows.Next() {

View File

@@ -19,7 +19,7 @@ func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying l7 policies: %w", err) return nil, fmt.Errorf("querying l7 policies: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var policies []L7Policy var policies []L7Policy
for rows.Next() { for rows.Next() {

View File

@@ -16,7 +16,7 @@ func (s *Store) ListListeners() ([]Listener, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying listeners: %w", err) return nil, fmt.Errorf("querying listeners: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var listeners []Listener var listeners []Listener
for rows.Next() { for rows.Next() {

View File

@@ -25,7 +25,7 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("querying routes: %w", err) return nil, fmt.Errorf("querying routes: %w", err)
} }
defer rows.Close() defer func() { _ = rows.Close() }()
var routes []Route var routes []Route
for rows.Next() { for rows.Next() {

View File

@@ -14,7 +14,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if err != nil { if err != nil {
return fmt.Errorf("beginning seed transaction: %w", err) return fmt.Errorf("beginning seed transaction: %w", err)
} }
defer tx.Rollback() defer func() { _ = tx.Rollback() }()
for _, l := range listeners { for _, l := range listeners {
result, err := tx.Exec( result, err := tx.Exec(

View File

@@ -234,7 +234,7 @@ func (f *Firewall) loadGeoDB(path string) error {
f.mu.Unlock() f.mu.Unlock()
if old != nil { if old != nil {
old.Close() _ = old.Close()
} }
return nil return nil
} }

View File

@@ -11,7 +11,7 @@ func TestEmptyFirewall(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"} addrs := []string{"192.168.1.1", "10.0.0.1", "::1", "2001:db8::1"}
for _, a := range addrs { for _, a := range addrs {
@@ -27,7 +27,7 @@ func TestIPBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -52,7 +52,7 @@ func TestCIDRBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -78,7 +78,7 @@ func TestIPv4MappedIPv6(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("::ffff:192.0.2.1") addr := netip.MustParseAddr("::ffff:192.0.2.1")
if !fw.Blocked(addr) { if !fw.Blocked(addr) {
@@ -105,7 +105,7 @@ func TestCombinedRules(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -130,7 +130,7 @@ func TestRateLimitBlocking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1") addr := netip.MustParseAddr("10.0.0.1")
@@ -151,7 +151,7 @@ func TestRateLimitBlocklistFirst(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
blockedAddr := netip.MustParseAddr("10.0.0.1") blockedAddr := netip.MustParseAddr("10.0.0.1")
otherAddr := netip.MustParseAddr("10.0.0.2") otherAddr := netip.MustParseAddr("10.0.0.2")
@@ -175,7 +175,7 @@ func TestBlockedWithReason(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
tests := []struct { tests := []struct {
addr string addr string
@@ -216,7 +216,7 @@ func TestRuntimeMutation(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("unexpected error: %v", err) t.Fatalf("unexpected error: %v", err)
} }
defer fw.Close() defer func() { _ = fw.Close() }()
addr := netip.MustParseAddr("10.0.0.1") addr := netip.MustParseAddr("10.0.0.1")
if fw.Blocked(addr) { if fw.Blocked(addr) {

View File

@@ -38,7 +38,7 @@ func (rl *rateLimiter) Allow(addr netip.Addr) bool {
now := rl.now().UnixNano() now := rl.now().UnixNano()
val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{}) val, _ := rl.entries.LoadOrStore(addr, &rateLimitEntry{})
entry := val.(*rateLimitEntry) entry, _ := val.(*rateLimitEntry)
windowStart := entry.start.Load() windowStart := entry.start.Load()
if now-windowStart >= rl.window.Nanoseconds() { if now-windowStart >= rl.window.Nanoseconds() {
@@ -70,7 +70,7 @@ func (rl *rateLimiter) cleanup() {
case <-ticker.C: case <-ticker.C:
cutoff := rl.now().Add(-2 * rl.window).UnixNano() cutoff := rl.now().Add(-2 * rl.window).UnixNano()
rl.entries.Range(func(key, value any) bool { rl.entries.Range(func(key, value any) bool {
entry := value.(*rateLimitEntry) entry, _ := value.(*rateLimitEntry)
if entry.start.Load() < cutoff { if entry.start.Load() < cutoff {
rl.entries.Delete(key) rl.entries.Delete(key)
} }

View File

@@ -53,7 +53,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
path := cfg.SocketPath() path := cfg.SocketPath()
// Remove stale socket file from a previous run. // Remove stale socket file from a previous run.
os.Remove(path) _ = os.Remove(path)
ln, err := net.Listen("unix", path) ln, err := net.Listen("unix", path)
if err != nil { if err != nil {
@@ -61,7 +61,7 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
} }
if err := os.Chmod(path, 0600); err != nil { if err := os.Chmod(path, 0600); err != nil {
ln.Close() _ = ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err) return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
} }
@@ -446,7 +446,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.
} }
listeners = append(listeners, &pb.ListenerStatus{ listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr, Addr: ls.Addr,
RouteCount: int32(len(routes)), RouteCount: int32(len(routes)), //nolint:gosec // route count can never exceed int32
ActiveConnections: ls.ActiveConnections.Load(), ActiveConnections: ls.ActiveConnections.Load(),
ProxyProtocol: ls.ProxyProtocol, ProxyProtocol: ls.ProxyProtocol,
MaxConnections: ls.MaxConnections, MaxConnections: ls.MaxConnections,

View File

@@ -39,7 +39,7 @@ func setup(t *testing.T) *testEnv {
if err != nil { if err != nil {
t.Fatalf("open db: %v", err) t.Fatalf("open db: %v", err)
} }
t.Cleanup(func() { store.Close() }) t.Cleanup(func() { _ = store.Close() })
if err := store.Migrate(); err != nil { if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err) t.Fatalf("migrate: %v", err)
@@ -130,7 +130,7 @@ func setup(t *testing.T) *testEnv {
if err != nil { if err != nil {
t.Fatalf("dial bufconn: %v", err) t.Fatalf("dial bufconn: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
return &testEnv{ return &testEnv{
client: pb.NewProxyAdminServiceClient(conn), client: pb.NewProxyAdminServiceClient(conn),
@@ -775,7 +775,7 @@ func TestRemoveL7Policy(t *testing.T) {
env := setup(t) env := setup(t)
ctx := context.Background() ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ _, _ = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443", ListenerAddr: ":443",
Hostname: "a.test", Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"}, Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},

View File

@@ -13,27 +13,27 @@ func TestPrefixConnRead(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("WORLD")) _, _ = conn.Write([]byte("WORLD"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("HELLO")) pc := NewPrefixConn(conn, []byte("HELLO"))
// Read all data: should get "HELLOWORLD". // Read all data: should get "HELLOWORLD".
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc) all, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -48,22 +48,22 @@ func TestPrefixConnSmallReads(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("CD")) _, _ = conn.Write([]byte("CD"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("AB")) pc := NewPrefixConn(conn, []byte("AB"))
@@ -79,7 +79,7 @@ func TestPrefixConnSmallReads(t *testing.T) {
} }
// Now reads come from the underlying conn. // Now reads come from the underlying conn.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
rest, err := io.ReadAll(pc) rest, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -94,25 +94,25 @@ func TestPrefixConnEmptyPrefix(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, err := ln.Accept() conn, err := ln.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
conn.Write([]byte("DATA")) _, _ = conn.Write([]byte("DATA"))
}() }()
conn, err := net.Dial("tcp", ln.Addr().String()) conn, err := net.Dial("tcp", ln.Addr().String())
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, nil) pc := NewPrefixConn(conn, nil)
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
all, err := io.ReadAll(pc) all, err := io.ReadAll(pc)
if err != nil { if err != nil {
t.Fatalf("ReadAll: %v", err) t.Fatalf("ReadAll: %v", err)
@@ -127,12 +127,12 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
go func() { go func() {
conn, _ := ln.Accept() conn, _ := ln.Accept()
if conn != nil { if conn != nil {
conn.Close() _ = conn.Close()
} }
}() }()
@@ -140,7 +140,7 @@ func TestPrefixConnDelegates(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
pc := NewPrefixConn(conn, []byte("X")) pc := NewPrefixConn(conn, []byte("X"))

View File

@@ -120,7 +120,7 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
ReadHeaderTimeout: 30 * time.Second, ReadHeaderTimeout: 30 * time.Second,
} }
singleConn := newSingleConnListener(tlsConn) singleConn := newSingleConnListener(tlsConn)
srv.Serve(singleConn) _ = srv.Serve(singleConn)
} }
return nil return nil
@@ -213,7 +213,7 @@ func dialBackend(ctx context.Context, network, addr string, timeout time.Duratio
backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String()) backendAddr, _ := netip.ParseAddrPort(conn.RemoteAddr().String())
if clientAddr.IsValid() { if clientAddr.IsValid() {
if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil { if err := proxyproto.WriteV2(conn, clientAddr, backendAddr); err != nil {
conn.Close() _ = conn.Close()
return nil, fmt.Errorf("writing PROXY protocol header: %w", err) return nil, fmt.Errorf("writing PROXY protocol header: %w", err)
} }
} }

View File

@@ -58,8 +58,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil { if err != nil {
t.Fatalf("creating cert file: %v", err) t.Fatalf("creating cert file: %v", err)
} }
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) _ = pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close() _ = certFile.Close()
keyDER, err := x509.MarshalECPrivateKey(key) keyDER, err := x509.MarshalECPrivateKey(key)
if err != nil { if err != nil {
@@ -69,8 +69,8 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
if err != nil { if err != nil {
t.Fatalf("creating key file: %v", err) t.Fatalf("creating key file: %v", err)
} }
pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) _ = pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
keyFile.Close() _ = keyFile.Close()
return certPath, keyPath return certPath, keyPath
} }
@@ -91,11 +91,11 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { t.Cleanup(func() {
srv.Close() _ = srv.Close()
ln.Close() _ = ln.Close()
}) })
go srv.Serve(ln) go func() { _ = srv.Serve(ln) }()
return ln.Addr().String() return ln.Addr().String()
} }
@@ -118,7 +118,7 @@ func dialTLSToProxy(t *testing.T, proxyAddr, serverName string) *http.Client {
if err != nil { if err != nil {
t.Fatalf("TLS dial: %v", err) t.Fatalf("TLS dial: %v", err)
} }
t.Cleanup(func() { conn.Close() }) t.Cleanup(func() { _ = conn.Close() })
// Create an HTTP/2 client transport over this single connection. // Create an HTTP/2 client transport over this single connection.
tr := &http2.Transport{} tr := &http2.Transport{}
@@ -142,29 +142,13 @@ func (s *singleConnRoundTripper) RoundTrip(req *http.Request) (*http.Response, e
return s.cc.RoundTrip(req) return s.cc.RoundTrip(req)
} }
// serveL7Route starts l7.Serve in a goroutine for a single connection.
// Returns when the goroutine completes.
func serveL7Route(t *testing.T, conn net.Conn, peeked []byte, route RouteConfig) {
t.Helper()
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
ctx := context.Background()
go func() {
l7Err := Serve(ctx, conn, peeked, route, clientAddr, logger)
if l7Err != nil {
t.Logf("l7.Serve: %v", l7Err)
}
}()
}
func TestL7H2CBackend(t *testing.T) { func TestL7H2CBackend(t *testing.T) {
certPath, keyPath := testCert(t, "l7.test") certPath, keyPath := testCert(t, "l7.test")
// Start an h2c backend. // Start an h2c backend.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Backend", "ok") w.Header().Set("X-Backend", "ok")
fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path) _, _ = fmt.Fprintf(w, "hello from backend, path=%s", r.URL.Path)
})) }))
// Start a TCP listener for the L7 proxy. // Start a TCP listener for the L7 proxy.
@@ -172,7 +156,7 @@ func TestL7H2CBackend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -190,17 +174,17 @@ func TestL7H2CBackend(t *testing.T) {
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
// No peeked bytes — the client is connecting directly with TLS. // No peeked bytes — the client is connecting directly with TLS.
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
// Connect as an HTTP/2 TLS client. // Connect as an HTTP/2 TLS client.
client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "l7.test")
resp, err := client.Get(fmt.Sprintf("https://l7.test/foo")) resp, err := client.Get("https://l7.test/foo")
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode) t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -221,7 +205,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
// Backend that echoes the forwarding headers. // Backend that echoes the forwarding headers.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s xfp=%s xri=%s", _, _ = fmt.Fprintf(w, "xff=%s xfp=%s xri=%s",
r.Header.Get("X-Forwarded-For"), r.Header.Get("X-Forwarded-For"),
r.Header.Get("X-Forwarded-Proto"), r.Header.Get("X-Forwarded-Proto"),
r.Header.Get("X-Real-IP"), r.Header.Get("X-Real-IP"),
@@ -232,7 +216,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -248,7 +232,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "headers.test")
@@ -256,7 +240,7 @@ func TestL7ForwardingHeaders(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
want := "xff=203.0.113.50 xfp=https xri=203.0.113.50" want := "xff=203.0.113.50 xfp=https xri=203.0.113.50"
@@ -274,13 +258,13 @@ func TestL7BackendUnreachable(t *testing.T) {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
deadAddr := ln.Addr().String() deadAddr := ln.Addr().String()
ln.Close() _ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: deadAddr, Backend: deadAddr,
@@ -296,7 +280,7 @@ func TestL7BackendUnreachable(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "unreachable.test")
@@ -304,7 +288,7 @@ func TestL7BackendUnreachable(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusBadGateway { if resp.StatusCode != http.StatusBadGateway {
t.Fatalf("status = %d, want 502", resp.StatusCode) t.Fatalf("status = %d, want 502", resp.StatusCode)
@@ -342,14 +326,14 @@ func TestL7MultipleRequests(t *testing.T) {
var reqCount int var reqCount int
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
reqCount++ reqCount++
fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path) _, _ = fmt.Fprintf(w, "req=%d path=%s", reqCount, r.URL.Path)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -365,7 +349,7 @@ func TestL7MultipleRequests(t *testing.T) {
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
clientAddr := netip.MustParseAddrPort("203.0.113.50:12345") clientAddr := netip.MustParseAddrPort("203.0.113.50:12345")
Serve(context.Background(), conn, nil, route, clientAddr, logger) _ = Serve(context.Background(), conn, nil, route, clientAddr, logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "multi.test")
@@ -378,7 +362,7 @@ func TestL7MultipleRequests(t *testing.T) {
t.Fatalf("GET %s: %v", path, err) t.Fatalf("GET %s: %v", path, err)
} }
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
resp.Body.Close() _ = resp.Body.Close()
want := fmt.Sprintf("req=%d path=%s", i+1, path) want := fmt.Sprintf("req=%d path=%s", i+1, path)
if string(body) != want { if string(body) != want {
@@ -396,14 +380,14 @@ func TestL7LargeResponse(t *testing.T) {
largeBody[i] = byte(i % 256) largeBody[i] = byte(i % 256)
} }
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write(largeBody) _, _ = w.Write(largeBody)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -418,7 +402,7 @@ func TestL7LargeResponse(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "large.test")
@@ -426,7 +410,7 @@ func TestL7LargeResponse(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if len(body) != len(largeBody) { if len(body) != len(largeBody) {
@@ -455,7 +439,7 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -470,7 +454,7 @@ func TestL7GRPCTrailers(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "trailers.test")
@@ -480,10 +464,10 @@ func TestL7GRPCTrailers(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("POST: %v", err) t.Fatalf("POST: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
// Read body to trigger trailer delivery. // Read body to trigger trailer delivery.
io.ReadAll(resp.Body) _, _ = io.ReadAll(resp.Body)
// Verify trailers were forwarded through the proxy. // Verify trailers were forwarded through the proxy.
grpcStatus := resp.Trailer.Get("Grpc-Status") grpcStatus := resp.Trailer.Get("Grpc-Status")
@@ -500,14 +484,14 @@ func TestL7HTTP11Fallback(t *testing.T) {
certPath, keyPath := testCert(t, "http11.test") certPath, keyPath := testCert(t, "http11.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "proto=%s", r.Proto) _, _ = fmt.Fprintf(w, "proto=%s", r.Proto)
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -522,7 +506,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
// Connect with HTTP/1.1 only (no h2 ALPN). // Connect with HTTP/1.1 only (no h2 ALPN).
@@ -538,7 +522,7 @@ func TestL7HTTP11Fallback(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
t.Fatalf("status = %d, want 200", resp.StatusCode) t.Fatalf("status = %d, want 200", resp.StatusCode)
@@ -556,14 +540,14 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
certPath, keyPath := testCert(t, "policy.test") certPath, keyPath := testCert(t, "policy.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "should-not-reach") _, _ = fmt.Fprint(w, "should-not-reach")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -581,7 +565,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
return return
} }
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test") client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
@@ -591,7 +575,7 @@ func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET: %v", err) t.Fatalf("GET: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != 403 { if resp.StatusCode != 403 {
t.Fatalf("status = %d, want 403", resp.StatusCode) t.Fatalf("status = %d, want 403", resp.StatusCode)
@@ -602,14 +586,14 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
certPath, keyPath := testCert(t, "reqhdr.test") certPath, keyPath := testCert(t, "reqhdr.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok") _, _ = fmt.Fprint(w, "ok")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
defer proxyLn.Close() defer func() { _ = proxyLn.Close() }()
route := RouteConfig{ route := RouteConfig{
Backend: backendAddr, Backend: backendAddr,
@@ -630,7 +614,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
} }
go func() { go func() {
logger := slog.New(slog.NewTextHandler(io.Discard, nil)) logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger) _ = Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}() }()
} }
}() }()
@@ -641,7 +625,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET without header: %v", err) t.Fatalf("GET without header: %v", err)
} }
resp1.Body.Close() _ = resp1.Body.Close()
if resp1.StatusCode != 403 { if resp1.StatusCode != 403 {
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode) t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
} }
@@ -654,7 +638,7 @@ func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("GET with header: %v", err) t.Fatalf("GET with header: %v", err)
} }
defer resp2.Body.Close() defer func() { _ = resp2.Body.Close() }()
body, _ := io.ReadAll(resp2.Body) body, _ := io.ReadAll(resp2.Body)
if resp2.StatusCode != 200 { if resp2.StatusCode != 200 {
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode) t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)

View File

@@ -2,6 +2,7 @@ package proxy
import ( import (
"context" "context"
"errors"
"io" "io"
"net" "net"
"sync" "sync"
@@ -31,8 +32,8 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
go func() { go func() {
<-ctx.Done() <-ctx.Done()
client.Close() _ = client.Close()
backend.Close() _ = backend.Close()
}() }()
var ( var (
@@ -50,7 +51,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout) result.ClientBytes, errC2B = copyWithIdleTimeout(backend, client, idleTimeout)
// Half-close backend's write side. // Half-close backend's write side.
if hc, ok := backend.(interface{ CloseWrite() error }); ok { if hc, ok := backend.(interface{ CloseWrite() error }); ok {
hc.CloseWrite() _ = hc.CloseWrite()
} }
}() }()
@@ -60,7 +61,7 @@ func Relay(ctx context.Context, client, backend net.Conn, peeked []byte, idleTim
result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout) result.BackendBytes, errB2C = copyWithIdleTimeout(client, backend, idleTimeout)
// Half-close client's write side. // Half-close client's write side.
if hc, ok := client.(interface{ CloseWrite() error }); ok { if hc, ok := client.(interface{ CloseWrite() error }); ok {
hc.CloseWrite() _ = hc.CloseWrite()
} }
}() }()
@@ -85,10 +86,10 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
var total int64 var total int64
for { for {
src.SetReadDeadline(time.Now().Add(idleTimeout)) _ = src.SetReadDeadline(time.Now().Add(idleTimeout))
nr, readErr := src.Read(buf) nr, readErr := src.Read(buf)
if nr > 0 { if nr > 0 {
dst.SetWriteDeadline(time.Now().Add(idleTimeout)) _ = dst.SetWriteDeadline(time.Now().Add(idleTimeout))
nw, writeErr := dst.Write(buf[:nr]) nw, writeErr := dst.Write(buf[:nr])
total += int64(nw) total += int64(nw)
if writeErr != nil { if writeErr != nil {
@@ -96,7 +97,7 @@ func copyWithIdleTimeout(dst, src net.Conn, idleTimeout time.Duration) (int64, e
} }
} }
if readErr != nil { if readErr != nil {
if readErr == io.EOF { if errors.Is(readErr, io.EOF) {
return total, nil return total, nil
} }
return total, readErr return total, readErr

View File

@@ -16,7 +16,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
peeked := []byte("peeked-hello-bytes") peeked := []byte("peeked-hello-bytes")
clientData := []byte("data from client") clientData := []byte("data from client")
@@ -29,7 +29,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read everything the backend receives. // Read everything the backend receives.
received, _ := io.ReadAll(conn) received, _ := io.ReadAll(conn)
@@ -40,21 +40,21 @@ func TestRelayBasic(t *testing.T) {
}() }()
// Restructure: use a more controlled flow. // Restructure: use a more controlled flow.
backendLn.Close() _ = backendLn.Close()
// Use a real TCP pair for proper half-close. // Use a real TCP pair for proper half-close.
backendLn2, err := net.Listen("tcp", "127.0.0.1:0") backendLn2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer backendLn2.Close() defer func() { _ = backendLn2.Close() }()
go func() { go func() {
conn, err := backendLn2.Accept() conn, err := backendLn2.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read peeked + client data. // Read peeked + client data.
buf := make([]byte, len(peeked)+len(clientData)) buf := make([]byte, len(peeked)+len(clientData))
@@ -62,11 +62,11 @@ func TestRelayBasic(t *testing.T) {
backendDone <- buf[:n] backendDone <- buf[:n]
// Send response. // Send response.
conn.Write(backendData) _, _ = conn.Write(backendData)
// Close write side to signal EOF. // Close write side to signal EOF.
if tc, ok := conn.(*net.TCPConn); ok { if tc, ok := conn.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -81,7 +81,7 @@ func TestRelayBasic(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer clientLn.Close() defer func() { _ = clientLn.Close() }()
clientConn, err := net.Dial("tcp", clientLn.Addr().String()) clientConn, err := net.Dial("tcp", clientLn.Addr().String())
if err != nil { if err != nil {
@@ -94,9 +94,9 @@ func TestRelayBasic(t *testing.T) {
// Client sends data then closes write. // Client sends data then closes write.
go func() { go func() {
clientConn.Write(clientData) _, _ = clientConn.Write(clientData)
if tc, ok := clientConn.(*net.TCPConn); ok { if tc, ok := clientConn.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -114,7 +114,7 @@ func TestRelayBasic(t *testing.T) {
} }
// Verify client received backend data. // Verify client received backend data.
clientConn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = clientConn.SetReadDeadline(time.Now().Add(2 * time.Second))
clientReceived, _ := io.ReadAll(clientConn) clientReceived, _ := io.ReadAll(clientConn)
if !bytes.Equal(clientReceived, backendData) { if !bytes.Equal(clientReceived, backendData) {
t.Fatalf("client received %q, want %q", clientReceived, backendData) t.Fatalf("client received %q, want %q", clientReceived, backendData)
@@ -131,12 +131,12 @@ func TestRelayBasic(t *testing.T) {
func TestRelayIdleTimeout(t *testing.T) { func TestRelayIdleTimeout(t *testing.T) {
// Two connected pairs via TCP. // Two connected pairs via TCP.
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
start := time.Now() start := time.Now()
_, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond) _, err := Relay(context.Background(), clientB, backendA, nil, 100*time.Millisecond)
@@ -154,18 +154,18 @@ func TestRelayIdleTimeout(t *testing.T) {
func TestRelayContextCancel(t *testing.T) { func TestRelayContextCancel(t *testing.T) {
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{}) done := make(chan struct{})
go func() { go func() {
Relay(ctx, clientB, backendA, nil, time.Minute) _, _ = Relay(ctx, clientB, backendA, nil, time.Minute)
close(done) close(done)
}() }()
@@ -185,12 +185,12 @@ func TestRelayContextCancel(t *testing.T) {
func TestRelayLargeTransfer(t *testing.T) { func TestRelayLargeTransfer(t *testing.T) {
clientA, clientB := tcpPair(t) clientA, clientB := tcpPair(t)
defer clientA.Close() defer func() { _ = clientA.Close() }()
defer clientB.Close() defer func() { _ = clientB.Close() }()
backendA, backendB := tcpPair(t) backendA, backendB := tcpPair(t)
defer backendA.Close() defer func() { _ = backendA.Close() }()
defer backendB.Close() defer func() { _ = backendB.Close() }()
// 1 MB of random data. // 1 MB of random data.
data := make([]byte, 1<<20) data := make([]byte, 1<<20)
@@ -199,9 +199,9 @@ func TestRelayLargeTransfer(t *testing.T) {
} }
go func() { go func() {
clientA.Write(data) _, _ = clientA.Write(data)
if tc, ok := clientA.(*net.TCPConn); ok { if tc, ok := clientA.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -211,14 +211,14 @@ func TestRelayLargeTransfer(t *testing.T) {
for { for {
n, err := backendB.Read(buf) n, err := backendB.Read(buf)
if n > 0 { if n > 0 {
backendB.Write(buf[:n]) _, _ = backendB.Write(buf[:n])
} }
if err != nil { if err != nil {
break break
} }
} }
if tc, ok := backendB.(*net.TCPConn); ok { if tc, ok := backendB.(*net.TCPConn); ok {
tc.CloseWrite() _ = tc.CloseWrite()
} }
}() }()
@@ -240,7 +240,7 @@ func tcpPair(t *testing.T) (net.Conn, net.Conn) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
defer ln.Close() defer func() { _ = ln.Close() }()
var serverConn net.Conn var serverConn net.Conn
done := make(chan struct{}) done := make(chan struct{})

View File

@@ -50,8 +50,8 @@ const (
// It reads only the exact bytes of the PROXY header, leaving the connection // It reads only the exact bytes of the PROXY header, leaving the connection
// positioned at the first byte after the header (e.g., TLS ClientHello). // positioned at the first byte after the header (e.g., TLS ClientHello).
func Parse(conn net.Conn, deadline time.Time) (Header, error) { func Parse(conn net.Conn, deadline time.Time) (Header, error) {
conn.SetReadDeadline(deadline) _ = conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{}) defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read the first byte to determine version. // Read the first byte to determine version.
var first [1]byte var first [1]byte

View File

@@ -16,7 +16,7 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { ln.Close() }) t.Cleanup(func() { _ = ln.Close() })
ch := make(chan net.Conn, 1) ch := make(chan net.Conn, 1)
go func() { go func() {
@@ -31,10 +31,10 @@ func pipeWithDeadline(t *testing.T) (reader net.Conn, writer net.Conn) {
if err != nil { if err != nil {
t.Fatalf("dial: %v", err) t.Fatalf("dial: %v", err)
} }
t.Cleanup(func() { w.Close() }) t.Cleanup(func() { _ = w.Close() })
r := <-ch r := <-ch
t.Cleanup(func() { r.Close() }) t.Cleanup(func() { _ = r.Close() })
return r, w return r, w
} }
@@ -42,7 +42,7 @@ func TestParseV1TCP4(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n")) _, _ = writer.Write([]byte("PROXY TCP4 192.168.1.1 10.0.0.1 56324 443\r\n"))
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -68,7 +68,7 @@ func TestParseV1TCP6(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n")) _, _ = writer.Write([]byte("PROXY TCP6 2001:db8::1 2001:db8::2 56324 8443\r\n"))
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -98,7 +98,7 @@ func TestParseV2TCP4(t *testing.T) {
buf = append(buf, 10, 0, 0, 1) // dst IP buf = append(buf, 10, 0, 0, 1) // dst IP
buf = binary.BigEndian.AppendUint16(buf, 12345) buf = binary.BigEndian.AppendUint16(buf, 12345)
buf = binary.BigEndian.AppendUint16(buf, 443) buf = binary.BigEndian.AppendUint16(buf, 443)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -137,7 +137,7 @@ func TestParseV2TCP6(t *testing.T) {
buf = append(buf, dst[:]...) buf = append(buf, dst[:]...)
buf = binary.BigEndian.AppendUint16(buf, 56324) buf = binary.BigEndian.AppendUint16(buf, 56324)
buf = binary.BigEndian.AppendUint16(buf, 8443) buf = binary.BigEndian.AppendUint16(buf, 8443)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -163,7 +163,7 @@ func TestParseV2Local(t *testing.T) {
buf = append(buf, 0x20) // version 2, LOCAL command buf = append(buf, 0x20) // version 2, LOCAL command
buf = append(buf, 0x00) // unspec family, unspec protocol buf = append(buf, 0x00) // unspec family, unspec protocol
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -195,7 +195,7 @@ func TestParseV1Malformed(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte(tt.data)) _, _ = writer.Write([]byte(tt.data))
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -211,7 +211,7 @@ func TestParseV2Malformed(t *testing.T) {
go func() { go func() {
bad := make([]byte, v2HeaderLen) bad := make([]byte, v2HeaderLen)
bad[0] = v2Signature[0] // first byte matches but rest doesn't bad[0] = v2Signature[0] // first byte matches but rest doesn't
writer.Write(bad) _, _ = writer.Write(bad)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -227,7 +227,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x31) // version 3, PROXY command buf = append(buf, 0x31) // version 3, PROXY command
buf = append(buf, 0x11) buf = append(buf, 0x11)
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -244,8 +244,8 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x11) // AF_INET, STREAM buf = append(buf, 0x11) // AF_INET, STREAM
buf = binary.BigEndian.AppendUint16(buf, 12) buf = binary.BigEndian.AppendUint16(buf, 12)
buf = append(buf, 1, 2, 3) // only 3 bytes, need 12 buf = append(buf, 1, 2, 3) // only 3 bytes, need 12
writer.Write(buf) _, _ = writer.Write(buf)
writer.Close() _ = writer.Close()
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -261,7 +261,7 @@ func TestParseV2Malformed(t *testing.T) {
buf = append(buf, 0x21) // version 2, PROXY buf = append(buf, 0x21) // version 2, PROXY
buf = append(buf, 0x31) // AF_UNIX (3), STREAM buf = append(buf, 0x31) // AF_UNIX (3), STREAM
buf = binary.BigEndian.AppendUint16(buf, 0) buf = binary.BigEndian.AppendUint16(buf, 0)
writer.Write(buf) _, _ = writer.Write(buf)
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {
@@ -332,7 +332,7 @@ func TestRoundTripV2IPv4(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
WriteV2(writer, src, dst) _ = WriteV2(writer, src, dst)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -361,7 +361,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
WriteV2(writer, src, dst) _ = WriteV2(writer, src, dst)
}() }()
hdr, err := Parse(reader, time.Now().Add(5*time.Second)) hdr, err := Parse(reader, time.Now().Add(5*time.Second))
@@ -380,7 +380,7 @@ func TestRoundTripV2IPv6(t *testing.T) {
func TestParseGarbageFirstByte(t *testing.T) { func TestParseGarbageFirstByte(t *testing.T) {
reader, writer := pipeWithDeadline(t) reader, writer := pipeWithDeadline(t)
go func() { go func() {
writer.Write([]byte{0xFF, 0x00, 0x01}) _, _ = writer.Write([]byte{0xFF, 0x00, 0x01})
}() }()
_, err := Parse(reader, time.Now().Add(2*time.Second)) _, err := Parse(reader, time.Now().Add(2*time.Second))
if err == nil { if err == nil {

View File

@@ -235,7 +235,7 @@ func (s *Server) Run(ctx context.Context) error {
ln, err := net.Listen("tcp", ls.Addr) ln, err := net.Listen("tcp", ls.Addr)
if err != nil { if err != nil {
for _, l := range netListeners { for _, l := range netListeners {
l.Close() _ = l.Close()
} }
return fmt.Errorf("listening on %s: %w", ls.Addr, err) return fmt.Errorf("listening on %s: %w", ls.Addr, err)
} }
@@ -253,7 +253,7 @@ func (s *Server) Run(ctx context.Context) error {
s.logger.Info("shutting down") s.logger.Info("shutting down")
for _, ln := range netListeners { for _, ln := range netListeners {
ln.Close() _ = ln.Close()
} }
done := make(chan struct{}) done := make(chan struct{})
@@ -272,7 +272,7 @@ func (s *Server) Run(ctx context.Context) error {
<-done <-done
} }
s.fw.Close() _ = s.fw.Close()
return nil return nil
} }
@@ -294,7 +294,7 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState)
// Enforce per-listener connection limit. // Enforce per-listener connection limit.
if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections { if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections {
conn.Close() _ = conn.Close()
s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections) s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections)
continue continue
} }
@@ -311,7 +311,7 @@ func (s *Server) forceCloseAll() {
for _, ls := range s.listeners { for _, ls := range s.listeners {
ls.connMu.Lock() ls.connMu.Lock()
for conn := range ls.activeConns { for conn := range ls.activeConns {
conn.Close() _ = conn.Close()
} }
ls.connMu.Unlock() ls.connMu.Unlock()
} }
@@ -321,7 +321,7 @@ func (s *Server) handleConn(ctx context.Context, conn net.Conn, ls *ListenerStat
defer s.wg.Done() defer s.wg.Done()
defer ls.ActiveConnections.Add(-1) defer ls.ActiveConnections.Add(-1)
defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec() defer metrics.ConnectionsActive.WithLabelValues(ls.Addr).Dec()
defer conn.Close() defer func() { _ = conn.Close() }()
ls.connMu.Lock() ls.connMu.Lock()
ls.activeConns[conn] = struct{}{} ls.activeConns[conn] = struct{}{}
@@ -392,7 +392,7 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err) s.logger.Error("backend dial failed", "hostname", hostname, "backend", route.Backend, "error", err)
return return
} }
defer backendConn.Close() defer func() { _ = backendConn.Close() }()
// Send PROXY protocol v2 header to backend if configured. // Send PROXY protocol v2 header to backend if configured.
if route.SendProxyProtocol { if route.SendProxyProtocol {

View File

@@ -43,8 +43,8 @@ func echoServer(t *testing.T, ln net.Listener) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
io.Copy(conn, conn) _, _ = io.Copy(conn, conn)
} }
// newTestServer creates a Server with the given listener data and no firewall rules. // newTestServer creates a Server with the given listener data and no firewall rules.
@@ -92,7 +92,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
// Pick a free port for the proxy listener. // Pick a free port for the proxy listener.
@@ -101,7 +101,7 @@ func TestProxyRoundTrip(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -121,7 +121,7 @@ func TestProxyRoundTrip(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -130,7 +130,7 @@ func TestProxyRoundTrip(t *testing.T) {
// The backend will echo our ClientHello back. Read it. // The backend will echo our ClientHello back. Read it.
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -157,7 +157,7 @@ func TestNoRouteResets(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -176,7 +176,7 @@ func TestNoRouteResets(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("unknown.test") hello := buildClientHello("unknown.test")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -184,7 +184,7 @@ func TestNoRouteResets(t *testing.T) {
} }
// The proxy should close the connection (no route match). // The proxy should close the connection (no route match).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed, but read succeeded") t.Fatal("expected connection to be closed, but read succeeded")
@@ -197,7 +197,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -205,7 +205,7 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -214,7 +214,7 @@ func TestFirewallBlocks(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
// Create a firewall that blocks 127.0.0.1 (the test client). // Create a firewall that blocks 127.0.0.1 (the test client).
fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0) fw, err := firewall.New("", []string{"127.0.0.1"}, nil, nil, 0, 0)
@@ -245,7 +245,7 @@ func TestFirewallBlocks(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
srv.Run(ctx) _ = srv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -253,13 +253,13 @@ func TestFirewallBlocks(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Connection should be closed (blocked by firewall). // Connection should be closed (blocked by firewall).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed by firewall") t.Fatal("expected connection to be closed by firewall")
@@ -283,7 +283,7 @@ func TestNotTLSResets(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -300,12 +300,12 @@ func TestNotTLSResets(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send HTTP, not TLS. // Send HTTP, not TLS.
conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n")) _, _ = conn.Write([]byte("GET / HTTP/1.1\r\nHost: x.test\r\n\r\n"))
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed for non-TLS data") t.Fatal("expected connection to be closed for non-TLS data")
@@ -318,7 +318,7 @@ func TestConnectionTracking(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
var backendConns []net.Conn var backendConns []net.Conn
var mu sync.Mutex var mu sync.Mutex
@@ -332,7 +332,7 @@ func TestConnectionTracking(t *testing.T) {
backendConns = append(backendConns, conn) backendConns = append(backendConns, conn)
mu.Unlock() mu.Unlock()
// Hold connection open, drain input. // Hold connection open, drain input.
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
} }
}() }()
@@ -341,7 +341,7 @@ func TestConnectionTracking(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -382,10 +382,10 @@ func TestConnectionTracking(t *testing.T) {
} }
// Close one client and its corresponding backend connection. // Close one client and its corresponding backend connection.
clientConns[0].Close() _ = clientConns[0].Close()
mu.Lock() mu.Lock()
if len(backendConns) > 0 { if len(backendConns) > 0 {
backendConns[0].Close() _ = backendConns[0].Close()
} }
mu.Unlock() mu.Unlock()
@@ -402,10 +402,10 @@ func TestConnectionTracking(t *testing.T) {
} }
// Clean up. // Clean up.
clientConns[1].Close() _ = clientConns[1].Close()
mu.Lock() mu.Lock()
for _, c := range backendConns { for _, c := range backendConns {
c.Close() _ = c.Close()
} }
mu.Unlock() mu.Unlock()
} }
@@ -416,13 +416,13 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend A listen: %v", err) t.Fatalf("backend A listen: %v", err)
} }
defer backendA.Close() defer func() { _ = backendA.Close() }()
backendB, err := net.Listen("tcp", "127.0.0.1:0") backendB, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("backend B listen: %v", err) t.Fatalf("backend B listen: %v", err)
} }
defer backendB.Close() defer func() { _ = backendB.Close() }()
// Each backend writes its identity and closes. // Each backend writes its identity and closes.
serve := func(ln net.Listener, id string) { serve := func(ln net.Listener, id string) {
@@ -430,10 +430,10 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Drain the incoming data, then write identity. // Drain the incoming data, then write identity.
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
conn.Write([]byte(id)) _, _ = conn.Write([]byte(id))
} }
go serve(backendA, "A") go serve(backendA, "A")
go serve(backendB, "B") go serve(backendB, "B")
@@ -444,14 +444,14 @@ func TestMultipleListeners(t *testing.T) {
t.Fatalf("finding free port 1: %v", err) t.Fatalf("finding free port 1: %v", err)
} }
addr1 := ln1.Addr().String() addr1 := ln1.Addr().String()
ln1.Close() _ = ln1.Close()
ln2, err := net.Listen("tcp", "127.0.0.1:0") ln2, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("finding free port 2: %v", err) t.Fatalf("finding free port 2: %v", err)
} }
addr2 := ln2.Addr().String() addr2 := ln2.Addr().String()
ln2.Close() _ = ln2.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}}, {ID: 1, Addr: addr1, Routes: map[string]RouteInfo{"svc.test": l4Route(backendA.Addr().String())}},
@@ -467,12 +467,12 @@ func TestMultipleListeners(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial %s: %v", proxyAddr, err) t.Fatalf("dial %s: %v", proxyAddr, err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("svc.test") hello := buildClientHello("svc.test")
conn.Write(hello) _, _ = conn.Write(hello)
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
buf := make([]byte, 128) buf := make([]byte, 128)
// Read what the backend sends back: echoed ClientHello + ID. // Read what the backend sends back: echoed ClientHello + ID.
// The backend drains input and writes the ID, so we read until we // The backend drains input and writes the ID, so we read until we
@@ -508,7 +508,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -516,7 +516,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -537,7 +537,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("ECHO.TEST") hello := buildClientHello("ECHO.TEST")
if _, err := conn.Write(hello); err != nil { if _, err := conn.Write(hello); err != nil {
@@ -545,7 +545,7 @@ func TestCaseInsensitiveRouting(t *testing.T) {
} }
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -558,14 +558,14 @@ func TestBackendUnreachable(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
deadAddr := ln.Addr().String() deadAddr := ln.Addr().String()
ln.Close() _ = ln.Close()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -584,13 +584,13 @@ func TestBackendUnreachable(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("dead.test") hello := buildClientHello("dead.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Proxy should close the connection after failing to dial backend. // Proxy should close the connection after failing to dial backend.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed when backend is unreachable") t.Fatal("expected connection to be closed when backend is unreachable")
@@ -603,15 +603,15 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go func() { go func() {
conn, err := backendLn.Accept() conn, err := backendLn.Accept()
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
io.Copy(io.Discard, conn) _, _ = io.Copy(io.Discard, conn)
}() }()
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -619,7 +619,7 @@ func TestGracefulShutdown(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
fw, err := firewall.New("", nil, nil, nil, 0, 0) fw, err := firewall.New("", nil, nil, nil, 0, 0)
if err != nil { if err != nil {
@@ -649,10 +649,10 @@ func TestGracefulShutdown(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("hold.test") hello := buildClientHello("hold.test")
conn.Write(hello) _, _ = conn.Write(hello)
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
// Trigger shutdown. // Trigger shutdown.
@@ -719,7 +719,7 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go echoServer(t, backendLn) go echoServer(t, backendLn)
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -727,7 +727,7 @@ func TestProxyProtocolReceive(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -747,22 +747,22 @@ func TestProxyProtocolReceive(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send PROXY v2 header followed by TLS ClientHello. // Send PROXY v2 header followed by TLS ClientHello.
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"), netip.MustParseAddrPort("198.51.100.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
hello := buildClientHello("echo.test") hello := buildClientHello("echo.test")
conn.Write(hello) _, _ = conn.Write(hello)
// Backend should echo the ClientHello back (not the PROXY header). // Backend should echo the ClientHello back (not the PROXY header).
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(conn, echoed); err != nil { if _, err := io.ReadFull(conn, echoed); err != nil {
t.Fatalf("read echoed data: %v", err) t.Fatalf("read echoed data: %v", err)
} }
@@ -774,7 +774,7 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -794,13 +794,13 @@ func TestProxyProtocolReceiveGarbage(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send garbage instead of a valid PROXY header. // Send garbage instead of a valid PROXY header.
conn.Write([]byte("NOT A PROXY HEADER\r\n")) _, _ = conn.Write([]byte("NOT A PROXY HEADER\r\n"))
// Connection should be closed. // Connection should be closed.
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed for invalid PROXY header") t.Fatal("expected connection to be closed for invalid PROXY header")
@@ -813,7 +813,7 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1) received := make(chan []byte, 1)
go func() { go func() {
@@ -821,9 +821,9 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Read all available data; the proxy sends PROXY header + ClientHello. // Read all available data; the proxy sends PROXY header + ClientHello.
conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(5 * time.Second))
var all []byte var all []byte
buf := make([]byte, 4096) buf := make([]byte, 4096)
for { for {
@@ -845,7 +845,7 @@ func TestProxyProtocolSend(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -868,10 +868,10 @@ func TestProxyProtocolSend(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("pp.test") hello := buildClientHello("pp.test")
conn.Write(hello) _, _ = conn.Write(hello)
// The backend should receive: PROXY v2 header + ClientHello. // The backend should receive: PROXY v2 header + ClientHello.
select { select {
@@ -904,7 +904,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
received := make(chan []byte, 1) received := make(chan []byte, 1)
go func() { go func() {
@@ -912,7 +912,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
return return
} }
defer conn.Close() defer func() { _ = conn.Close() }()
buf := make([]byte, 4096) buf := make([]byte, 4096)
n, _ := conn.Read(buf) n, _ := conn.Read(buf)
received <- buf[:n] received <- buf[:n]
@@ -923,7 +923,7 @@ func TestProxyProtocolNotSent(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -942,10 +942,10 @@ func TestProxyProtocolNotSent(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
hello := buildClientHello("nopp.test") hello := buildClientHello("nopp.test")
conn.Write(hello) _, _ = conn.Write(hello)
select { select {
case data := <-received: case data := <-received:
@@ -964,7 +964,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -972,7 +972,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -981,7 +981,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
t.Fatalf("finding free port: %v", err) t.Fatalf("finding free port: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
// Block 203.0.113.50 (the "real" client IP from PROXY header). // Block 203.0.113.50 (the "real" client IP from PROXY header).
// 127.0.0.1 (the actual TCP peer) is NOT blocked. // 127.0.0.1 (the actual TCP peer) is NOT blocked.
@@ -1014,7 +1014,7 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
srv.Run(ctx) _ = srv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -1022,19 +1022,19 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial proxy: %v", err) t.Fatalf("dial proxy: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
// Send PROXY v2 with the blocked real IP. // Send PROXY v2 with the blocked real IP.
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("203.0.113.50:12345"), netip.MustParseAddrPort("203.0.113.50:12345"),
netip.MustParseAddrPort("198.51.100.1:443"), netip.MustParseAddrPort("198.51.100.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test")) _, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be closed (firewall blocks real IP). // Connection should be closed (firewall blocks real IP).
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed by firewall") t.Fatal("expected connection to be closed by firewall")
@@ -1060,7 +1060,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
go func() { go func() {
for { for {
@@ -1068,7 +1068,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
return return
} }
go io.Copy(io.Discard, conn) go func() { _, _ = io.Copy(io.Discard, conn) }()
} }
}() }()
@@ -1077,7 +1077,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1100,7 +1100,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial %d: %v", i, err) t.Fatalf("dial %d: %v", i, err)
} }
conn.Write(buildClientHello("limit.test")) _, _ = conn.Write(buildClientHello("limit.test"))
conns = append(conns, conn) conns = append(conns, conn)
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -1110,16 +1110,16 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial 3: %v", err) t.Fatalf("dial 3: %v", err)
} }
conn3.Write(buildClientHello("limit.test")) _, _ = conn3.Write(buildClientHello("limit.test"))
conn3.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn3.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn3.Read(make([]byte, 1)) _, err = conn3.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected 3rd connection to be closed due to limit") t.Fatal("expected 3rd connection to be closed due to limit")
} }
conn3.Close() _ = conn3.Close()
// Close one existing connection. // Close one existing connection.
conns[0].Close() _ = conns[0].Close()
time.Sleep(200 * time.Millisecond) time.Sleep(200 * time.Millisecond)
// Now a new connection should succeed. // Now a new connection should succeed.
@@ -1127,8 +1127,8 @@ func TestConnectionLimitEnforced(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial 4: %v", err) t.Fatalf("dial 4: %v", err)
} }
defer conn4.Close() defer func() { _ = conn4.Close() }()
conn4.Write(buildClientHello("limit.test")) _, _ = conn4.Write(buildClientHello("limit.test"))
// Give it time to be proxied. // Give it time to be proxied.
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@@ -1138,7 +1138,7 @@ func TestConnectionLimitEnforced(t *testing.T) {
// Clean up. // Clean up.
for _, c := range conns[1:] { for _, c := range conns[1:] {
c.Close() _ = c.Close()
} }
} }
@@ -1155,7 +1155,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
// h2c backend on origin that echoes the X-Forwarded-For. // h2c backend on origin that echoes the X-Forwarded-For.
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For")) _, _ = fmt.Fprintf(w, "xff=%s", r.Header.Get("X-Forwarded-For"))
})) }))
// Origin proxy: proxy_protocol=true listener, L7 route to backend. // Origin proxy: proxy_protocol=true listener, L7 route to backend.
@@ -1164,7 +1164,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("origin listen: %v", err) t.Fatalf("origin listen: %v", err)
} }
originAddr := originLn.Addr().String() originAddr := originLn.Addr().String()
originLn.Close() _ = originLn.Close()
originFw, _ := firewall.New("", nil, nil, nil, 0, 0) originFw, _ := firewall.New("", nil, nil, nil, 0, 0)
originCfg := &config.Config{ originCfg := &config.Config{
@@ -1196,7 +1196,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
originWg.Add(1) originWg.Add(1)
go func() { go func() {
defer originWg.Done() defer originWg.Done()
originSrv.Run(originCtx) _ = originSrv.Run(originCtx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
defer func() { defer func() {
@@ -1210,7 +1210,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
t.Fatalf("edge listen: %v", err) t.Fatalf("edge listen: %v", err)
} }
edgeAddr := edgeLn.Addr().String() edgeAddr := edgeLn.Addr().String()
edgeLn.Close() _ = edgeLn.Close()
edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0) edgeFw, _ := firewall.New("", nil, nil, nil, 0, 0)
edgeSrv := New(originCfg, edgeFw, []ListenerData{ edgeSrv := New(originCfg, edgeFw, []ListenerData{
@@ -1232,7 +1232,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
edgeWg.Add(1) edgeWg.Add(1)
go func() { go func() {
defer edgeWg.Done() defer edgeWg.Done()
edgeSrv.Run(edgeCtx) _ = edgeSrv.Run(edgeCtx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
defer func() { defer func() {
@@ -1253,7 +1253,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial edge: %v", err) t.Fatalf("TLS dial edge: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn) h2conn, err := tr.NewClientConn(conn)
@@ -1266,7 +1266,7 @@ func TestMultiHopProxyProtocol(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("RoundTrip: %v", err) t.Fatalf("RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != 200 { if resp.StatusCode != 200 {
@@ -1289,7 +1289,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("backend listen: %v", err) t.Fatalf("backend listen: %v", err)
} }
defer backendLn.Close() defer func() { _ = backendLn.Close() }()
reached := make(chan struct{}, 1) reached := make(chan struct{}, 1)
go func() { go func() {
@@ -1297,7 +1297,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
return return
} }
conn.Close() _ = conn.Close()
reached <- struct{}{} reached <- struct{}{}
}() }()
@@ -1306,7 +1306,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
t.Fatalf("origin listen: %v", err) t.Fatalf("origin listen: %v", err)
} }
originAddr := originLn.Addr().String() originAddr := originLn.Addr().String()
originLn.Close() _ = originLn.Close()
// Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header. // Block 198.51.100.99 — this is the "real client IP" we'll put in the PROXY header.
originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0) originFw, _ := firewall.New("", []string{"198.51.100.99"}, nil, nil, 0, 0)
@@ -1334,7 +1334,7 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
wg.Add(1) wg.Add(1)
go func() { go func() {
defer wg.Done() defer wg.Done()
originSrv.Run(ctx) _ = originSrv.Run(ctx)
}() }()
time.Sleep(50 * time.Millisecond) time.Sleep(50 * time.Millisecond)
@@ -1343,18 +1343,18 @@ func TestMultiHopFirewallBlocksRealIP(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial origin: %v", err) t.Fatalf("dial origin: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
var ppBuf bytes.Buffer var ppBuf bytes.Buffer
proxyproto.WriteV2(&ppBuf, _ = proxyproto.WriteV2(&ppBuf,
netip.MustParseAddrPort("198.51.100.99:12345"), netip.MustParseAddrPort("198.51.100.99:12345"),
netip.MustParseAddrPort("10.0.0.1:443"), netip.MustParseAddrPort("10.0.0.1:443"),
) )
conn.Write(ppBuf.Bytes()) _, _ = conn.Write(ppBuf.Bytes())
conn.Write(buildClientHello("blocked.test")) _, _ = conn.Write(buildClientHello("blocked.test"))
// Connection should be dropped by firewall. // Connection should be dropped by firewall.
conn.SetReadDeadline(time.Now().Add(2 * time.Second)) _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second))
_, err = conn.Read(make([]byte, 1)) _, err = conn.Read(make([]byte, 1))
if err == nil { if err == nil {
t.Fatal("expected connection to be closed") t.Fatal("expected connection to be closed")
@@ -1396,12 +1396,12 @@ func testCert(t *testing.T, hostname string) (certPath, keyPath string) {
certPath = filepath.Join(dir, "cert.pem") certPath = filepath.Join(dir, "cert.pem")
keyPath = filepath.Join(dir, "key.pem") keyPath = filepath.Join(dir, "key.pem")
cf, _ := os.Create(certPath) cf, _ := os.Create(certPath)
pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) _ = pem.Encode(cf, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
cf.Close() _ = cf.Close()
keyDER, _ := x509.MarshalECPrivateKey(key) keyDER, _ := x509.MarshalECPrivateKey(key)
kf, _ := os.Create(keyPath) kf, _ := os.Create(keyPath)
pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) _ = pem.Encode(kf, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})
kf.Close() _ = kf.Close()
return return
} }
@@ -1417,8 +1417,8 @@ func startH2CBackend(t *testing.T, handler http.Handler) string {
if err != nil { if err != nil {
t.Fatalf("listen: %v", err) t.Fatalf("listen: %v", err)
} }
t.Cleanup(func() { srv.Close(); ln.Close() }) t.Cleanup(func() { _ = srv.Close(); _ = ln.Close() })
go srv.Serve(ln) go func() { _ = srv.Serve(ln) }()
return ln.Addr().String() return ln.Addr().String()
} }
@@ -1426,7 +1426,7 @@ func TestL7ThroughServer(t *testing.T) {
certPath, keyPath := testCert(t, "l7srv.test") certPath, keyPath := testCert(t, "l7srv.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For")) _, _ = fmt.Fprintf(w, "ok path=%s xff=%s", r.URL.Path, r.Header.Get("X-Forwarded-For"))
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1434,7 +1434,7 @@ func TestL7ThroughServer(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1467,7 +1467,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial: %v", err) t.Fatalf("TLS dial: %v", err)
} }
defer conn.Close() defer func() { _ = conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(conn) h2conn, err := tr.NewClientConn(conn)
@@ -1480,7 +1480,7 @@ func TestL7ThroughServer(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("RoundTrip: %v", err) t.Fatalf("RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
// The X-Forwarded-For should be the TCP source IP (127.0.0.1) since // The X-Forwarded-For should be the TCP source IP (127.0.0.1) since
@@ -1502,12 +1502,12 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("l4 backend listen: %v", err) t.Fatalf("l4 backend listen: %v", err)
} }
defer l4BackendLn.Close() defer func() { _ = l4BackendLn.Close() }()
go echoServer(t, l4BackendLn) go echoServer(t, l4BackendLn)
// L7 backend: h2c HTTP server. // L7 backend: h2c HTTP server.
l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l7BackendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "l7-response") _, _ = fmt.Fprint(w, "l7-response")
})) }))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0") proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
@@ -1515,7 +1515,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
t.Fatalf("proxy listen: %v", err) t.Fatalf("proxy listen: %v", err)
} }
proxyAddr := proxyLn.Addr().String() proxyAddr := proxyLn.Addr().String()
proxyLn.Close() _ = proxyLn.Close()
srv := newTestServer(t, []ListenerData{ srv := newTestServer(t, []ListenerData{
{ {
@@ -1541,11 +1541,11 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("dial L4: %v", err) t.Fatalf("dial L4: %v", err)
} }
defer l4Conn.Close() defer func() { _ = l4Conn.Close() }()
hello := buildClientHello("l4echo.test") hello := buildClientHello("l4echo.test")
l4Conn.Write(hello) _, _ = l4Conn.Write(hello)
echoed := make([]byte, len(hello)) echoed := make([]byte, len(hello))
l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second)) _ = l4Conn.SetReadDeadline(time.Now().Add(5 * time.Second))
if _, err := io.ReadFull(l4Conn, echoed); err != nil { if _, err := io.ReadFull(l4Conn, echoed); err != nil {
t.Fatalf("L4 echo read: %v", err) t.Fatalf("L4 echo read: %v", err)
} }
@@ -1563,7 +1563,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("TLS dial L7: %v", err) t.Fatalf("TLS dial L7: %v", err)
} }
defer l7Conn.Close() defer func() { _ = l7Conn.Close() }()
tr := &http2.Transport{} tr := &http2.Transport{}
h2conn, err := tr.NewClientConn(l7Conn) h2conn, err := tr.NewClientConn(l7Conn)
@@ -1576,7 +1576,7 @@ func TestMixedL4L7SameListener(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("L7 RoundTrip: %v", err) t.Fatalf("L7 RoundTrip: %v", err)
} }
defer resp.Body.Close() defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body) body, _ := io.ReadAll(resp.Body)
if string(body) != "l7-response" { if string(body) != "l7-response" {
@@ -1593,11 +1593,11 @@ func buildClientHello(serverName string) []byte {
func buildClientHelloWithExtensions(extensions []byte) []byte { func buildClientHelloWithExtensions(extensions []byte) []byte {
var hello []byte var hello []byte
hello = append(hello, 0x03, 0x03) // TLS 1.2 hello = append(hello, 0x03, 0x03) // TLS 1.2
hello = append(hello, make([]byte, 32)...) // random hello = append(hello, make([]byte, 32)...) // random
hello = append(hello, 0x00) // session ID: empty hello = append(hello, 0x00) // session ID: empty
hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites hello = append(hello, 0x00, 0x02, 0x00, 0x9C) // cipher suites
hello = append(hello, 0x01, 0x00) // compression methods hello = append(hello, 0x01, 0x00) // compression methods
if len(extensions) > 0 { if len(extensions) > 0 {
hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions))) hello = binary.BigEndian.AppendUint16(hello, uint16(len(extensions)))
@@ -1636,4 +1636,3 @@ func sniExtension(serverName string) []byte {
return ext return ext
} }

View File

@@ -17,8 +17,8 @@ const maxBufferSize = 16384 // 16 KiB, max TLS record size
// //
// A read deadline is set on the connection to prevent slowloris attacks. // A read deadline is set on the connection to prevent slowloris attacks.
func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) { func Extract(conn net.Conn, deadline time.Time) (hostname string, peeked []byte, err error) {
conn.SetReadDeadline(deadline) _ = conn.SetReadDeadline(deadline)
defer conn.SetReadDeadline(time.Time{}) defer func() { _ = conn.SetReadDeadline(time.Time{}) }()
// Read TLS record header (5 bytes). // Read TLS record header (5 bytes).
header := make([]byte, 5) header := make([]byte, 5)

View File

@@ -22,13 +22,13 @@ func TestExtract(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHello(tt.sni) hello := buildClientHello(tt.sni)
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second)) hostname, peeked, err := Extract(server, time.Now().Add(5*time.Second))
@@ -53,13 +53,13 @@ func TestExtract(t *testing.T) {
func TestExtractNoSNI(t *testing.T) { func TestExtractNoSNI(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHelloNoSNI() hello := buildClientHelloNoSNI()
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -70,11 +70,11 @@ func TestExtractNoSNI(t *testing.T) {
func TestExtractNotTLS(t *testing.T) { func TestExtractNotTLS(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")) _, _ = client.Write([]byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n"))
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -85,13 +85,13 @@ func TestExtractNotTLS(t *testing.T) {
func TestExtractTruncated(t *testing.T) { func TestExtractTruncated(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
// Write just the TLS record header, then close. // Write just the TLS record header, then close.
client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50}) _, _ = client.Write([]byte{0x16, 0x03, 0x01, 0x00, 0x50})
client.Close() _ = client.Close()
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -102,15 +102,15 @@ func TestExtractTruncated(t *testing.T) {
func TestExtractOversizedRecord(t *testing.T) { func TestExtractOversizedRecord(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
go func() { go func() {
// Record header claiming a length larger than 16 KiB. // Record header claiming a length larger than 16 KiB.
header := []byte{0x16, 0x03, 0x01} header := []byte{0x16, 0x03, 0x01}
header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5 header = binary.BigEndian.AppendUint16(header, 16384) // exceeds maxBufferSize - 5
client.Write(header) _, _ = client.Write(header)
client.Close() _ = client.Close()
}() }()
_, _, err := Extract(server, time.Now().Add(5*time.Second)) _, _, err := Extract(server, time.Now().Add(5*time.Second))
@@ -121,13 +121,13 @@ func TestExtractOversizedRecord(t *testing.T) {
func TestExtractMultipleExtensions(t *testing.T) { func TestExtractMultipleExtensions(t *testing.T) {
client, server := net.Pipe() client, server := net.Pipe()
defer client.Close() defer func() { _ = client.Close() }()
defer server.Close() defer func() { _ = server.Close() }()
hello := buildClientHelloWithExtraExtensions("target.example.com") hello := buildClientHelloWithExtraExtensions("target.example.com")
go func() { go func() {
client.Write(hello) _, _ = client.Write(hello)
}() }()
hostname, _, err := Extract(server, time.Now().Add(5*time.Second)) hostname, _, err := Extract(server, time.Now().Add(5*time.Second))