diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index f372697..a5f6efa 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -177,8 +177,8 @@ gRPC admin API. ## gRPC Admin API The admin API is optional (disabled if `[grpc]` is omitted from the config). -When enabled, it requires TLS and supports optional mTLS for client -authentication. TLS 1.3 is enforced. The API provides runtime management +It listens on a Unix domain socket for security — access is controlled via +filesystem permissions (0600, owner-only). The API provides runtime management of routes and firewall rules without restarting the proxy. ### RPCs @@ -192,6 +192,7 @@ of routes and firewall rules without restarting the proxy. | `AddFirewallRule` | Add a firewall rule (write-through to DB) | | `RemoveFirewallRule` | Remove a firewall rule (write-through to DB) | | `GetStatus` | Return version, uptime, listener status, connection counts | +| `grpc.health.v1.Health` | Standard gRPC health check (Check, Watch) | ### Input Validation @@ -205,13 +206,11 @@ The admin API validates all inputs before persisting: ### Security The gRPC admin API has no MCIAS integration — mc-proxy is pre-auth -infrastructure. Access control relies on: +infrastructure. Access control relies on Unix socket filesystem permissions: -1. **Network binding**: bind to `127.0.0.1` (default) to restrict to local access. -2. **mTLS**: configure `client_ca` to require client certificates. - -If the admin API is exposed on a non-loopback interface without mTLS, -any network client can modify routing and firewall rules. +- Socket is created with mode `0600` (read/write for owner only) +- Only processes running as the same user can connect +- No network exposure — the API is not accessible over TCP --- @@ -252,15 +251,9 @@ addr = ":9443" backend = "127.0.0.1:28443" # gRPC admin API. Optional — omit or leave addr empty to disable. -# If enabled, tls_cert and tls_key are required (TLS 1.3 only). -# client_ca enables mTLS and is strongly recommended for non-loopback addresses. -# ca_cert is used by the `status` CLI command to verify the server certificate. +# Listens on a Unix socket; access controlled via filesystem permissions. [grpc] -addr = "127.0.0.1:9090" -tls_cert = "/srv/mc-proxy/certs/cert.pem" -tls_key = "/srv/mc-proxy/certs/key.pem" -client_ca = "/srv/mc-proxy/certs/ca.pem" -ca_cert = "/srv/mc-proxy/certs/ca.pem" +addr = "/var/run/mc-proxy.sock" # Firewall. Global blocklist, evaluated before routing. Default allow. [firewall] @@ -347,14 +340,14 @@ CREATE TABLE firewall_rules ( /srv/mc-proxy/ ├── mc-proxy.toml Configuration ├── mc-proxy.db SQLite database -├── certs/ TLS certificates (for gRPC admin API) +├── mc-proxy.sock Unix socket for gRPC admin API ├── GeoLite2-Country.mmdb GeoIP database (if using country blocks) └── backups/ Database snapshots ``` -mc-proxy does not terminate TLS on the proxy listeners, so no proxy -certificates are needed. The `certs/` directory is for the gRPC admin -API's TLS and optional mTLS keypair. +mc-proxy does not terminate TLS on any listener. The proxy listeners pass +through raw TLS streams, and the gRPC admin API uses a Unix socket +(filesystem permissions for access control). --- @@ -447,5 +440,4 @@ Items are listed roughly in priority order: | **User-agent blocking** | Block connections based on user-agent string (requires L7 mode). | | **Connection rate limiting** | Per-source-IP rate limits to mitigate connection floods. | | **Per-listener connection limits** | Cap maximum concurrent connections per listener. | -| **Health check endpoint** | Lightweight TCP or HTTP health check for load balancers and monitoring. | | **Metrics** | Prometheus-compatible metrics: connections per listener, firewall blocks by rule, backend dial latency, active connections. | diff --git a/client/mcproxy/client.go b/client/mcproxy/client.go new file mode 100644 index 0000000..9e772f3 --- /dev/null +++ b/client/mcproxy/client.go @@ -0,0 +1,238 @@ +// Package mcproxy provides a client for the mc-proxy gRPC admin API. +package mcproxy + +import ( + "context" + "fmt" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + + pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" +) + +// Client provides access to the mc-proxy admin API. +type Client struct { + conn *grpc.ClientConn + admin pb.ProxyAdminServiceClient + health healthpb.HealthClient +} + +// Dial connects to the mc-proxy admin API via Unix socket. +func Dial(socketPath string) (*Client, error) { + conn, err := grpc.NewClient("unix://"+socketPath, + grpc.WithTransportCredentials(insecure.NewCredentials())) + if err != nil { + return nil, fmt.Errorf("connecting to %s: %w", socketPath, err) + } + + return &Client{ + conn: conn, + admin: pb.NewProxyAdminServiceClient(conn), + health: healthpb.NewHealthClient(conn), + }, nil +} + +// Close closes the connection to the server. +func (c *Client) Close() error { + return c.conn.Close() +} + +// Route represents a hostname to backend mapping. +type Route struct { + Hostname string + Backend string +} + +// ListRoutes returns all routes for the given listener address. +func (c *Client) ListRoutes(ctx context.Context, listenerAddr string) ([]Route, error) { + resp, err := c.admin.ListRoutes(ctx, &pb.ListRoutesRequest{ + ListenerAddr: listenerAddr, + }) + if err != nil { + return nil, err + } + + routes := make([]Route, len(resp.Routes)) + for i, r := range resp.Routes { + routes[i] = Route{ + Hostname: r.Hostname, + Backend: r.Backend, + } + } + return routes, nil +} + +// AddRoute adds a route to the given listener. +func (c *Client) AddRoute(ctx context.Context, listenerAddr, hostname, backend string) error { + _, err := c.admin.AddRoute(ctx, &pb.AddRouteRequest{ + ListenerAddr: listenerAddr, + Route: &pb.Route{ + Hostname: hostname, + Backend: backend, + }, + }) + return err +} + +// RemoveRoute removes a route from the given listener. +func (c *Client) RemoveRoute(ctx context.Context, listenerAddr, hostname string) error { + _, err := c.admin.RemoveRoute(ctx, &pb.RemoveRouteRequest{ + ListenerAddr: listenerAddr, + Hostname: hostname, + }) + return err +} + +// FirewallRuleType represents the type of firewall rule. +type FirewallRuleType string + +const ( + FirewallRuleIP FirewallRuleType = "ip" + FirewallRuleCIDR FirewallRuleType = "cidr" + FirewallRuleCountry FirewallRuleType = "country" +) + +// FirewallRule represents a firewall block rule. +type FirewallRule struct { + Type FirewallRuleType + Value string +} + +// GetFirewallRules returns all firewall rules. +func (c *Client) GetFirewallRules(ctx context.Context) ([]FirewallRule, error) { + resp, err := c.admin.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) + if err != nil { + return nil, err + } + + rules := make([]FirewallRule, len(resp.Rules)) + for i, r := range resp.Rules { + rules[i] = FirewallRule{ + Type: protoToRuleType(r.Type), + Value: r.Value, + } + } + return rules, nil +} + +// AddFirewallRule adds a firewall rule. +func (c *Client) AddFirewallRule(ctx context.Context, ruleType FirewallRuleType, value string) error { + _, err := c.admin.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: ruleTypeToProto(ruleType), + Value: value, + }, + }) + return err +} + +// RemoveFirewallRule removes a firewall rule. +func (c *Client) RemoveFirewallRule(ctx context.Context, ruleType FirewallRuleType, value string) error { + _, err := c.admin.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{ + Rule: &pb.FirewallRule{ + Type: ruleTypeToProto(ruleType), + Value: value, + }, + }) + return err +} + +// ListenerStatus contains status information for a single listener. +type ListenerStatus struct { + Addr string + RouteCount int + ActiveConnections int64 +} + +// Status contains the server's current status. +type Status struct { + Version string + StartedAt time.Time + TotalConnections int64 + Listeners []ListenerStatus +} + +// GetStatus returns the server's current status. +func (c *Client) GetStatus(ctx context.Context) (*Status, error) { + resp, err := c.admin.GetStatus(ctx, &pb.GetStatusRequest{}) + if err != nil { + return nil, err + } + + status := &Status{ + Version: resp.Version, + TotalConnections: resp.TotalConnections, + } + if resp.StartedAt != nil { + status.StartedAt = resp.StartedAt.AsTime() + } + + status.Listeners = make([]ListenerStatus, len(resp.Listeners)) + for i, ls := range resp.Listeners { + status.Listeners[i] = ListenerStatus{ + Addr: ls.Addr, + RouteCount: int(ls.RouteCount), + ActiveConnections: ls.ActiveConnections, + } + } + + return status, nil +} + +// HealthStatus represents the health of the server. +type HealthStatus int + +const ( + HealthUnknown HealthStatus = 0 + HealthServing HealthStatus = 1 + HealthNotServing HealthStatus = 2 +) + +func (h HealthStatus) String() string { + switch h { + case HealthServing: + return "SERVING" + case HealthNotServing: + return "NOT_SERVING" + default: + return "UNKNOWN" + } +} + +// CheckHealth checks the health of the server. +func (c *Client) CheckHealth(ctx context.Context) (HealthStatus, error) { + resp, err := c.health.Check(ctx, &healthpb.HealthCheckRequest{}) + if err != nil { + return HealthUnknown, err + } + return HealthStatus(resp.Status), nil +} + +func protoToRuleType(t pb.FirewallRuleType) FirewallRuleType { + switch t { + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: + return FirewallRuleIP + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: + return FirewallRuleCIDR + case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: + return FirewallRuleCountry + default: + return "" + } +} + +func ruleTypeToProto(t FirewallRuleType) pb.FirewallRuleType { + switch t { + case FirewallRuleIP: + return pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP + case FirewallRuleCIDR: + return pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR + case FirewallRuleCountry: + return pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY + default: + return pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED + } +} diff --git a/client/mcproxy/client_test.go b/client/mcproxy/client_test.go new file mode 100644 index 0000000..010f240 --- /dev/null +++ b/client/mcproxy/client_test.go @@ -0,0 +1,331 @@ +package mcproxy + +import ( + "context" + "io" + "log/slog" + "net" + "path/filepath" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" + "google.golang.org/grpc/test/bufconn" + + pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" + "git.wntrmute.dev/kyle/mc-proxy/internal/config" + "git.wntrmute.dev/kyle/mc-proxy/internal/db" + "git.wntrmute.dev/kyle/mc-proxy/internal/firewall" + "git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver" + "git.wntrmute.dev/kyle/mc-proxy/internal/server" +) + +func setupTestClient(t *testing.T) *Client { + t.Helper() + + // Database in temp dir. + dbPath := filepath.Join(t.TempDir(), "test.db") + store, err := db.Open(dbPath) + if err != nil { + t.Fatalf("open db: %v", err) + } + t.Cleanup(func() { store.Close() }) + + if err := store.Migrate(); err != nil { + t.Fatalf("migrate: %v", err) + } + + // Seed with one listener and one route. + listeners := []config.Listener{ + { + Addr: ":443", + Routes: []config.Route{ + {Hostname: "example.test", Backend: "127.0.0.1:8443"}, + }, + }, + } + fw := config.Firewall{ + BlockedIPs: []string{"10.0.0.1"}, + } + if err := store.Seed(listeners, fw); err != nil { + t.Fatalf("seed: %v", err) + } + + // Build server with matching in-memory state. + fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0) + if err != nil { + t.Fatalf("firewall: %v", err) + } + + cfg := &config.Config{ + Proxy: config.Proxy{ + ConnectTimeout: config.Duration{Duration: 5 * time.Second}, + IdleTimeout: config.Duration{Duration: 30 * time.Second}, + ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, + }, + } + logger := slog.New(slog.NewTextHandler(io.Discard, nil)) + + // Load listener data from DB to get correct IDs. + dbListeners, err := store.ListListeners() + if err != nil { + t.Fatalf("list listeners: %v", err) + } + var listenerData []server.ListenerData + for _, l := range dbListeners { + dbRoutes, err := store.ListRoutes(l.ID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + routes := make(map[string]string, len(dbRoutes)) + for _, r := range dbRoutes { + routes[r.Hostname] = r.Backend + } + listenerData = append(listenerData, server.ListenerData{ + ID: l.ID, + Addr: l.Addr, + Routes: routes, + }) + } + + srv := server.New(cfg, fwObj, listenerData, logger, "test-version") + + // Set up bufconn gRPC server. + lis := bufconn.Listen(1024 * 1024) + grpcSrv := grpc.NewServer() + + pb.RegisterProxyAdminServiceServer(grpcSrv, &testAdminServer{ + srv: srv, + store: store, + logger: logger, + }) + + // Register health service. + healthServer := health.NewServer() + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(grpcSrv, healthServer) + + go func() { + if err := grpcSrv.Serve(lis); err != nil { + t.Logf("grpc serve: %v", err) + } + }() + t.Cleanup(grpcSrv.Stop) + + conn, err := grpc.NewClient("passthrough://bufconn", + grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { + return lis.DialContext(ctx) + }), + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + t.Fatalf("dial bufconn: %v", err) + } + t.Cleanup(func() { conn.Close() }) + + return &Client{ + conn: conn, + admin: pb.NewProxyAdminServiceClient(conn), + health: healthpb.NewHealthClient(conn), + } +} + +// testAdminServer is a minimal implementation for testing. +// It delegates to the real grpcserver.AdminServer logic. +type testAdminServer struct { + pb.UnimplementedProxyAdminServiceServer + srv *server.Server + store *db.Store + logger *slog.Logger +} + +func (s *testAdminServer) GetStatus(ctx context.Context, req *pb.GetStatusRequest) (*pb.GetStatusResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).GetStatus(ctx, req) +} + +func (s *testAdminServer) ListRoutes(ctx context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).ListRoutes(ctx, req) +} + +func (s *testAdminServer) AddRoute(ctx context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).AddRoute(ctx, req) +} + +func (s *testAdminServer) RemoveRoute(ctx context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).RemoveRoute(ctx, req) +} + +func (s *testAdminServer) GetFirewallRules(ctx context.Context, req *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).GetFirewallRules(ctx, req) +} + +func (s *testAdminServer) AddFirewallRule(ctx context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).AddFirewallRule(ctx, req) +} + +func (s *testAdminServer) RemoveFirewallRule(ctx context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) { + return grpcserver.NewAdminServer(s.srv, s.store, s.logger).RemoveFirewallRule(ctx, req) +} + +func TestClientGetStatus(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + status, err := client.GetStatus(ctx) + if err != nil { + t.Fatalf("GetStatus: %v", err) + } + + if status.Version != "test-version" { + t.Errorf("got version %q, want %q", status.Version, "test-version") + } + if len(status.Listeners) != 1 { + t.Errorf("got %d listeners, want 1", len(status.Listeners)) + } + if status.Listeners[0].Addr != ":443" { + t.Errorf("got listener addr %q, want %q", status.Listeners[0].Addr, ":443") + } +} + +func TestClientListRoutes(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + routes, err := client.ListRoutes(ctx, ":443") + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + + if len(routes) != 1 { + t.Fatalf("got %d routes, want 1", len(routes)) + } + if routes[0].Hostname != "example.test" { + t.Errorf("got hostname %q, want %q", routes[0].Hostname, "example.test") + } + if routes[0].Backend != "127.0.0.1:8443" { + t.Errorf("got backend %q, want %q", routes[0].Backend, "127.0.0.1:8443") + } +} + +func TestClientAddRemoveRoute(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + // Add a new route. + err := client.AddRoute(ctx, ":443", "new.test", "127.0.0.1:9443") + if err != nil { + t.Fatalf("AddRoute: %v", err) + } + + // Verify it was added. + routes, err := client.ListRoutes(ctx, ":443") + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + if len(routes) != 2 { + t.Fatalf("got %d routes after add, want 2", len(routes)) + } + + // Remove the route. + err = client.RemoveRoute(ctx, ":443", "new.test") + if err != nil { + t.Fatalf("RemoveRoute: %v", err) + } + + // Verify it was removed. + routes, err = client.ListRoutes(ctx, ":443") + if err != nil { + t.Fatalf("ListRoutes: %v", err) + } + if len(routes) != 1 { + t.Fatalf("got %d routes after remove, want 1", len(routes)) + } +} + +func TestClientGetFirewallRules(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + rules, err := client.GetFirewallRules(ctx) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + + if len(rules) != 1 { + t.Fatalf("got %d rules, want 1", len(rules)) + } + if rules[0].Type != FirewallRuleIP { + t.Errorf("got type %q, want %q", rules[0].Type, FirewallRuleIP) + } + if rules[0].Value != "10.0.0.1" { + t.Errorf("got value %q, want %q", rules[0].Value, "10.0.0.1") + } +} + +func TestClientAddRemoveFirewallRule(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + // Add a CIDR rule. + err := client.AddFirewallRule(ctx, FirewallRuleCIDR, "192.168.0.0/16") + if err != nil { + t.Fatalf("AddFirewallRule: %v", err) + } + + // Verify it was added. + rules, err := client.GetFirewallRules(ctx) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + if len(rules) != 2 { + t.Fatalf("got %d rules after add, want 2", len(rules)) + } + + // Remove the rule. + err = client.RemoveFirewallRule(ctx, FirewallRuleCIDR, "192.168.0.0/16") + if err != nil { + t.Fatalf("RemoveFirewallRule: %v", err) + } + + // Verify it was removed. + rules, err = client.GetFirewallRules(ctx) + if err != nil { + t.Fatalf("GetFirewallRules: %v", err) + } + if len(rules) != 1 { + t.Fatalf("got %d rules after remove, want 1", len(rules)) + } +} + +func TestClientCheckHealth(t *testing.T) { + client := setupTestClient(t) + ctx := context.Background() + + status, err := client.CheckHealth(ctx) + if err != nil { + t.Fatalf("CheckHealth: %v", err) + } + + if status != HealthServing { + t.Errorf("got health status %v, want %v", status, HealthServing) + } +} + +func TestHealthStatusString(t *testing.T) { + tests := []struct { + status HealthStatus + want string + }{ + {HealthUnknown, "UNKNOWN"}, + {HealthServing, "SERVING"}, + {HealthNotServing, "NOT_SERVING"}, + } + for _, tt := range tests { + if got := tt.status.String(); got != tt.want { + t.Errorf("HealthStatus(%d).String() = %q, want %q", tt.status, got, tt.want) + } + } +} diff --git a/client/mcproxy/doc.go b/client/mcproxy/doc.go new file mode 100644 index 0000000..b1057b5 --- /dev/null +++ b/client/mcproxy/doc.go @@ -0,0 +1,41 @@ +// Package mcproxy provides a Go client for the mc-proxy gRPC admin API. +// +// The client connects to mc-proxy via Unix socket and provides methods +// for managing routes, firewall rules, and querying server status. +// +// # Basic Usage +// +// client, err := mcproxy.Dial("/var/run/mc-proxy.sock") +// if err != nil { +// log.Fatal(err) +// } +// defer client.Close() +// +// // Get server status +// status, err := client.GetStatus(ctx) +// if err != nil { +// log.Fatal(err) +// } +// fmt.Printf("mc-proxy %s, %d connections\n", status.Version, status.TotalConnections) +// +// // List routes for a listener +// routes, err := client.ListRoutes(ctx, ":443") +// if err != nil { +// log.Fatal(err) +// } +// for _, r := range routes { +// fmt.Printf(" %s -> %s\n", r.Hostname, r.Backend) +// } +// +// // Add a route +// err = client.AddRoute(ctx, ":443", "example.com", "127.0.0.1:8443") +// +// // Add a firewall rule +// err = client.AddFirewallRule(ctx, mcproxy.FirewallRuleCIDR, "10.0.0.0/8") +// +// // Check health +// health, err := client.CheckHealth(ctx) +// if health == mcproxy.HealthServing { +// fmt.Println("Server is healthy") +// } +package mcproxy diff --git a/cmd/mc-proxy/server.go b/cmd/mc-proxy/server.go index 17453fb..212999b 100644 --- a/cmd/mc-proxy/server.go +++ b/cmd/mc-proxy/server.go @@ -92,9 +92,7 @@ func serverCmd() *cobra.Command { }() defer func() { grpcSrv.GracefulStop() - if cfg.GRPC.IsUnixSocket() { - os.Remove(cfg.GRPC.SocketPath()) - } + os.Remove(cfg.GRPC.SocketPath()) }() } diff --git a/cmd/mc-proxy/status.go b/cmd/mc-proxy/status.go index 0f77a46..d353105 100644 --- a/cmd/mc-proxy/status.go +++ b/cmd/mc-proxy/status.go @@ -2,15 +2,11 @@ package main import ( "context" - "crypto/tls" - "crypto/x509" "fmt" - "os" "time" "github.com/spf13/cobra" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" @@ -71,37 +67,6 @@ func statusCmd() *cobra.Command { } func dialGRPC(cfg config.GRPC) (*grpc.ClientConn, error) { - if cfg.IsUnixSocket() { - return grpc.NewClient("unix://"+cfg.SocketPath(), - grpc.WithTransportCredentials(insecure.NewCredentials())) - } - - tlsConfig := &tls.Config{ - MinVersion: tls.VersionTLS13, - } - - // Load CA cert for verifying the server. - if cfg.CACert != "" { - caCert, err := os.ReadFile(cfg.CACert) - if err != nil { - return nil, fmt.Errorf("reading CA cert: %w", err) - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(caCert) { - return nil, fmt.Errorf("failed to parse CA certificate") - } - tlsConfig.RootCAs = pool - } - - // Load client cert for mTLS. - if cfg.TLSCert != "" && cfg.TLSKey != "" { - cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) - if err != nil { - return nil, fmt.Errorf("loading client cert: %w", err) - } - tlsConfig.Certificates = []tls.Certificate{cert} - } - - creds := credentials.NewTLS(tlsConfig) - return grpc.NewClient(cfg.Addr, grpc.WithTransportCredentials(creds)) + return grpc.NewClient("unix://"+cfg.SocketPath(), + grpc.WithTransportCredentials(insecure.NewCredentials())) } diff --git a/internal/config/config.go b/internal/config/config.go index 16e7074..5a3b2fe 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,11 +23,7 @@ type Database struct { } type GRPC struct { - Addr string `toml:"addr"` - TLSCert string `toml:"tls_cert"` - TLSKey string `toml:"tls_key"` - CACert string `toml:"ca_cert"` // CA cert for verifying the server (client-side) - ClientCA string `toml:"client_ca"` // CA cert for verifying clients (server-side mTLS) + Addr string `toml:"addr"` // Unix socket path (e.g., "/var/run/mc-proxy.sock") } type Listener struct { @@ -64,13 +60,7 @@ type Duration struct { time.Duration } -// IsUnixSocket returns true if the gRPC address refers to a Unix domain socket. -func (g GRPC) IsUnixSocket() bool { - path := strings.TrimPrefix(g.Addr, "unix:") - return strings.Contains(path, "/") -} - -// SocketPath returns the filesystem path for a Unix socket address, +// SocketPath returns the filesystem path for the Unix socket, // stripping any "unix:" prefix. func (g GRPC) SocketPath() string { return strings.TrimPrefix(g.Addr, "unix:") @@ -93,6 +83,10 @@ func Load(path string) (*Config, error) { return nil, fmt.Errorf("parsing config: %w", err) } + if err := cfg.applyEnvOverrides(); err != nil { + return nil, fmt.Errorf("applying env overrides: %w", err) + } + if err := cfg.validate(); err != nil { return nil, fmt.Errorf("invalid config: %w", err) } @@ -100,6 +94,69 @@ func Load(path string) (*Config, error) { return &cfg, nil } +// applyEnvOverrides applies environment variable overrides to the config. +// Variables use the MCPROXY_ prefix with underscore-separated paths. +func (c *Config) applyEnvOverrides() error { + // Database + if v := os.Getenv("MCPROXY_DATABASE_PATH"); v != "" { + c.Database.Path = v + } + + // gRPC + if v := os.Getenv("MCPROXY_GRPC_ADDR"); v != "" { + c.GRPC.Addr = v + } + + // Firewall + if v := os.Getenv("MCPROXY_FIREWALL_GEOIP_DB"); v != "" { + c.Firewall.GeoIPDB = v + } + if v := os.Getenv("MCPROXY_FIREWALL_RATE_LIMIT"); v != "" { + var n int64 + if _, err := fmt.Sscanf(v, "%d", &n); err != nil { + return fmt.Errorf("MCPROXY_FIREWALL_RATE_LIMIT: %w", err) + } + c.Firewall.RateLimit = n + } + if v := os.Getenv("MCPROXY_FIREWALL_RATE_WINDOW"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + return fmt.Errorf("MCPROXY_FIREWALL_RATE_WINDOW: %w", err) + } + c.Firewall.RateWindow = Duration{d} + } + + // Proxy timeouts + if v := os.Getenv("MCPROXY_PROXY_CONNECT_TIMEOUT"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + return fmt.Errorf("MCPROXY_PROXY_CONNECT_TIMEOUT: %w", err) + } + c.Proxy.ConnectTimeout = Duration{d} + } + if v := os.Getenv("MCPROXY_PROXY_IDLE_TIMEOUT"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + return fmt.Errorf("MCPROXY_PROXY_IDLE_TIMEOUT: %w", err) + } + c.Proxy.IdleTimeout = Duration{d} + } + if v := os.Getenv("MCPROXY_PROXY_SHUTDOWN_TIMEOUT"); v != "" { + d, err := time.ParseDuration(v) + if err != nil { + return fmt.Errorf("MCPROXY_PROXY_SHUTDOWN_TIMEOUT: %w", err) + } + c.Proxy.ShutdownTimeout = Duration{d} + } + + // Log + if v := os.Getenv("MCPROXY_LOG_LEVEL"); v != "" { + c.Log.Level = v + } + + return nil +} + func (c *Config) validate() error { if c.Database.Path == "" { return fmt.Errorf("database.path is required") @@ -139,11 +196,11 @@ func (c *Config) validate() error { return fmt.Errorf("firewall.rate_window is required when rate_limit is set") } - // Validate gRPC config: if enabled, TLS cert and key are required - // (unless using a Unix socket, which doesn't need TLS). - if c.GRPC.Addr != "" && !c.GRPC.IsUnixSocket() { - if c.GRPC.TLSCert == "" || c.GRPC.TLSKey == "" { - return fmt.Errorf("grpc: tls_cert and tls_key are required when grpc.addr is a TCP address") + // Validate gRPC config: if enabled, addr must be a Unix socket path. + if c.GRPC.Addr != "" { + path := c.GRPC.SocketPath() + if !strings.Contains(path, "/") { + return fmt.Errorf("grpc.addr must be a Unix socket path (e.g., /var/run/mc-proxy.sock)") } } diff --git a/internal/config/config_test.go b/internal/config/config_test.go index d31c7e1..f15b86a 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -195,26 +195,23 @@ addr = ":8443" } } -func TestGRPCIsUnixSocket(t *testing.T) { +func TestGRPCSocketPath(t *testing.T) { tests := []struct { addr string - want bool + want string }{ - {"/var/run/mc-proxy.sock", true}, - {"unix:/var/run/mc-proxy.sock", true}, - {"127.0.0.1:9090", false}, - {":9090", false}, - {"", false}, + {"/var/run/mc-proxy.sock", "/var/run/mc-proxy.sock"}, + {"unix:/var/run/mc-proxy.sock", "/var/run/mc-proxy.sock"}, } for _, tt := range tests { g := GRPC{Addr: tt.addr} - if got := g.IsUnixSocket(); got != tt.want { - t.Fatalf("IsUnixSocket(%q) = %v, want %v", tt.addr, got, tt.want) + if got := g.SocketPath(); got != tt.want { + t.Fatalf("SocketPath(%q) = %q, want %q", tt.addr, got, tt.want) } } } -func TestValidateGRPCUnixNoTLS(t *testing.T) { +func TestValidateGRPCUnixSocket(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.toml") @@ -231,11 +228,11 @@ addr = "/var/run/mc-proxy.sock" _, err := Load(path) if err != nil { - t.Fatalf("expected Unix socket without TLS to be valid, got: %v", err) + t.Fatalf("expected Unix socket to be valid, got: %v", err) } } -func TestValidateGRPCTCPRequiresTLS(t *testing.T) { +func TestValidateGRPCRejectsTCPAddr(t *testing.T) { dir := t.TempDir() path := filepath.Join(dir, "test.toml") @@ -252,7 +249,7 @@ addr = "127.0.0.1:9090" _, err := Load(path) if err == nil { - t.Fatal("expected error for TCP gRPC addr without TLS certs") + t.Fatal("expected error for TCP gRPC addr") } } @@ -311,3 +308,86 @@ func TestDuration(t *testing.T) { t.Fatalf("got %v, want 5s", d.Duration) } } + +func TestEnvOverrides(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" + +[proxy] +idle_timeout = "60s" + +[log] +level = "info" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + // Set env overrides. + t.Setenv("MCPROXY_LOG_LEVEL", "debug") + t.Setenv("MCPROXY_PROXY_IDLE_TIMEOUT", "600s") + t.Setenv("MCPROXY_DATABASE_PATH", "/override/test.db") + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Log.Level != "debug" { + t.Fatalf("got log.level %q, want %q", cfg.Log.Level, "debug") + } + if cfg.Proxy.IdleTimeout.Duration.Seconds() != 600 { + t.Fatalf("got idle_timeout %v, want 600s", cfg.Proxy.IdleTimeout.Duration) + } + if cfg.Database.Path != "/override/test.db" { + t.Fatalf("got database.path %q, want %q", cfg.Database.Path, "/override/test.db") + } +} + +func TestEnvOverrideInvalidDuration(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + t.Setenv("MCPROXY_PROXY_IDLE_TIMEOUT", "not-a-duration") + + _, err := Load(path) + if err == nil { + t.Fatal("expected error for invalid duration") + } +} + +func TestEnvOverrideGRPCAddr(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "test.toml") + + data := ` +[database] +path = "/tmp/test.db" +` + if err := os.WriteFile(path, []byte(data), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + + t.Setenv("MCPROXY_GRPC_ADDR", "/var/run/override.sock") + + cfg, err := Load(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.GRPC.Addr != "/var/run/override.sock" { + t.Fatalf("got grpc.addr %q, want %q", cfg.GRPC.Addr, "/var/run/override.sock") + } +} diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index 9ddd353..197883c 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -2,8 +2,6 @@ package grpcserver import ( "context" - "crypto/tls" - "crypto/x509" "fmt" "log/slog" "net" @@ -14,7 +12,8 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" + "google.golang.org/grpc/health" + healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -34,8 +33,16 @@ type AdminServer struct { logger *slog.Logger } -// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP -// addresses, TLS is required with optional mTLS. +// NewAdminServer creates an AdminServer for use in testing or custom setups. +func NewAdminServer(srv *server.Server, store *db.Store, logger *slog.Logger) *AdminServer { + return &AdminServer{ + srv: srv, + store: store, + logger: logger, + } +} + +// New creates a gRPC server listening on a Unix socket. func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) { admin := &AdminServer{ srv: srv, @@ -43,13 +50,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg logger: logger, } - if cfg.IsUnixSocket() { - return newUnixServer(cfg, admin) - } - return newTCPServer(cfg, admin) -} - -func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) { path := cfg.SocketPath() // Remove stale socket file from a previous run. @@ -67,41 +67,12 @@ func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Liste grpcServer := grpc.NewServer() pb.RegisterProxyAdminServiceServer(grpcServer, admin) - return grpcServer, ln, nil -} -func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) { - cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) - if err != nil { - return nil, nil, fmt.Errorf("loading TLS keypair: %w", err) - } - - tlsConfig := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS13, - } - - if cfg.ClientCA != "" { - caCert, err := os.ReadFile(cfg.ClientCA) - if err != nil { - return nil, nil, fmt.Errorf("reading client CA: %w", err) - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(caCert) { - return nil, nil, fmt.Errorf("failed to parse client CA certificate") - } - tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert - tlsConfig.ClientCAs = pool - } - - creds := credentials.NewTLS(tlsConfig) - grpcServer := grpc.NewServer(grpc.Creds(creds)) - pb.RegisterProxyAdminServiceServer(grpcServer, admin) - - ln, err := net.Listen("tcp", cfg.Addr) - if err != nil { - return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err) - } + // Register standard gRPC health check service. + healthServer := health.NewServer() + healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) + healthServer.SetServingStatus("mc_proxy.v1.ProxyAdminService", healthpb.HealthCheckResponse_SERVING) + healthpb.RegisterHealthServer(grpcServer, healthServer) return grpcServer, ln, nil }