diff --git a/PROGRESS.md b/PROGRESS.md index 213d787..95f785f 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -49,19 +49,42 @@ proceeds. Each item is marked: - [x] 5.4 Web UI through L7 validation (HTTP/1.1 fallback, HTTP/2) - [x] 5.5 Documentation (verified ARCHITECTURE.md, CLAUDE.md, Makefile) +## Phase 6: Per-Listener Connection Limits + +- [x] 6.1 Config: `MaxConnections` on Listener, validation +- [x] 6.2 DB: migration 3, CRUD updates, `UpdateListenerMaxConns` +- [x] 6.3 Server: limit enforcement in `serve()`, `SetMaxConnections` method +- [x] 6.4 Proto/gRPC: `SetListenerMaxConnections` RPC, `max_connections` in status +- [x] 6.5 Client: `SetListenerMaxConnections` method, `MaxConnections` in status +- [x] 6.6 Tests: DB CRUD, server limit enforcement, gRPC round-trip + +## Phase 7: L7 Policies + +- [ ] 7.1 Config: `L7Policy` struct, `L7Policies` on Route, validation +- [ ] 7.2 DB: migration 4, `l7_policies` table, CRUD in `l7policies.go` +- [ ] 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go` +- [ ] 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig +- [ ] 7.5 Proto/gRPC: `L7Policy` message, policy management RPCs +- [ ] 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand +- [ ] 7.7 Startup: load L7 policies per route in `loadListenersFromDB` +- [ ] 7.8 Tests: middleware unit, DB CRUD + cascade, gRPC round-trip, e2e + +## Phase 8: Prometheus Metrics + +- [ ] 8.1 Dependency: add `prometheus/client_golang` +- [ ] 8.2 Config: `Metrics` section (`addr`, `path`) +- [ ] 8.3 Package: `internal/metrics/` definitions and HTTP server +- [ ] 8.4 Instrumentation: connections, firewall, dial latency, bytes, HTTP status, policy blocks +- [ ] 8.5 Firewall: `BlockedWithReason()` method +- [ ] 8.6 L7: status recording on ResponseWriter +- [ ] 8.7 Startup: conditionally start metrics server +- [ ] 8.8 Tests: metric sanity, server endpoint, `BlockedWithReason` + --- ## Current State -All five phases are complete. The codebase implements the full dual-mode -(L4/L7) architecture with PROXY protocol support: - -- L4 passthrough: SNI extraction, raw TCP relay (original behavior) -- L7 terminating: TLS termination, HTTP/2 reverse proxy with h2c backends -- PROXY protocol: v1/v2 receive on listeners, v2 send on routes -- Per-route mode selection: L4 and L7 routes coexist on the same listener -- gRPC admin API: full CRUD for routes with L7 fields, PROXY protocol flags -- CLI tools: `mcproxyctl routes add` with `--mode`, `--tls-cert`, etc. -- Multi-hop deployment tested: edge→origin with real client IP preservation +Phases 1-6 complete. Per-listener connection limits are implemented and +tested. L7 policies and Prometheus metrics are next. `go vet` and `go test` pass across all 13 packages. diff --git a/PROJECT_PLAN.md b/PROJECT_PLAN.md index 3ee9a4c..5754ab4 100644 --- a/PROJECT_PLAN.md +++ b/PROJECT_PLAN.md @@ -329,3 +329,48 @@ Test that htmx-based web UIs work through the L7 proxy: - Verify ARCHITECTURE.md matches final implementation. - Update CLAUDE.md if any package structure or rules changed. - Update Makefile if new build targets are needed. + +--- + +## Phase 6: Per-Listener Connection Limits + +Add configurable maximum concurrent connection limits per listener. + +### 6.1 Config: `MaxConnections int64` on `Listener` (0 = unlimited) +### 6.2 DB: migration 3 adds `listeners.max_connections`, CRUD updates +### 6.3 Server: enforce limit in `serve()` after Accept, before handleConn +### 6.4 Proto/gRPC: `SetListenerMaxConnections` RPC, `max_connections` in `ListenerStatus` +### 6.5 Client/CLI: `SetListenerMaxConnections` method, status display +### 6.6 Tests: DB CRUD, server limit enforcement, gRPC round-trip + +--- + +## Phase 7: L7 Policies + +Per-route HTTP blocking rules for L7 routes: user-agent blocking +(substring match) and required header enforcement. + +### 7.1 Config: `L7Policy` struct (`type` + `value`), `L7Policies` on Route +### 7.2 DB: migration 4 creates `l7_policies` table, new `l7policies.go` CRUD +### 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go` +### 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig +### 7.5 Proto/gRPC: `L7Policy` message, `ListL7Policies`/`AddL7Policy`/`RemoveL7Policy` RPCs +### 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand +### 7.7 Startup: load L7 policies per route in `loadListenersFromDB` +### 7.8 Tests: middleware unit tests, DB CRUD + cascade, gRPC round-trip, e2e + +--- + +## Phase 8: Prometheus Metrics + +Instrument the proxy with Prometheus-compatible metrics exposed via a +separate HTTP endpoint. + +### 8.1 Dependency: add `prometheus/client_golang` +### 8.2 Config: `Metrics` section (`addr`, `path`) +### 8.3 Package: `internal/metrics/` with metric definitions and HTTP server +### 8.4 Instrumentation: connections, firewall blocks, dial latency, bytes, HTTP status codes, policy blocks +### 8.5 Firewall: add `BlockedWithReason()` method +### 8.6 L7: status recording wrapper on ResponseWriter +### 8.7 Startup: conditionally start metrics server +### 8.8 Tests: metric sanity, server endpoint, `BlockedWithReason` diff --git a/client/mcproxy/client.go b/client/mcproxy/client.go index 1fe82a9..353ca10 100644 --- a/client/mcproxy/client.go +++ b/client/mcproxy/client.go @@ -161,6 +161,7 @@ type ListenerStatus struct { RouteCount int ActiveConnections int64 ProxyProtocol bool + MaxConnections int64 } // Status contains the server's current status. @@ -193,12 +194,23 @@ func (c *Client) GetStatus(ctx context.Context) (*Status, error) { RouteCount: int(ls.RouteCount), ActiveConnections: ls.ActiveConnections, ProxyProtocol: ls.ProxyProtocol, + MaxConnections: ls.MaxConnections, } } return status, nil } +// SetListenerMaxConnections updates the per-listener connection limit. +// 0 means unlimited. +func (c *Client) SetListenerMaxConnections(ctx context.Context, listenerAddr string, maxConns int64) error { + _, err := c.admin.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ + ListenerAddr: listenerAddr, + MaxConnections: maxConns, + }) + return err +} + // HealthStatus represents the health of the server. type HealthStatus int diff --git a/cmd/mc-proxy/server.go b/cmd/mc-proxy/server.go index cbb2b7e..0b7ca68 100644 --- a/cmd/mc-proxy/server.go +++ b/cmd/mc-proxy/server.go @@ -142,10 +142,11 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) { } } result = append(result, server.ListenerData{ - ID: l.ID, - Addr: l.Addr, - ProxyProtocol: l.ProxyProtocol, - Routes: routes, + ID: l.ID, + Addr: l.Addr, + ProxyProtocol: l.ProxyProtocol, + MaxConnections: l.MaxConnections, + Routes: routes, }) } return result, nil diff --git a/gen/mc_proxy/v1/admin.pb.go b/gen/mc_proxy/v1/admin.pb.go index b937671..516991b 100644 --- a/gen/mc_proxy/v1/admin.pb.go +++ b/gen/mc_proxy/v1/admin.pb.go @@ -730,19 +730,108 @@ func (*RemoveFirewallRuleResponse) Descriptor() ([]byte, []int) { return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{13} } +type SetListenerMaxConnectionsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + ListenerAddr string `protobuf:"bytes,1,opt,name=listener_addr,json=listenerAddr,proto3" json:"listener_addr,omitempty"` + MaxConnections int64 `protobuf:"varint,2,opt,name=max_connections,json=maxConnections,proto3" json:"max_connections,omitempty"` // 0 = unlimited + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetListenerMaxConnectionsRequest) Reset() { + *x = SetListenerMaxConnectionsRequest{} + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetListenerMaxConnectionsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetListenerMaxConnectionsRequest) ProtoMessage() {} + +func (x *SetListenerMaxConnectionsRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetListenerMaxConnectionsRequest.ProtoReflect.Descriptor instead. +func (*SetListenerMaxConnectionsRequest) Descriptor() ([]byte, []int) { + return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{14} +} + +func (x *SetListenerMaxConnectionsRequest) GetListenerAddr() string { + if x != nil { + return x.ListenerAddr + } + return "" +} + +func (x *SetListenerMaxConnectionsRequest) GetMaxConnections() int64 { + if x != nil { + return x.MaxConnections + } + return 0 +} + +type SetListenerMaxConnectionsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetListenerMaxConnectionsResponse) Reset() { + *x = SetListenerMaxConnectionsResponse{} + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetListenerMaxConnectionsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetListenerMaxConnectionsResponse) ProtoMessage() {} + +func (x *SetListenerMaxConnectionsResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetListenerMaxConnectionsResponse.ProtoReflect.Descriptor instead. +func (*SetListenerMaxConnectionsResponse) Descriptor() ([]byte, []int) { + return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{15} +} + type ListenerStatus struct { state protoimpl.MessageState `protogen:"open.v1"` Addr string `protobuf:"bytes,1,opt,name=addr,proto3" json:"addr,omitempty"` RouteCount int32 `protobuf:"varint,2,opt,name=route_count,json=routeCount,proto3" json:"route_count,omitempty"` ActiveConnections int64 `protobuf:"varint,3,opt,name=active_connections,json=activeConnections,proto3" json:"active_connections,omitempty"` ProxyProtocol bool `protobuf:"varint,4,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` + MaxConnections int64 `protobuf:"varint,5,opt,name=max_connections,json=maxConnections,proto3" json:"max_connections,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *ListenerStatus) Reset() { *x = ListenerStatus{} - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[14] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[16] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -754,7 +843,7 @@ func (x *ListenerStatus) String() string { func (*ListenerStatus) ProtoMessage() {} func (x *ListenerStatus) ProtoReflect() protoreflect.Message { - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[14] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[16] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -767,7 +856,7 @@ func (x *ListenerStatus) ProtoReflect() protoreflect.Message { // Deprecated: Use ListenerStatus.ProtoReflect.Descriptor instead. func (*ListenerStatus) Descriptor() ([]byte, []int) { - return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{14} + return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{16} } func (x *ListenerStatus) GetAddr() string { @@ -798,6 +887,13 @@ func (x *ListenerStatus) GetProxyProtocol() bool { return false } +func (x *ListenerStatus) GetMaxConnections() int64 { + if x != nil { + return x.MaxConnections + } + return 0 +} + type GetStatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -806,7 +902,7 @@ type GetStatusRequest struct { func (x *GetStatusRequest) Reset() { *x = GetStatusRequest{} - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[15] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[17] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -818,7 +914,7 @@ func (x *GetStatusRequest) String() string { func (*GetStatusRequest) ProtoMessage() {} func (x *GetStatusRequest) ProtoReflect() protoreflect.Message { - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[15] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[17] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -831,7 +927,7 @@ func (x *GetStatusRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use GetStatusRequest.ProtoReflect.Descriptor instead. func (*GetStatusRequest) Descriptor() ([]byte, []int) { - return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{15} + return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{17} } type GetStatusResponse struct { @@ -846,7 +942,7 @@ type GetStatusResponse struct { func (x *GetStatusResponse) Reset() { *x = GetStatusResponse{} - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[16] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[18] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -858,7 +954,7 @@ func (x *GetStatusResponse) String() string { func (*GetStatusResponse) ProtoMessage() {} func (x *GetStatusResponse) ProtoReflect() protoreflect.Message { - mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[16] + mi := &file_proto_mc_proxy_v1_admin_proto_msgTypes[18] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -871,7 +967,7 @@ func (x *GetStatusResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetStatusResponse.ProtoReflect.Descriptor instead. func (*GetStatusResponse) Descriptor() ([]byte, []int) { - return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{16} + return file_proto_mc_proxy_v1_admin_proto_rawDescGZIP(), []int{18} } func (x *GetStatusResponse) GetVersion() string { @@ -940,13 +1036,18 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" + "\x17AddFirewallRuleResponse\"J\n" + "\x19RemoveFirewallRuleRequest\x12-\n" + "\x04rule\x18\x01 \x01(\v2\x19.mc_proxy.v1.FirewallRuleR\x04rule\"\x1c\n" + - "\x1aRemoveFirewallRuleResponse\"\x9b\x01\n" + + "\x1aRemoveFirewallRuleResponse\"p\n" + + " SetListenerMaxConnectionsRequest\x12#\n" + + "\rlistener_addr\x18\x01 \x01(\tR\flistenerAddr\x12'\n" + + "\x0fmax_connections\x18\x02 \x01(\x03R\x0emaxConnections\"#\n" + + "!SetListenerMaxConnectionsResponse\"\xc4\x01\n" + "\x0eListenerStatus\x12\x12\n" + "\x04addr\x18\x01 \x01(\tR\x04addr\x12\x1f\n" + "\vroute_count\x18\x02 \x01(\x05R\n" + "routeCount\x12-\n" + "\x12active_connections\x18\x03 \x01(\x03R\x11activeConnections\x12%\n" + - "\x0eproxy_protocol\x18\x04 \x01(\bR\rproxyProtocol\"\x12\n" + + "\x0eproxy_protocol\x18\x04 \x01(\bR\rproxyProtocol\x12'\n" + + "\x0fmax_connections\x18\x05 \x01(\x03R\x0emaxConnections\"\x12\n" + "\x10GetStatusRequest\"\xd0\x01\n" + "\x11GetStatusResponse\x12\x18\n" + "\aversion\x18\x01 \x01(\tR\aversion\x129\n" + @@ -958,7 +1059,7 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" + "\x1eFIREWALL_RULE_TYPE_UNSPECIFIED\x10\x00\x12\x19\n" + "\x15FIREWALL_RULE_TYPE_IP\x10\x01\x12\x1b\n" + "\x17FIREWALL_RULE_TYPE_CIDR\x10\x02\x12\x1e\n" + - "\x1aFIREWALL_RULE_TYPE_COUNTRY\x10\x032\xef\x04\n" + + "\x1aFIREWALL_RULE_TYPE_COUNTRY\x10\x032\xeb\x05\n" + "\x11ProxyAdminService\x12M\n" + "\n" + "ListRoutes\x12\x1e.mc_proxy.v1.ListRoutesRequest\x1a\x1f.mc_proxy.v1.ListRoutesResponse\x12G\n" + @@ -966,7 +1067,8 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" + "\vRemoveRoute\x12\x1f.mc_proxy.v1.RemoveRouteRequest\x1a .mc_proxy.v1.RemoveRouteResponse\x12_\n" + "\x10GetFirewallRules\x12$.mc_proxy.v1.GetFirewallRulesRequest\x1a%.mc_proxy.v1.GetFirewallRulesResponse\x12\\\n" + "\x0fAddFirewallRule\x12#.mc_proxy.v1.AddFirewallRuleRequest\x1a$.mc_proxy.v1.AddFirewallRuleResponse\x12e\n" + - "\x12RemoveFirewallRule\x12&.mc_proxy.v1.RemoveFirewallRuleRequest\x1a'.mc_proxy.v1.RemoveFirewallRuleResponse\x12J\n" + + "\x12RemoveFirewallRule\x12&.mc_proxy.v1.RemoveFirewallRuleRequest\x1a'.mc_proxy.v1.RemoveFirewallRuleResponse\x12z\n" + + "\x19SetListenerMaxConnections\x12-.mc_proxy.v1.SetListenerMaxConnectionsRequest\x1a..mc_proxy.v1.SetListenerMaxConnectionsResponse\x12J\n" + "\tGetStatus\x12\x1d.mc_proxy.v1.GetStatusRequest\x1a\x1e.mc_proxy.v1.GetStatusResponseB:Z8git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1;mcproxyv1b\x06proto3" var ( @@ -982,27 +1084,29 @@ func file_proto_mc_proxy_v1_admin_proto_rawDescGZIP() []byte { } var file_proto_mc_proxy_v1_admin_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_proto_mc_proxy_v1_admin_proto_msgTypes = make([]protoimpl.MessageInfo, 17) +var file_proto_mc_proxy_v1_admin_proto_msgTypes = make([]protoimpl.MessageInfo, 19) var file_proto_mc_proxy_v1_admin_proto_goTypes = []any{ - (FirewallRuleType)(0), // 0: mc_proxy.v1.FirewallRuleType - (*Route)(nil), // 1: mc_proxy.v1.Route - (*ListRoutesRequest)(nil), // 2: mc_proxy.v1.ListRoutesRequest - (*ListRoutesResponse)(nil), // 3: mc_proxy.v1.ListRoutesResponse - (*AddRouteRequest)(nil), // 4: mc_proxy.v1.AddRouteRequest - (*AddRouteResponse)(nil), // 5: mc_proxy.v1.AddRouteResponse - (*RemoveRouteRequest)(nil), // 6: mc_proxy.v1.RemoveRouteRequest - (*RemoveRouteResponse)(nil), // 7: mc_proxy.v1.RemoveRouteResponse - (*FirewallRule)(nil), // 8: mc_proxy.v1.FirewallRule - (*GetFirewallRulesRequest)(nil), // 9: mc_proxy.v1.GetFirewallRulesRequest - (*GetFirewallRulesResponse)(nil), // 10: mc_proxy.v1.GetFirewallRulesResponse - (*AddFirewallRuleRequest)(nil), // 11: mc_proxy.v1.AddFirewallRuleRequest - (*AddFirewallRuleResponse)(nil), // 12: mc_proxy.v1.AddFirewallRuleResponse - (*RemoveFirewallRuleRequest)(nil), // 13: mc_proxy.v1.RemoveFirewallRuleRequest - (*RemoveFirewallRuleResponse)(nil), // 14: mc_proxy.v1.RemoveFirewallRuleResponse - (*ListenerStatus)(nil), // 15: mc_proxy.v1.ListenerStatus - (*GetStatusRequest)(nil), // 16: mc_proxy.v1.GetStatusRequest - (*GetStatusResponse)(nil), // 17: mc_proxy.v1.GetStatusResponse - (*timestamppb.Timestamp)(nil), // 18: google.protobuf.Timestamp + (FirewallRuleType)(0), // 0: mc_proxy.v1.FirewallRuleType + (*Route)(nil), // 1: mc_proxy.v1.Route + (*ListRoutesRequest)(nil), // 2: mc_proxy.v1.ListRoutesRequest + (*ListRoutesResponse)(nil), // 3: mc_proxy.v1.ListRoutesResponse + (*AddRouteRequest)(nil), // 4: mc_proxy.v1.AddRouteRequest + (*AddRouteResponse)(nil), // 5: mc_proxy.v1.AddRouteResponse + (*RemoveRouteRequest)(nil), // 6: mc_proxy.v1.RemoveRouteRequest + (*RemoveRouteResponse)(nil), // 7: mc_proxy.v1.RemoveRouteResponse + (*FirewallRule)(nil), // 8: mc_proxy.v1.FirewallRule + (*GetFirewallRulesRequest)(nil), // 9: mc_proxy.v1.GetFirewallRulesRequest + (*GetFirewallRulesResponse)(nil), // 10: mc_proxy.v1.GetFirewallRulesResponse + (*AddFirewallRuleRequest)(nil), // 11: mc_proxy.v1.AddFirewallRuleRequest + (*AddFirewallRuleResponse)(nil), // 12: mc_proxy.v1.AddFirewallRuleResponse + (*RemoveFirewallRuleRequest)(nil), // 13: mc_proxy.v1.RemoveFirewallRuleRequest + (*RemoveFirewallRuleResponse)(nil), // 14: mc_proxy.v1.RemoveFirewallRuleResponse + (*SetListenerMaxConnectionsRequest)(nil), // 15: mc_proxy.v1.SetListenerMaxConnectionsRequest + (*SetListenerMaxConnectionsResponse)(nil), // 16: mc_proxy.v1.SetListenerMaxConnectionsResponse + (*ListenerStatus)(nil), // 17: mc_proxy.v1.ListenerStatus + (*GetStatusRequest)(nil), // 18: mc_proxy.v1.GetStatusRequest + (*GetStatusResponse)(nil), // 19: mc_proxy.v1.GetStatusResponse + (*timestamppb.Timestamp)(nil), // 20: google.protobuf.Timestamp } var file_proto_mc_proxy_v1_admin_proto_depIdxs = []int32{ 1, // 0: mc_proxy.v1.ListRoutesResponse.routes:type_name -> mc_proxy.v1.Route @@ -1011,24 +1115,26 @@ var file_proto_mc_proxy_v1_admin_proto_depIdxs = []int32{ 8, // 3: mc_proxy.v1.GetFirewallRulesResponse.rules:type_name -> mc_proxy.v1.FirewallRule 8, // 4: mc_proxy.v1.AddFirewallRuleRequest.rule:type_name -> mc_proxy.v1.FirewallRule 8, // 5: mc_proxy.v1.RemoveFirewallRuleRequest.rule:type_name -> mc_proxy.v1.FirewallRule - 18, // 6: mc_proxy.v1.GetStatusResponse.started_at:type_name -> google.protobuf.Timestamp - 15, // 7: mc_proxy.v1.GetStatusResponse.listeners:type_name -> mc_proxy.v1.ListenerStatus + 20, // 6: mc_proxy.v1.GetStatusResponse.started_at:type_name -> google.protobuf.Timestamp + 17, // 7: mc_proxy.v1.GetStatusResponse.listeners:type_name -> mc_proxy.v1.ListenerStatus 2, // 8: mc_proxy.v1.ProxyAdminService.ListRoutes:input_type -> mc_proxy.v1.ListRoutesRequest 4, // 9: mc_proxy.v1.ProxyAdminService.AddRoute:input_type -> mc_proxy.v1.AddRouteRequest 6, // 10: mc_proxy.v1.ProxyAdminService.RemoveRoute:input_type -> mc_proxy.v1.RemoveRouteRequest 9, // 11: mc_proxy.v1.ProxyAdminService.GetFirewallRules:input_type -> mc_proxy.v1.GetFirewallRulesRequest 11, // 12: mc_proxy.v1.ProxyAdminService.AddFirewallRule:input_type -> mc_proxy.v1.AddFirewallRuleRequest 13, // 13: mc_proxy.v1.ProxyAdminService.RemoveFirewallRule:input_type -> mc_proxy.v1.RemoveFirewallRuleRequest - 16, // 14: mc_proxy.v1.ProxyAdminService.GetStatus:input_type -> mc_proxy.v1.GetStatusRequest - 3, // 15: mc_proxy.v1.ProxyAdminService.ListRoutes:output_type -> mc_proxy.v1.ListRoutesResponse - 5, // 16: mc_proxy.v1.ProxyAdminService.AddRoute:output_type -> mc_proxy.v1.AddRouteResponse - 7, // 17: mc_proxy.v1.ProxyAdminService.RemoveRoute:output_type -> mc_proxy.v1.RemoveRouteResponse - 10, // 18: mc_proxy.v1.ProxyAdminService.GetFirewallRules:output_type -> mc_proxy.v1.GetFirewallRulesResponse - 12, // 19: mc_proxy.v1.ProxyAdminService.AddFirewallRule:output_type -> mc_proxy.v1.AddFirewallRuleResponse - 14, // 20: mc_proxy.v1.ProxyAdminService.RemoveFirewallRule:output_type -> mc_proxy.v1.RemoveFirewallRuleResponse - 17, // 21: mc_proxy.v1.ProxyAdminService.GetStatus:output_type -> mc_proxy.v1.GetStatusResponse - 15, // [15:22] is the sub-list for method output_type - 8, // [8:15] is the sub-list for method input_type + 15, // 14: mc_proxy.v1.ProxyAdminService.SetListenerMaxConnections:input_type -> mc_proxy.v1.SetListenerMaxConnectionsRequest + 18, // 15: mc_proxy.v1.ProxyAdminService.GetStatus:input_type -> mc_proxy.v1.GetStatusRequest + 3, // 16: mc_proxy.v1.ProxyAdminService.ListRoutes:output_type -> mc_proxy.v1.ListRoutesResponse + 5, // 17: mc_proxy.v1.ProxyAdminService.AddRoute:output_type -> mc_proxy.v1.AddRouteResponse + 7, // 18: mc_proxy.v1.ProxyAdminService.RemoveRoute:output_type -> mc_proxy.v1.RemoveRouteResponse + 10, // 19: mc_proxy.v1.ProxyAdminService.GetFirewallRules:output_type -> mc_proxy.v1.GetFirewallRulesResponse + 12, // 20: mc_proxy.v1.ProxyAdminService.AddFirewallRule:output_type -> mc_proxy.v1.AddFirewallRuleResponse + 14, // 21: mc_proxy.v1.ProxyAdminService.RemoveFirewallRule:output_type -> mc_proxy.v1.RemoveFirewallRuleResponse + 16, // 22: mc_proxy.v1.ProxyAdminService.SetListenerMaxConnections:output_type -> mc_proxy.v1.SetListenerMaxConnectionsResponse + 19, // 23: mc_proxy.v1.ProxyAdminService.GetStatus:output_type -> mc_proxy.v1.GetStatusResponse + 16, // [16:24] is the sub-list for method output_type + 8, // [8:16] is the sub-list for method input_type 8, // [8:8] is the sub-list for extension type_name 8, // [8:8] is the sub-list for extension extendee 0, // [0:8] is the sub-list for field type_name @@ -1045,7 +1151,7 @@ func file_proto_mc_proxy_v1_admin_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_mc_proxy_v1_admin_proto_rawDesc), len(file_proto_mc_proxy_v1_admin_proto_rawDesc)), NumEnums: 1, - NumMessages: 17, + NumMessages: 19, NumExtensions: 0, NumServices: 1, }, diff --git a/gen/mc_proxy/v1/admin_grpc.pb.go b/gen/mc_proxy/v1/admin_grpc.pb.go index 276bae5..e72bc32 100644 --- a/gen/mc_proxy/v1/admin_grpc.pb.go +++ b/gen/mc_proxy/v1/admin_grpc.pb.go @@ -19,13 +19,14 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( - ProxyAdminService_ListRoutes_FullMethodName = "/mc_proxy.v1.ProxyAdminService/ListRoutes" - ProxyAdminService_AddRoute_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddRoute" - ProxyAdminService_RemoveRoute_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveRoute" - ProxyAdminService_GetFirewallRules_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetFirewallRules" - ProxyAdminService_AddFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddFirewallRule" - ProxyAdminService_RemoveFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveFirewallRule" - ProxyAdminService_GetStatus_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetStatus" + ProxyAdminService_ListRoutes_FullMethodName = "/mc_proxy.v1.ProxyAdminService/ListRoutes" + ProxyAdminService_AddRoute_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddRoute" + ProxyAdminService_RemoveRoute_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveRoute" + ProxyAdminService_GetFirewallRules_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetFirewallRules" + ProxyAdminService_AddFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddFirewallRule" + ProxyAdminService_RemoveFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveFirewallRule" + ProxyAdminService_SetListenerMaxConnections_FullMethodName = "/mc_proxy.v1.ProxyAdminService/SetListenerMaxConnections" + ProxyAdminService_GetStatus_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetStatus" ) // ProxyAdminServiceClient is the client API for ProxyAdminService service. @@ -40,6 +41,8 @@ type ProxyAdminServiceClient interface { GetFirewallRules(ctx context.Context, in *GetFirewallRulesRequest, opts ...grpc.CallOption) (*GetFirewallRulesResponse, error) AddFirewallRule(ctx context.Context, in *AddFirewallRuleRequest, opts ...grpc.CallOption) (*AddFirewallRuleResponse, error) RemoveFirewallRule(ctx context.Context, in *RemoveFirewallRuleRequest, opts ...grpc.CallOption) (*RemoveFirewallRuleResponse, error) + // Connection limits + SetListenerMaxConnections(ctx context.Context, in *SetListenerMaxConnectionsRequest, opts ...grpc.CallOption) (*SetListenerMaxConnectionsResponse, error) // Status GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error) } @@ -112,6 +115,16 @@ func (c *proxyAdminServiceClient) RemoveFirewallRule(ctx context.Context, in *Re return out, nil } +func (c *proxyAdminServiceClient) SetListenerMaxConnections(ctx context.Context, in *SetListenerMaxConnectionsRequest, opts ...grpc.CallOption) (*SetListenerMaxConnectionsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SetListenerMaxConnectionsResponse) + err := c.cc.Invoke(ctx, ProxyAdminService_SetListenerMaxConnections_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *proxyAdminServiceClient) GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetStatusResponse) @@ -134,6 +147,8 @@ type ProxyAdminServiceServer interface { GetFirewallRules(context.Context, *GetFirewallRulesRequest) (*GetFirewallRulesResponse, error) AddFirewallRule(context.Context, *AddFirewallRuleRequest) (*AddFirewallRuleResponse, error) RemoveFirewallRule(context.Context, *RemoveFirewallRuleRequest) (*RemoveFirewallRuleResponse, error) + // Connection limits + SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error) // Status GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error) mustEmbedUnimplementedProxyAdminServiceServer() @@ -164,6 +179,9 @@ func (UnimplementedProxyAdminServiceServer) AddFirewallRule(context.Context, *Ad func (UnimplementedProxyAdminServiceServer) RemoveFirewallRule(context.Context, *RemoveFirewallRuleRequest) (*RemoveFirewallRuleResponse, error) { return nil, status.Error(codes.Unimplemented, "method RemoveFirewallRule not implemented") } +func (UnimplementedProxyAdminServiceServer) SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error) { + return nil, status.Error(codes.Unimplemented, "method SetListenerMaxConnections not implemented") +} func (UnimplementedProxyAdminServiceServer) GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error) { return nil, status.Error(codes.Unimplemented, "method GetStatus not implemented") } @@ -296,6 +314,24 @@ func _ProxyAdminService_RemoveFirewallRule_Handler(srv interface{}, ctx context. return interceptor(ctx, in, info, handler) } +func _ProxyAdminService_SetListenerMaxConnections_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetListenerMaxConnectionsRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(ProxyAdminServiceServer).SetListenerMaxConnections(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: ProxyAdminService_SetListenerMaxConnections_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(ProxyAdminServiceServer).SetListenerMaxConnections(ctx, req.(*SetListenerMaxConnectionsRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _ProxyAdminService_GetStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetStatusRequest) if err := dec(in); err != nil { @@ -345,6 +381,10 @@ var ProxyAdminService_ServiceDesc = grpc.ServiceDesc{ MethodName: "RemoveFirewallRule", Handler: _ProxyAdminService_RemoveFirewallRule_Handler, }, + { + MethodName: "SetListenerMaxConnections", + Handler: _ProxyAdminService_SetListenerMaxConnections_Handler, + }, { MethodName: "GetStatus", Handler: _ProxyAdminService_GetStatus_Handler, diff --git a/internal/config/config.go b/internal/config/config.go index f770a0b..b82e311 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -29,9 +29,10 @@ type GRPC struct { } type Listener struct { - Addr string `toml:"addr"` - ProxyProtocol bool `toml:"proxy_protocol"` - Routes []Route `toml:"routes"` + Addr string `toml:"addr"` + ProxyProtocol bool `toml:"proxy_protocol"` + MaxConnections int64 `toml:"max_connections"` // 0 = unlimited + Routes []Route `toml:"routes"` } type Route struct { @@ -176,6 +177,9 @@ func (c *Config) validate() error { if l.Addr == "" { return fmt.Errorf("listener %d: addr is required", i) } + if l.MaxConnections < 0 { + return fmt.Errorf("listener %d (%s): max_connections must not be negative", i, l.Addr) + } seen := make(map[string]bool) for j := range l.Routes { r := &l.Routes[j] diff --git a/internal/db/db_test.go b/internal/db/db_test.go index eff7b9b..a8f50f5 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -41,7 +41,7 @@ func TestIsEmpty(t *testing.T) { t.Fatal("expected empty database") } - if _, err := store.CreateListener(":443", false); err != nil { + if _, err := store.CreateListener(":443", false, 0); err != nil { t.Fatalf("create listener: %v", err) } @@ -57,7 +57,7 @@ func TestIsEmpty(t *testing.T) { func TestListenerCRUD(t *testing.T) { store := openTestDB(t) - id, err := store.CreateListener(":443", false) + id, err := store.CreateListener(":443", false, 0) if err != nil { t.Fatalf("create: %v", err) } @@ -103,7 +103,7 @@ func TestListenerCRUD(t *testing.T) { func TestListenerProxyProtocol(t *testing.T) { store := openTestDB(t) - id, err := store.CreateListener(":443", true) + id, err := store.CreateListener(":443", true, 0) if err != nil { t.Fatalf("create: %v", err) } @@ -120,13 +120,52 @@ func TestListenerProxyProtocol(t *testing.T) { } } +func TestListenerMaxConnections(t *testing.T) { + store := openTestDB(t) + + id, err := store.CreateListener(":443", false, 5000) + if err != nil { + t.Fatalf("create: %v", err) + } + + l, err := store.GetListenerByAddr(":443") + if err != nil { + t.Fatalf("get: %v", err) + } + if l.MaxConnections != 5000 { + t.Fatalf("max_connections = %d, want 5000", l.MaxConnections) + } + + // Update max connections. + if err := store.UpdateListenerMaxConns(id, 10000); err != nil { + t.Fatalf("update: %v", err) + } + + l, err = store.GetListenerByAddr(":443") + if err != nil { + t.Fatalf("get after update: %v", err) + } + if l.MaxConnections != 10000 { + t.Fatalf("max_connections = %d, want 10000", l.MaxConnections) + } + + // Set to 0 (unlimited). + if err := store.UpdateListenerMaxConns(id, 0); err != nil { + t.Fatalf("update to 0: %v", err) + } + l, _ = store.GetListenerByAddr(":443") + if l.MaxConnections != 0 { + t.Fatalf("max_connections = %d, want 0", l.MaxConnections) + } +} + func TestListenerDuplicateAddr(t *testing.T) { store := openTestDB(t) - if _, err := store.CreateListener(":443", false); err != nil { + if _, err := store.CreateListener(":443", false, 0); err != nil { t.Fatalf("first create: %v", err) } - if _, err := store.CreateListener(":443", false); err == nil { + if _, err := store.CreateListener(":443", false, 0); err == nil { t.Fatal("expected error for duplicate addr") } } @@ -134,7 +173,7 @@ func TestListenerDuplicateAddr(t *testing.T) { func TestRouteCRUD(t *testing.T) { store := openTestDB(t) - listenerID, err := store.CreateListener(":443", false) + listenerID, err := store.CreateListener(":443", false, 0) if err != nil { t.Fatalf("create listener: %v", err) } @@ -177,7 +216,7 @@ func TestRouteCRUD(t *testing.T) { func TestRouteL7Fields(t *testing.T) { store := openTestDB(t) - listenerID, _ := store.CreateListener(":443", false) + listenerID, _ := store.CreateListener(":443", false, 0) _, err := store.CreateRoute(listenerID, "api.example.com", "127.0.0.1:8080", "l7", "/certs/api.crt", "/certs/api.key", false, true) @@ -214,7 +253,7 @@ func TestRouteL7Fields(t *testing.T) { func TestRouteDuplicateHostname(t *testing.T) { store := openTestDB(t) - listenerID, _ := store.CreateListener(":443", false) + listenerID, _ := store.CreateListener(":443", false, 0) if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false); err != nil { t.Fatalf("first create: %v", err) } @@ -226,7 +265,7 @@ func TestRouteDuplicateHostname(t *testing.T) { func TestRouteCascadeDelete(t *testing.T) { store := openTestDB(t) - listenerID, _ := store.CreateListener(":443", false) + listenerID, _ := store.CreateListener(":443", false, 0) store.CreateRoute(listenerID, "a.example.com", "127.0.0.1:8443", "l4", "", "", false, false) store.CreateRoute(listenerID, "b.example.com", "127.0.0.1:9443", "l4", "", "", false, false) @@ -373,7 +412,7 @@ func TestSeed(t *testing.T) { func TestSnapshot(t *testing.T) { store := openTestDB(t) - store.CreateListener(":443", false) + store.CreateListener(":443", false, 0) dest := filepath.Join(t.TempDir(), "backup.db") if err := store.Snapshot(dest); err != nil { @@ -432,7 +471,7 @@ func TestMigrationV2Upgrade(t *testing.T) { } // Insert a listener and route with defaults to verify new columns work. - lid, err := store.CreateListener(":443", false) + lid, err := store.CreateListener(":443", false, 0) if err != nil { t.Fatalf("create listener: %v", err) } diff --git a/internal/db/listeners.go b/internal/db/listeners.go index bed5e02..750de2d 100644 --- a/internal/db/listeners.go +++ b/internal/db/listeners.go @@ -4,14 +4,15 @@ import "fmt" // Listener is a database listener record. type Listener struct { - ID int64 - Addr string - ProxyProtocol bool + ID int64 + Addr string + ProxyProtocol bool + MaxConnections int64 } // ListListeners returns all listeners. func (s *Store) ListListeners() ([]Listener, error) { - rows, err := s.db.Query("SELECT id, addr, proxy_protocol FROM listeners ORDER BY id") + rows, err := s.db.Query("SELECT id, addr, proxy_protocol, max_connections FROM listeners ORDER BY id") if err != nil { return nil, fmt.Errorf("querying listeners: %w", err) } @@ -20,7 +21,7 @@ func (s *Store) ListListeners() ([]Listener, error) { var listeners []Listener for rows.Next() { var l Listener - if err := rows.Scan(&l.ID, &l.Addr, &l.ProxyProtocol); err != nil { + if err := rows.Scan(&l.ID, &l.Addr, &l.ProxyProtocol, &l.MaxConnections); err != nil { return nil, fmt.Errorf("scanning listener: %w", err) } listeners = append(listeners, l) @@ -29,10 +30,10 @@ func (s *Store) ListListeners() ([]Listener, error) { } // CreateListener inserts a listener and returns its ID. -func (s *Store) CreateListener(addr string, proxyProtocol bool) (int64, error) { +func (s *Store) CreateListener(addr string, proxyProtocol bool, maxConnections int64) (int64, error) { result, err := s.db.Exec( - "INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)", - addr, proxyProtocol, + "INSERT INTO listeners (addr, proxy_protocol, max_connections) VALUES (?, ?, ?)", + addr, proxyProtocol, maxConnections, ) if err != nil { return 0, fmt.Errorf("inserting listener: %w", err) @@ -56,10 +57,26 @@ func (s *Store) DeleteListener(id int64) error { // GetListenerByAddr returns a listener by its address. func (s *Store) GetListenerByAddr(addr string) (Listener, error) { var l Listener - err := s.db.QueryRow("SELECT id, addr, proxy_protocol FROM listeners WHERE addr = ?", addr). - Scan(&l.ID, &l.Addr, &l.ProxyProtocol) + err := s.db.QueryRow("SELECT id, addr, proxy_protocol, max_connections FROM listeners WHERE addr = ?", addr). + Scan(&l.ID, &l.Addr, &l.ProxyProtocol, &l.MaxConnections) if err != nil { return Listener{}, fmt.Errorf("querying listener by addr %q: %w", addr, err) } return l, nil } + +// UpdateListenerMaxConns updates the max_connections for a listener. +func (s *Store) UpdateListenerMaxConns(listenerID int64, maxConns int64) error { + result, err := s.db.Exec( + "UPDATE listeners SET max_connections = ? WHERE id = ?", + maxConns, listenerID, + ) + if err != nil { + return fmt.Errorf("updating max_connections: %w", err) + } + n, _ := result.RowsAffected() + if n == 0 { + return fmt.Errorf("listener %d not found", listenerID) + } + return nil +} diff --git a/internal/db/migrations.go b/internal/db/migrations.go index 49cb4b2..82e30b6 100644 --- a/internal/db/migrations.go +++ b/internal/db/migrations.go @@ -14,6 +14,7 @@ type migration struct { var migrations = []migration{ {1, "create_core_tables", migrate001CreateCoreTables}, {2, "add_proxy_protocol_and_l7_fields", migrate002AddL7Fields}, + {3, "add_listener_max_connections", migrate003AddListenerMaxConnections}, } // Migrate runs all unapplied migrations sequentially. @@ -110,3 +111,8 @@ func migrate002AddL7Fields(tx *sql.Tx) error { } return nil } + +func migrate003AddListenerMaxConnections(tx *sql.Tx) error { + _, err := tx.Exec(`ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0`) + return err +} diff --git a/internal/db/seed.go b/internal/db/seed.go index 2abb018..3ac0b06 100644 --- a/internal/db/seed.go +++ b/internal/db/seed.go @@ -18,8 +18,8 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error { for _, l := range listeners { result, err := tx.Exec( - "INSERT INTO listeners (addr, proxy_protocol) VALUES (?, ?)", - l.Addr, l.ProxyProtocol, + "INSERT INTO listeners (addr, proxy_protocol, max_connections) VALUES (?, ?, ?)", + l.Addr, l.ProxyProtocol, l.MaxConnections, ) if err != nil { return fmt.Errorf("seeding listener %q: %w", l.Addr, err) diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index 5aa619d..2f0f58d 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -308,6 +308,28 @@ func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewa return &pb.RemoveFirewallRuleResponse{}, nil } +// SetListenerMaxConnections updates the per-listener connection limit. +func (a *AdminServer) SetListenerMaxConnections(_ context.Context, req *pb.SetListenerMaxConnectionsRequest) (*pb.SetListenerMaxConnectionsResponse, error) { + if req.MaxConnections < 0 { + return nil, status.Error(codes.InvalidArgument, "max_connections must not be negative") + } + + ls, err := a.findListener(req.ListenerAddr) + if err != nil { + return nil, err + } + + // Write-through: DB first, then memory. + if err := a.store.UpdateListenerMaxConns(ls.ID, req.MaxConnections); err != nil { + return nil, status.Errorf(codes.Internal, "%v", err) + } + + ls.SetMaxConnections(req.MaxConnections) + + a.logger.Info("connection limit updated", "listener", ls.Addr, "max_connections", req.MaxConnections) + return &pb.SetListenerMaxConnectionsResponse{}, nil +} + // GetStatus returns the proxy's current status. func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) { var listeners []*pb.ListenerStatus @@ -318,6 +340,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb. RouteCount: int32(len(routes)), ActiveConnections: ls.ActiveConnections.Load(), ProxyProtocol: ls.ProxyProtocol, + MaxConnections: ls.MaxConnections, }) } diff --git a/internal/grpcserver/grpcserver_test.go b/internal/grpcserver/grpcserver_test.go index 061fc01..30e1f2f 100644 --- a/internal/grpcserver/grpcserver_test.go +++ b/internal/grpcserver/grpcserver_test.go @@ -676,3 +676,62 @@ func TestGetStatusProxyProtocol(t *testing.T) { t.Fatal("expected proxy_protocol = false") } } + +func TestSetListenerMaxConnections(t *testing.T) { + env := setup(t) + ctx := context.Background() + + // Set max connections. + _, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ + ListenerAddr: ":443", + MaxConnections: 5000, + }) + if err != nil { + t.Fatalf("SetListenerMaxConnections: %v", err) + } + + // Verify via GetStatus. + resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) + if err != nil { + t.Fatalf("GetStatus: %v", err) + } + if resp.Listeners[0].MaxConnections != 5000 { + t.Fatalf("max_connections = %d, want 5000", resp.Listeners[0].MaxConnections) + } + + // Verify DB persistence. + l, _ := env.store.GetListenerByAddr(":443") + if l.MaxConnections != 5000 { + t.Fatalf("DB max_connections = %d, want 5000", l.MaxConnections) + } + + // Set to 0 (unlimited). + _, err = env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ + ListenerAddr: ":443", + MaxConnections: 0, + }) + if err != nil { + t.Fatalf("SetListenerMaxConnections to 0: %v", err) + } + + resp, _ = env.client.GetStatus(ctx, &pb.GetStatusRequest{}) + if resp.Listeners[0].MaxConnections != 0 { + t.Fatalf("max_connections = %d, want 0", resp.Listeners[0].MaxConnections) + } +} + +func TestSetListenerMaxConnectionsNotFound(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ + ListenerAddr: ":9999", + MaxConnections: 100, + }) + if err == nil { + t.Fatal("expected error for nonexistent listener") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound { + t.Fatalf("expected NotFound, got %v", err) + } +} diff --git a/internal/server/server.go b/internal/server/server.go index 2cc2edd..7babce6 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -34,6 +34,7 @@ type ListenerState struct { ID int64 // database primary key Addr string ProxyProtocol bool + MaxConnections int64 // 0 = unlimited routes map[string]RouteInfo // lowercase hostname → route info mu sync.RWMutex ActiveConnections atomic.Int64 @@ -41,6 +42,13 @@ type ListenerState struct { connMu sync.Mutex } +// SetMaxConnections updates the connection limit at runtime. +func (ls *ListenerState) SetMaxConnections(n int64) { + ls.mu.Lock() + defer ls.mu.Unlock() + ls.MaxConnections = n +} + // Routes returns a snapshot of the listener's route table. func (ls *ListenerState) Routes() map[string]RouteInfo { ls.mu.RLock() @@ -93,10 +101,11 @@ func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) { // ListenerData holds the data needed to construct a ListenerState. type ListenerData struct { - ID int64 - Addr string - ProxyProtocol bool - Routes map[string]RouteInfo // lowercase hostname → route info + ID int64 + Addr string + ProxyProtocol bool + MaxConnections int64 + Routes map[string]RouteInfo // lowercase hostname → route info } // Server is the mc-proxy server. It manages listeners, firewall evaluation, @@ -116,11 +125,12 @@ func New(cfg *config.Config, fw *firewall.Firewall, listenerData []ListenerData, var listeners []*ListenerState for _, ld := range listenerData { listeners = append(listeners, &ListenerState{ - ID: ld.ID, - Addr: ld.Addr, - ProxyProtocol: ld.ProxyProtocol, - routes: ld.Routes, - activeConns: make(map[net.Conn]struct{}), + ID: ld.ID, + Addr: ld.Addr, + ProxyProtocol: ld.ProxyProtocol, + MaxConnections: ld.MaxConnections, + routes: ld.Routes, + activeConns: make(map[net.Conn]struct{}), }) } @@ -229,6 +239,13 @@ func (s *Server) serve(ctx context.Context, ln net.Listener, ls *ListenerState) continue } + // Enforce per-listener connection limit. + if ls.MaxConnections > 0 && ls.ActiveConnections.Load() >= ls.MaxConnections { + conn.Close() + s.logger.Debug("connection limit reached", "addr", ls.Addr, "limit", ls.MaxConnections) + continue + } + s.wg.Add(1) ls.ActiveConnections.Add(1) go s.handleConn(ctx, conn, ls) diff --git a/internal/server/server_test.go b/internal/server/server_test.go index bc41ec2..0e12f7b 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -1052,6 +1052,96 @@ func TestProxyProtocolFirewallUsesRealIP(t *testing.T) { wg.Wait() } +// --- Connection limit tests --- + +func TestConnectionLimitEnforced(t *testing.T) { + // Backend that holds connections open. + backendLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("backend listen: %v", err) + } + defer backendLn.Close() + + go func() { + for { + conn, err := backendLn.Accept() + if err != nil { + return + } + go io.Copy(io.Discard, conn) + } + }() + + proxyLn, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("proxy listen: %v", err) + } + proxyAddr := proxyLn.Addr().String() + proxyLn.Close() + + srv := newTestServer(t, []ListenerData{ + { + ID: 1, + Addr: proxyAddr, + MaxConnections: 2, + Routes: map[string]RouteInfo{ + "limit.test": l4Route(backendLn.Addr().String()), + }, + }, + }) + + stop := startAndStop(t, srv) + defer stop() + + // Open 2 connections (should succeed). + var conns []net.Conn + for i := range 2 { + conn, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial %d: %v", i, err) + } + conn.Write(buildClientHello("limit.test")) + conns = append(conns, conn) + } + time.Sleep(100 * time.Millisecond) + + // 3rd connection should be rejected (closed immediately). + conn3, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial 3: %v", err) + } + conn3.Write(buildClientHello("limit.test")) + conn3.SetReadDeadline(time.Now().Add(2 * time.Second)) + _, err = conn3.Read(make([]byte, 1)) + if err == nil { + t.Fatal("expected 3rd connection to be closed due to limit") + } + conn3.Close() + + // Close one existing connection. + conns[0].Close() + time.Sleep(200 * time.Millisecond) + + // Now a new connection should succeed. + conn4, err := net.DialTimeout("tcp", proxyAddr, 2*time.Second) + if err != nil { + t.Fatalf("dial 4: %v", err) + } + defer conn4.Close() + conn4.Write(buildClientHello("limit.test")) + + // Give it time to be proxied. + time.Sleep(100 * time.Millisecond) + if got := srv.TotalConnections(); got < 2 { + t.Fatalf("expected at least 2 connections, got %d", got) + } + + // Clean up. + for _, c := range conns[1:] { + c.Close() + } +} + // --- Multi-hop integration tests --- func TestMultiHopProxyProtocol(t *testing.T) { diff --git a/proto/mc_proxy/v1/admin.proto b/proto/mc_proxy/v1/admin.proto index 7ba758f..819b4b3 100644 --- a/proto/mc_proxy/v1/admin.proto +++ b/proto/mc_proxy/v1/admin.proto @@ -17,6 +17,9 @@ service ProxyAdminService { rpc AddFirewallRule(AddFirewallRuleRequest) returns (AddFirewallRuleResponse); rpc RemoveFirewallRule(RemoveFirewallRuleRequest) returns (RemoveFirewallRuleResponse); + // Connection limits + rpc SetListenerMaxConnections(SetListenerMaxConnectionsRequest) returns (SetListenerMaxConnectionsResponse); + // Status rpc GetStatus(GetStatusRequest) returns (GetStatusResponse); } @@ -90,11 +93,19 @@ message RemoveFirewallRuleResponse {} // Status +message SetListenerMaxConnectionsRequest { + string listener_addr = 1; + int64 max_connections = 2; // 0 = unlimited +} + +message SetListenerMaxConnectionsResponse {} + message ListenerStatus { string addr = 1; int32 route_count = 2; int64 active_connections = 3; bool proxy_protocol = 4; + int64 max_connections = 5; } message GetStatusRequest {}