diff --git a/PROGRESS.md b/PROGRESS.md index e2ba26b..8295919 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -35,11 +35,11 @@ proceeds. Each item is marked: ## Phase 4: gRPC API & CLI Updates -- [ ] 4.1 Proto updates (new fields on `Route`, `AddRouteRequest`, `ListenerStatus`) -- [ ] 4.2 gRPC server updates (accept/validate/persist new route fields) -- [ ] 4.3 Client package updates (new fields on `Route`, `ListenerStatus`) -- [ ] 4.4 mcproxyctl updates (flags for `routes add`, display in `routes list`) -- [ ] 4.5 Tests (gRPC round-trip with new fields, backward compatibility) +- [x] 4.1 Proto updates (new fields on `Route`, `AddRouteRequest`, `ListenerStatus`) +- [x] 4.2 gRPC server updates (accept/validate/persist new route fields) +- [x] 4.3 Client package updates (new fields on `Route`, `ListenerStatus`) +- [x] 4.4 mcproxyctl updates (flags for `routes add`, display in `routes list`) +- [x] 4.5 Tests (gRPC round-trip with new fields, backward compatibility) ## Phase 5: Integration & Polish diff --git a/client/mcproxy/client.go b/client/mcproxy/client.go index 9e772f3..1fe82a9 100644 --- a/client/mcproxy/client.go +++ b/client/mcproxy/client.go @@ -40,10 +40,15 @@ func (c *Client) Close() error { return c.conn.Close() } -// Route represents a hostname to backend mapping. +// Route represents a hostname to backend mapping with mode and options. type Route struct { - Hostname string - Backend string + Hostname string + Backend string + Mode string // "l4" or "l7" + TLSCert string + TLSKey string + BackendTLS bool + SendProxyProtocol bool } // ListRoutes returns all routes for the given listener address. @@ -58,20 +63,30 @@ func (c *Client) ListRoutes(ctx context.Context, listenerAddr string) ([]Route, routes := make([]Route, len(resp.Routes)) for i, r := range resp.Routes { routes[i] = Route{ - Hostname: r.Hostname, - Backend: r.Backend, + Hostname: r.Hostname, + Backend: r.Backend, + Mode: r.Mode, + TLSCert: r.TlsCert, + TLSKey: r.TlsKey, + BackendTLS: r.BackendTls, + SendProxyProtocol: r.SendProxyProtocol, } } return routes, nil } // AddRoute adds a route to the given listener. -func (c *Client) AddRoute(ctx context.Context, listenerAddr, hostname, backend string) error { +func (c *Client) AddRoute(ctx context.Context, listenerAddr string, route Route) error { _, err := c.admin.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: listenerAddr, Route: &pb.Route{ - Hostname: hostname, - Backend: backend, + Hostname: route.Hostname, + Backend: route.Backend, + Mode: route.Mode, + TlsCert: route.TLSCert, + TlsKey: route.TLSKey, + BackendTls: route.BackendTLS, + SendProxyProtocol: route.SendProxyProtocol, }, }) return err @@ -145,6 +160,7 @@ type ListenerStatus struct { Addr string RouteCount int ActiveConnections int64 + ProxyProtocol bool } // Status contains the server's current status. @@ -176,6 +192,7 @@ func (c *Client) GetStatus(ctx context.Context) (*Status, error) { Addr: ls.Addr, RouteCount: int(ls.RouteCount), ActiveConnections: ls.ActiveConnections, + ProxyProtocol: ls.ProxyProtocol, } } diff --git a/client/mcproxy/client_test.go b/client/mcproxy/client_test.go index f60b0e1..da9c540 100644 --- a/client/mcproxy/client_test.go +++ b/client/mcproxy/client_test.go @@ -219,7 +219,7 @@ func TestClientAddRemoveRoute(t *testing.T) { ctx := context.Background() // Add a new route. - err := client.AddRoute(ctx, ":443", "new.test", "127.0.0.1:9443") + err := client.AddRoute(ctx, ":443", Route{Hostname: "new.test", Backend: "127.0.0.1:9443"}) if err != nil { t.Fatalf("AddRoute: %v", err) } diff --git a/cmd/mcproxyctl/routes.go b/cmd/mcproxyctl/routes.go index 5e5526f..fa04827 100644 --- a/cmd/mcproxyctl/routes.go +++ b/cmd/mcproxyctl/routes.go @@ -6,6 +6,8 @@ import ( "time" "github.com/spf13/cobra" + + mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy" ) func routesCmd() *cobra.Command { @@ -45,7 +47,7 @@ func routesListCmd() *cobra.Command { return nil } - // Find max hostname length for alignment + // Find max hostname length for alignment. maxHostLen := 0 for _, r := range routes { if len(r.Hostname) > maxHostLen { @@ -55,7 +57,18 @@ func routesListCmd() *cobra.Command { fmt.Printf("Routes for %s:\n", listenerAddr) for _, r := range routes { - fmt.Printf(" %-*s -> %s\n", maxHostLen, r.Hostname, r.Backend) + mode := r.Mode + if mode == "" { + mode = "l4" + } + extra := "" + if r.SendProxyProtocol { + extra += " [proxy-protocol]" + } + if r.BackendTLS { + extra += " [backend-tls]" + } + fmt.Printf(" %-*s -> %-25s [%s]%s\n", maxHostLen, r.Hostname, r.Backend, mode, extra) } return nil @@ -64,7 +77,15 @@ func routesListCmd() *cobra.Command { } func routesAddCmd() *cobra.Command { - return &cobra.Command{ + var ( + mode string + tlsCert string + tlsKey string + backendTLS bool + sendProxyProtocol bool + ) + + cmd := &cobra.Command{ Use: "add LISTENER HOSTNAME BACKEND", Short: "Add a route", Long: "Add a route mapping a hostname to a backend for the specified listener.", @@ -78,14 +99,31 @@ func routesAddCmd() *cobra.Command { ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second) defer cancel() - if err := client.AddRoute(ctx, listenerAddr, hostname, backend); err != nil { + route := mcproxy.Route{ + Hostname: hostname, + Backend: backend, + Mode: mode, + TLSCert: tlsCert, + TLSKey: tlsKey, + BackendTLS: backendTLS, + SendProxyProtocol: sendProxyProtocol, + } + if err := client.AddRoute(ctx, listenerAddr, route); err != nil { return fmt.Errorf("adding route: %w", err) } - fmt.Printf("Added route: %s -> %s on %s\n", hostname, backend, listenerAddr) + fmt.Printf("Added route: %s -> %s on %s [%s]\n", hostname, backend, listenerAddr, mode) return nil }, } + + cmd.Flags().StringVar(&mode, "mode", "l4", "route mode: l4 (passthrough) or l7 (TLS-terminating)") + cmd.Flags().StringVar(&tlsCert, "tls-cert", "", "TLS certificate path (L7 only)") + cmd.Flags().StringVar(&tlsKey, "tls-key", "", "TLS private key path (L7 only)") + cmd.Flags().BoolVar(&backendTLS, "backend-tls", false, "re-encrypt to backend (L7 only)") + cmd.Flags().BoolVar(&sendProxyProtocol, "send-proxy-protocol", false, "send PROXY v2 header to backend") + + return cmd } func routesRemoveCmd() *cobra.Command { diff --git a/gen/mc_proxy/v1/admin.pb.go b/gen/mc_proxy/v1/admin.pb.go index f6f30c3..b937671 100644 --- a/gen/mc_proxy/v1/admin.pb.go +++ b/gen/mc_proxy/v1/admin.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.36.11 -// protoc v6.33.4 +// protoc v5.29.5 // source: proto/mc_proxy/v1/admin.proto package mcproxyv1 @@ -75,11 +75,16 @@ func (FirewallRuleType) EnumDescriptor() ([]byte, []int) { } type Route struct { - state protoimpl.MessageState `protogen:"open.v1"` - Hostname string `protobuf:"bytes,1,opt,name=hostname,proto3" json:"hostname,omitempty"` - Backend string `protobuf:"bytes,2,opt,name=backend,proto3" json:"backend,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Hostname string `protobuf:"bytes,1,opt,name=hostname,proto3" json:"hostname,omitempty"` + Backend string `protobuf:"bytes,2,opt,name=backend,proto3" json:"backend,omitempty"` + Mode string `protobuf:"bytes,3,opt,name=mode,proto3" json:"mode,omitempty"` // "l4" (default) or "l7" + TlsCert string `protobuf:"bytes,4,opt,name=tls_cert,json=tlsCert,proto3" json:"tls_cert,omitempty"` // PEM certificate path (L7 only) + TlsKey string `protobuf:"bytes,5,opt,name=tls_key,json=tlsKey,proto3" json:"tls_key,omitempty"` // PEM private key path (L7 only) + BackendTls bool `protobuf:"varint,6,opt,name=backend_tls,json=backendTls,proto3" json:"backend_tls,omitempty"` // re-encrypt to backend (L7 only) + SendProxyProtocol bool `protobuf:"varint,7,opt,name=send_proxy_protocol,json=sendProxyProtocol,proto3" json:"send_proxy_protocol,omitempty"` // send PROXY v2 header to backend + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Route) Reset() { @@ -126,6 +131,41 @@ func (x *Route) GetBackend() string { return "" } +func (x *Route) GetMode() string { + if x != nil { + return x.Mode + } + return "" +} + +func (x *Route) GetTlsCert() string { + if x != nil { + return x.TlsCert + } + return "" +} + +func (x *Route) GetTlsKey() string { + if x != nil { + return x.TlsKey + } + return "" +} + +func (x *Route) GetBackendTls() bool { + if x != nil { + return x.BackendTls + } + return false +} + +func (x *Route) GetSendProxyProtocol() bool { + if x != nil { + return x.SendProxyProtocol + } + return false +} + type ListRoutesRequest struct { state protoimpl.MessageState `protogen:"open.v1"` ListenerAddr string `protobuf:"bytes,1,opt,name=listener_addr,json=listenerAddr,proto3" json:"listener_addr,omitempty"` @@ -695,6 +735,7 @@ type ListenerStatus struct { Addr string `protobuf:"bytes,1,opt,name=addr,proto3" json:"addr,omitempty"` RouteCount int32 `protobuf:"varint,2,opt,name=route_count,json=routeCount,proto3" json:"route_count,omitempty"` ActiveConnections int64 `protobuf:"varint,3,opt,name=active_connections,json=activeConnections,proto3" json:"active_connections,omitempty"` + ProxyProtocol bool `protobuf:"varint,4,opt,name=proxy_protocol,json=proxyProtocol,proto3" json:"proxy_protocol,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -750,6 +791,13 @@ func (x *ListenerStatus) GetActiveConnections() int64 { return 0 } +func (x *ListenerStatus) GetProxyProtocol() bool { + if x != nil { + return x.ProxyProtocol + } + return false +} + type GetStatusRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -858,10 +906,16 @@ var File_proto_mc_proxy_v1_admin_proto protoreflect.FileDescriptor const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" + "\n" + - "\x1dproto/mc_proxy/v1/admin.proto\x12\vmc_proxy.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"=\n" + + "\x1dproto/mc_proxy/v1/admin.proto\x12\vmc_proxy.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\xd6\x01\n" + "\x05Route\x12\x1a\n" + "\bhostname\x18\x01 \x01(\tR\bhostname\x12\x18\n" + - "\abackend\x18\x02 \x01(\tR\abackend\"8\n" + + "\abackend\x18\x02 \x01(\tR\abackend\x12\x12\n" + + "\x04mode\x18\x03 \x01(\tR\x04mode\x12\x19\n" + + "\btls_cert\x18\x04 \x01(\tR\atlsCert\x12\x17\n" + + "\atls_key\x18\x05 \x01(\tR\x06tlsKey\x12\x1f\n" + + "\vbackend_tls\x18\x06 \x01(\bR\n" + + "backendTls\x12.\n" + + "\x13send_proxy_protocol\x18\a \x01(\bR\x11sendProxyProtocol\"8\n" + "\x11ListRoutesRequest\x12#\n" + "\rlistener_addr\x18\x01 \x01(\tR\flistenerAddr\"e\n" + "\x12ListRoutesResponse\x12#\n" + @@ -886,12 +940,13 @@ const file_proto_mc_proxy_v1_admin_proto_rawDesc = "" + "\x17AddFirewallRuleResponse\"J\n" + "\x19RemoveFirewallRuleRequest\x12-\n" + "\x04rule\x18\x01 \x01(\v2\x19.mc_proxy.v1.FirewallRuleR\x04rule\"\x1c\n" + - "\x1aRemoveFirewallRuleResponse\"t\n" + + "\x1aRemoveFirewallRuleResponse\"\x9b\x01\n" + "\x0eListenerStatus\x12\x12\n" + "\x04addr\x18\x01 \x01(\tR\x04addr\x12\x1f\n" + "\vroute_count\x18\x02 \x01(\x05R\n" + "routeCount\x12-\n" + - "\x12active_connections\x18\x03 \x01(\x03R\x11activeConnections\"\x12\n" + + "\x12active_connections\x18\x03 \x01(\x03R\x11activeConnections\x12%\n" + + "\x0eproxy_protocol\x18\x04 \x01(\bR\rproxyProtocol\"\x12\n" + "\x10GetStatusRequest\"\xd0\x01\n" + "\x11GetStatusResponse\x12\x18\n" + "\aversion\x18\x01 \x01(\tR\aversion\x129\n" + diff --git a/gen/mc_proxy/v1/admin_grpc.pb.go b/gen/mc_proxy/v1/admin_grpc.pb.go index c96d856..276bae5 100644 --- a/gen/mc_proxy/v1/admin_grpc.pb.go +++ b/gen/mc_proxy/v1/admin_grpc.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: // - protoc-gen-go-grpc v1.6.1 -// - protoc v6.33.4 +// - protoc v5.29.5 // source: proto/mc_proxy/v1/admin.proto package mcproxyv1 diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index 9621598..5aa619d 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -90,8 +90,13 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) ( } for hostname, route := range routes { resp.Routes = append(resp.Routes, &pb.Route{ - Hostname: hostname, - Backend: route.Backend, + Hostname: hostname, + Backend: route.Backend, + Mode: route.Mode, + TlsCert: route.TLSCert, + TlsKey: route.TLSKey, + BackendTls: route.BackendTLS, + SendProxyProtocol: route.SendProxyProtocol, }) } return resp, nil @@ -118,17 +123,42 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb. hostname := strings.ToLower(req.Route.Hostname) + // Normalize mode. + mode := req.Route.Mode + if mode == "" { + mode = "l4" + } + if mode != "l4" && mode != "l7" { + return nil, status.Errorf(codes.InvalidArgument, "mode must be \"l4\" or \"l7\", got %q", mode) + } + + // L7 routes require cert/key paths. + if mode == "l7" { + if req.Route.TlsCert == "" || req.Route.TlsKey == "" { + return nil, status.Error(codes.InvalidArgument, "L7 routes require tls_cert and tls_key") + } + } + // Write-through: DB first, then memory. - if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, "l4", "", "", false, false); err != nil { + if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode, + req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } - if err := ls.AddRoute(hostname, server.RouteInfo{Backend: req.Route.Backend, Mode: "l4"}); err != nil { + info := server.RouteInfo{ + Backend: req.Route.Backend, + Mode: mode, + TLSCert: req.Route.TlsCert, + TLSKey: req.Route.TlsKey, + BackendTLS: req.Route.BackendTls, + SendProxyProtocol: req.Route.SendProxyProtocol, + } + if err := ls.AddRoute(hostname, info); err != nil { // DB succeeded but memory failed (should not happen since DB enforces uniqueness). a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) } - a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend) + a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode) return &pb.AddRouteResponse{}, nil } @@ -287,6 +317,7 @@ func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb. Addr: ls.Addr, RouteCount: int32(len(routes)), ActiveConnections: ls.ActiveConnections.Load(), + ProxyProtocol: ls.ProxyProtocol, }) } diff --git a/internal/grpcserver/grpcserver_test.go b/internal/grpcserver/grpcserver_test.go index dd96238..061fc01 100644 --- a/internal/grpcserver/grpcserver_test.go +++ b/internal/grpcserver/grpcserver_test.go @@ -517,3 +517,162 @@ func TestRemoveFirewallRuleNotFound(t *testing.T) { t.Fatal("expected error for removing nonexistent rule") } } + +func TestAddRouteL7(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{ + Hostname: "l7.test", + Backend: "127.0.0.1:8080", + Mode: "l7", + TlsCert: "/certs/l7.crt", + TlsKey: "/certs/l7.key", + BackendTls: false, + SendProxyProtocol: true, + }, + }) + if err != nil { + t.Fatalf("AddRoute L7: %v", err) + } + + // Verify in-memory via ListRoutes. + resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + + var found *pb.Route + for _, r := range resp.Routes { + if r.Hostname == "l7.test" { + found = r + break + } + } + if found == nil { + t.Fatal("L7 route not found in ListRoutes response") + } + if found.Mode != "l7" { + t.Fatalf("mode = %q, want %q", found.Mode, "l7") + } + if found.TlsCert != "/certs/l7.crt" { + t.Fatalf("tls_cert = %q, want %q", found.TlsCert, "/certs/l7.crt") + } + if found.TlsKey != "/certs/l7.key" { + t.Fatalf("tls_key = %q, want %q", found.TlsKey, "/certs/l7.key") + } + if found.BackendTls { + t.Fatal("expected backend_tls = false") + } + if !found.SendProxyProtocol { + t.Fatal("expected send_proxy_protocol = true") + } + + // Verify DB persistence. + dbListeners, _ := env.store.ListListeners() + dbRoutes, _ := env.store.ListRoutes(dbListeners[0].ID) + var dbRoute *db.Route + for i := range dbRoutes { + if dbRoutes[i].Hostname == "l7.test" { + dbRoute = &dbRoutes[i] + break + } + } + if dbRoute == nil { + t.Fatal("L7 route not found in DB") + } + if dbRoute.Mode != "l7" { + t.Fatalf("DB mode = %q, want %q", dbRoute.Mode, "l7") + } + if !dbRoute.SendProxyProtocol { + t.Fatal("DB send_proxy_protocol should be true") + } +} + +func TestAddRouteL7MissingCert(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{ + Hostname: "nocert.test", + Backend: "127.0.0.1:8080", + Mode: "l7", + }, + }) + if err == nil { + t.Fatal("expected error for L7 route without cert/key") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument { + t.Fatalf("expected InvalidArgument, got %v", err) + } +} + +func TestAddRouteInvalidMode(t *testing.T) { + env := setup(t) + ctx := context.Background() + + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{ + Hostname: "badmode.test", + Backend: "127.0.0.1:8080", + Mode: "l5", + }, + }) + if err == nil { + t.Fatal("expected error for invalid mode") + } + if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument { + t.Fatalf("expected InvalidArgument, got %v", err) + } +} + +func TestAddRouteDefaultsToL4(t *testing.T) { + env := setup(t) + ctx := context.Background() + + // Add route without specifying mode — should default to "l4". + _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: ":443", + Route: &pb.Route{ + Hostname: "default.test", + Backend: "127.0.0.1:9443", + }, + }) + if err != nil { + t.Fatalf("AddRoute: %v", err) + } + + resp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + for _, r := range resp.Routes { + if r.Hostname == "default.test" { + if r.Mode != "l4" { + t.Fatalf("mode = %q, want %q", r.Mode, "l4") + } + return + } + } + t.Fatal("route not found") +} + +func TestGetStatusProxyProtocol(t *testing.T) { + env := setup(t) + ctx := context.Background() + + resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) + if err != nil { + t.Fatalf("GetStatus: %v", err) + } + + // The seeded listener has proxy_protocol = false. + if len(resp.Listeners) != 1 { + t.Fatalf("got %d listeners, want 1", len(resp.Listeners)) + } + if resp.Listeners[0].ProxyProtocol { + t.Fatal("expected proxy_protocol = false") + } +} diff --git a/proto/mc_proxy/v1/admin.proto b/proto/mc_proxy/v1/admin.proto index e65bdd9..7ba758f 100644 --- a/proto/mc_proxy/v1/admin.proto +++ b/proto/mc_proxy/v1/admin.proto @@ -26,6 +26,11 @@ service ProxyAdminService { message Route { string hostname = 1; string backend = 2; + string mode = 3; // "l4" (default) or "l7" + string tls_cert = 4; // PEM certificate path (L7 only) + string tls_key = 5; // PEM private key path (L7 only) + bool backend_tls = 6; // re-encrypt to backend (L7 only) + bool send_proxy_protocol = 7; // send PROXY v2 header to backend } message ListRoutesRequest { @@ -89,6 +94,7 @@ message ListenerStatus { string addr = 1; int32 route_count = 2; int64 active_connections = 3; + bool proxy_protocol = 4; } message GetStatusRequest {}