package grpcserver import ( "context" "io" "log/slog" "net" "path/filepath" "testing" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1" "git.wntrmute.dev/mc/mc-proxy/internal/config" "git.wntrmute.dev/mc/mc-proxy/internal/db" "git.wntrmute.dev/mc/mc-proxy/internal/firewall" "git.wntrmute.dev/mc/mc-proxy/internal/server" ) // testEnv bundles all the objects needed for a grpcserver test. type testEnv struct { client pb.ProxyAdminServiceClient conn *grpc.ClientConn store *db.Store srv *server.Server } func setup(t *testing.T) *testEnv { t.Helper() // Database in temp dir. dbPath := filepath.Join(t.TempDir(), "test.db") store, err := db.Open(dbPath) if err != nil { t.Fatalf("open db: %v", err) } t.Cleanup(func() { store.Close() }) if err := store.Migrate(); err != nil { t.Fatalf("migrate: %v", err) } // Seed with one listener and one route. listeners := []config.Listener{ { Addr: ":443", Routes: []config.Route{ {Hostname: "a.test", Backend: "127.0.0.1:8443"}, }, }, } fw := config.Firewall{ BlockedIPs: []string{"10.0.0.1"}, } if err := store.Seed(listeners, fw); err != nil { t.Fatalf("seed: %v", err) } // Build server with matching in-memory state. fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0) if err != nil { t.Fatalf("firewall: %v", err) } cfg := &config.Config{ Proxy: config.Proxy{ ConnectTimeout: config.Duration{Duration: 5 * time.Second}, IdleTimeout: config.Duration{Duration: 30 * time.Second}, ShutdownTimeout: config.Duration{Duration: 5 * time.Second}, }, } logger := slog.New(slog.NewTextHandler(io.Discard, nil)) // Load listener data from DB to get correct IDs. dbListeners, err := store.ListListeners() if err != nil { t.Fatalf("list listeners: %v", err) } var listenerData []server.ListenerData for _, l := range dbListeners { dbRoutes, err := store.ListRoutes(l.ID) if err != nil { t.Fatalf("list routes: %v", err) } routes := make(map[string]server.RouteInfo, len(dbRoutes)) for _, r := range dbRoutes { routes[r.Hostname] = server.RouteInfo{ Backend: r.Backend, Mode: r.Mode, } } listenerData = append(listenerData, server.ListenerData{ ID: l.ID, Addr: l.Addr, ProxyProtocol: l.ProxyProtocol, Routes: routes, }) } srv := server.New(cfg, fwObj, listenerData, logger, "test-version") // Set up bufconn gRPC server (no TLS for tests). lis := bufconn.Listen(1024 * 1024) grpcSrv := grpc.NewServer() admin := &AdminServer{ srv: srv, store: store, logger: logger, } pb.RegisterProxyAdminServiceServer(grpcSrv, admin) go func() { if err := grpcSrv.Serve(lis); err != nil { t.Logf("grpc serve: %v", err) } }() t.Cleanup(grpcSrv.Stop) conn, err := grpc.NewClient("passthrough://bufconn", grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) { return lis.DialContext(ctx) }), grpc.WithTransportCredentials(insecure.NewCredentials()), ) if err != nil { t.Fatalf("dial bufconn: %v", err) } t.Cleanup(func() { conn.Close() }) return &testEnv{ client: pb.NewProxyAdminServiceClient(conn), conn: conn, store: store, srv: srv, } } func TestGetStatus(t *testing.T) { env := setup(t) ctx := context.Background() resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) if err != nil { t.Fatalf("GetStatus: %v", err) } if resp.Version != "test-version" { t.Fatalf("got version %q, want %q", resp.Version, "test-version") } if len(resp.Listeners) != 1 { t.Fatalf("got %d listeners, want 1", len(resp.Listeners)) } if resp.Listeners[0].Addr != ":443" { t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443") } if resp.Listeners[0].RouteCount != 1 { t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount) } } func TestListRoutes(t *testing.T) { env := setup(t) ctx := context.Background() resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) if err != nil { t.Fatalf("ListRoutes: %v", err) } if len(resp.Routes) != 1 { t.Fatalf("got %d routes, want 1", len(resp.Routes)) } if resp.Routes[0].Hostname != "a.test" { t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test") } if resp.Routes[0].Backend != "127.0.0.1:8443" { t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443") } } func TestListRoutesNotFound(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"}) if err == nil { t.Fatal("expected error for nonexistent listener") } if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound { t.Fatalf("expected NotFound, got %v", err) } } func TestAddRoute(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"}, }) if err != nil { t.Fatalf("AddRoute: %v", err) } // Verify in-memory. resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) if err != nil { t.Fatalf("ListRoutes: %v", err) } if len(resp.Routes) != 2 { t.Fatalf("got %d routes, want 2", len(resp.Routes)) } // Verify in DB. dbListeners, err := env.store.ListListeners() if err != nil { t.Fatalf("list listeners: %v", err) } dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID) if err != nil { t.Fatalf("list routes: %v", err) } if len(dbRoutes) != 2 { t.Fatalf("DB has %d routes, want 2", len(dbRoutes)) } } func TestAddRouteDuplicate(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"}, }) if err == nil { t.Fatal("expected error for duplicate route") } if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists { t.Fatalf("expected AlreadyExists, got %v", err) } } func TestAddRouteValidation(t *testing.T) { env := setup(t) ctx := context.Background() // Missing route. _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"}) if err == nil { t.Fatal("expected error for nil route") } // Missing hostname. _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Backend: "127.0.0.1:1"}, }) if err == nil { t.Fatal("expected error for empty hostname") } // Missing backend. _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Hostname: "x.test"}, }) if err == nil { t.Fatal("expected error for empty backend") } // Invalid backend (not host:port). _, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"}, }) if err == nil { t.Fatal("expected error for invalid backend address") } } func TestRemoveRoute(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{ ListenerAddr: ":443", Hostname: "a.test", }) if err != nil { t.Fatalf("RemoveRoute: %v", err) } // Verify removed from memory. resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) if err != nil { t.Fatalf("ListRoutes: %v", err) } if len(resp.Routes) != 0 { t.Fatalf("got %d routes, want 0", len(resp.Routes)) } } func TestRemoveRouteNotFound(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{ ListenerAddr: ":443", Hostname: "nonexistent.test", }) if err == nil { t.Fatal("expected error for removing nonexistent route") } } func TestGetFirewallRules(t *testing.T) { env := setup(t) ctx := context.Background() resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) if err != nil { t.Fatalf("GetFirewallRules: %v", err) } if len(resp.Rules) != 1 { t.Fatalf("got %d rules, want 1", len(resp.Rules)) } if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP { t.Fatalf("got type %v, want IP", resp.Rules[0].Type) } if resp.Rules[0].Value != "10.0.0.1" { t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1") } } func TestAddFirewallRule(t *testing.T) { env := setup(t) ctx := context.Background() // Add IP rule. _, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: "10.0.0.2", }, }) if err != nil { t.Fatalf("AddFirewallRule IP: %v", err) } // Add CIDR rule. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, Value: "192.168.0.0/16", }, }) if err != nil { t.Fatalf("AddFirewallRule CIDR: %v", err) } // Add country rule. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, Value: "RU", }, }) if err != nil { t.Fatalf("AddFirewallRule country: %v", err) } // Verify. resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) if err != nil { t.Fatalf("GetFirewallRules: %v", err) } if len(resp.Rules) != 4 { t.Fatalf("got %d rules, want 4", len(resp.Rules)) } // Verify DB persistence. dbRules, err := env.store.ListFirewallRules() if err != nil { t.Fatalf("list firewall rules: %v", err) } if len(dbRules) != 4 { t.Fatalf("DB has %d rules, want 4", len(dbRules)) } } func TestAddFirewallRuleValidation(t *testing.T) { env := setup(t) ctx := context.Background() // Nil rule. _, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{}) if err == nil { t.Fatal("expected error for nil rule") } // Unknown type. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED, Value: "x", }, }) if err == nil { t.Fatal("expected error for unspecified rule type") } // Empty value. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, }, }) if err == nil { t.Fatal("expected error for empty value") } // Invalid IP address. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: "not-an-ip", }, }) if err == nil { t.Fatal("expected error for invalid IP") } // Invalid CIDR. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, Value: "not-a-cidr", }, }) if err == nil { t.Fatal("expected error for invalid CIDR") } // Non-canonical CIDR. _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, Value: "192.168.1.5/16", }, }) if err == nil { t.Fatal("expected error for non-canonical CIDR") } // Invalid country code (lowercase). _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, Value: "cn", }, }) if err == nil { t.Fatal("expected error for lowercase country code") } // Invalid country code (too long). _, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, Value: "USA", }, }) if err == nil { t.Fatal("expected error for 3-letter country code") } } func TestRemoveFirewallRule(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: "10.0.0.1", }, }) if err != nil { t.Fatalf("RemoveFirewallRule: %v", err) } resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{}) if err != nil { t.Fatalf("GetFirewallRules: %v", err) } if len(resp.Rules) != 0 { t.Fatalf("got %d rules, want 0", len(resp.Rules)) } } func TestRemoveFirewallRuleNotFound(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{ Rule: &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: "99.99.99.99", }, }) if err == nil { t.Fatal("expected error for removing nonexistent rule") } } func TestAddRouteL7(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{ Hostname: "l7.test", Backend: "127.0.0.1:8080", Mode: "l7", TlsCert: "/certs/l7.crt", TlsKey: "/certs/l7.key", BackendTls: false, SendProxyProtocol: true, }, }) if err != nil { t.Fatalf("AddRoute L7: %v", err) } // Verify in-memory via ListRoutes. resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) if err != nil { t.Fatalf("ListRoutes: %v", err) } var found *pb.Route for _, r := range resp.Routes { if r.Hostname == "l7.test" { found = r break } } if found == nil { t.Fatal("L7 route not found in ListRoutes response") } if found.Mode != "l7" { t.Fatalf("mode = %q, want %q", found.Mode, "l7") } if found.TlsCert != "/certs/l7.crt" { t.Fatalf("tls_cert = %q, want %q", found.TlsCert, "/certs/l7.crt") } if found.TlsKey != "/certs/l7.key" { t.Fatalf("tls_key = %q, want %q", found.TlsKey, "/certs/l7.key") } if found.BackendTls { t.Fatal("expected backend_tls = false") } if !found.SendProxyProtocol { t.Fatal("expected send_proxy_protocol = true") } // Verify DB persistence. dbListeners, _ := env.store.ListListeners() dbRoutes, _ := env.store.ListRoutes(dbListeners[0].ID) var dbRoute *db.Route for i := range dbRoutes { if dbRoutes[i].Hostname == "l7.test" { dbRoute = &dbRoutes[i] break } } if dbRoute == nil { t.Fatal("L7 route not found in DB") } if dbRoute.Mode != "l7" { t.Fatalf("DB mode = %q, want %q", dbRoute.Mode, "l7") } if !dbRoute.SendProxyProtocol { t.Fatal("DB send_proxy_protocol should be true") } } func TestAddRouteL7MissingCert(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{ Hostname: "nocert.test", Backend: "127.0.0.1:8080", Mode: "l7", }, }) if err == nil { t.Fatal("expected error for L7 route without cert/key") } if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument { t.Fatalf("expected InvalidArgument, got %v", err) } } func TestAddRouteInvalidMode(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{ Hostname: "badmode.test", Backend: "127.0.0.1:8080", Mode: "l5", }, }) if err == nil { t.Fatal("expected error for invalid mode") } if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument { t.Fatalf("expected InvalidArgument, got %v", err) } } func TestAddRouteDefaultsToL4(t *testing.T) { env := setup(t) ctx := context.Background() // Add route without specifying mode — should default to "l4". _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{ Hostname: "default.test", Backend: "127.0.0.1:9443", }, }) if err != nil { t.Fatalf("AddRoute: %v", err) } resp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) for _, r := range resp.Routes { if r.Hostname == "default.test" { if r.Mode != "l4" { t.Fatalf("mode = %q, want %q", r.Mode, "l4") } return } } t.Fatal("route not found") } func TestGetStatusProxyProtocol(t *testing.T) { env := setup(t) ctx := context.Background() resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) if err != nil { t.Fatalf("GetStatus: %v", err) } // The seeded listener has proxy_protocol = false. if len(resp.Listeners) != 1 { t.Fatalf("got %d listeners, want 1", len(resp.Listeners)) } if resp.Listeners[0].ProxyProtocol { t.Fatal("expected proxy_protocol = false") } } func TestSetListenerMaxConnections(t *testing.T) { env := setup(t) ctx := context.Background() // Set max connections. _, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ ListenerAddr: ":443", MaxConnections: 5000, }) if err != nil { t.Fatalf("SetListenerMaxConnections: %v", err) } // Verify via GetStatus. resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{}) if err != nil { t.Fatalf("GetStatus: %v", err) } if resp.Listeners[0].MaxConnections != 5000 { t.Fatalf("max_connections = %d, want 5000", resp.Listeners[0].MaxConnections) } // Verify DB persistence. l, _ := env.store.GetListenerByAddr(":443") if l.MaxConnections != 5000 { t.Fatalf("DB max_connections = %d, want 5000", l.MaxConnections) } // Set to 0 (unlimited). _, err = env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ ListenerAddr: ":443", MaxConnections: 0, }) if err != nil { t.Fatalf("SetListenerMaxConnections to 0: %v", err) } resp, _ = env.client.GetStatus(ctx, &pb.GetStatusRequest{}) if resp.Listeners[0].MaxConnections != 0 { t.Fatalf("max_connections = %d, want 0", resp.Listeners[0].MaxConnections) } } func TestSetListenerMaxConnectionsNotFound(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{ ListenerAddr: ":9999", MaxConnections: 100, }) if err == nil { t.Fatal("expected error for nonexistent listener") } if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound { t.Fatalf("expected NotFound, got %v", err) } } func TestAddListL7Policy(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"}, }) if err != nil { t.Fatalf("AddL7Policy: %v", err) } resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{ ListenerAddr: ":443", Hostname: "a.test", }) if err != nil { t.Fatalf("ListL7Policies: %v", err) } if len(resp.Policies) != 1 { t.Fatalf("got %d policies, want 1", len(resp.Policies)) } if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" { t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0]) } routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) for _, r := range routeResp.Routes { if r.Hostname == "a.test" && len(r.L7Policies) != 1 { t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies)) } } } func TestRemoveL7Policy(t *testing.T) { env := setup(t) ctx := context.Background() env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"}, }) _, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"}, }) if err != nil { t.Fatalf("RemoveL7Policy: %v", err) } resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{ ListenerAddr: ":443", Hostname: "a.test", }) if len(resp.Policies) != 0 { t.Fatalf("got %d policies after remove, want 0", len(resp.Policies)) } } func TestAddL7PolicyValidation(t *testing.T) { env := setup(t) ctx := context.Background() _, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"}, }) if err == nil { t.Fatal("expected error for invalid policy type") } _, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""}, }) if err == nil { t.Fatal("expected error for empty policy value") } _, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{ ListenerAddr: ":443", Hostname: "a.test", }) if err == nil { t.Fatal("expected error for nil policy") } }