package grpcserver import ( "context" "net" "path/filepath" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1" "git.wntrmute.dev/kyle/mcr/internal/auth" "git.wntrmute.dev/kyle/mcr/internal/db" ) // fakeValidator is a test double for server.TokenValidator. type fakeValidator struct { claims *auth.Claims err error } func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) { return f.claims, f.err } // openTestDB creates a temporary test database with migrations applied. func openTestDB(t *testing.T) *db.DB { t.Helper() path := filepath.Join(t.TempDir(), "test.db") d, err := db.Open(path) if err != nil { t.Fatalf("Open: %v", err) } t.Cleanup(func() { _ = d.Close() }) if err := d.Migrate(); err != nil { t.Fatalf("Migrate: %v", err) } return d } // startTestServer creates a gRPC server and client for testing. // Returns the client connection and a cleanup function. func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn { t.Helper() srv, err := New("", "", deps) if err != nil { t.Fatalf("New: %v", err) } lis, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("Listen: %v", err) } go func() { _ = srv.Serve(lis) }() t.Cleanup(func() { srv.GracefulStop() }) //nolint:gosec // insecure credentials for testing only cc, err := grpc.NewClient( lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.ForceCodecV2(pb.JSONCodec{})), ) if err != nil { t.Fatalf("Dial: %v", err) } t.Cleanup(func() { _ = cc.Close() }) return cc } // withAuth adds a bearer token to the outgoing context metadata. func withAuth(ctx context.Context, token string) context.Context { return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token) } func TestHealthBypassesAuth(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{err: auth.ErrUnauthorized} cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) client := pb.NewAdminServiceClient(cc) resp, err := client.Health(context.Background(), &pb.HealthRequest{}) if err != nil { t.Fatalf("Health: %v", err) } if resp.Status != "ok" { t.Fatalf("Health status: got %q, want %q", resp.Status, "ok") } } func TestAuthInterceptorNoToken(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{err: auth.ErrUnauthorized} cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) client := pb.NewRegistryServiceClient(cc) _, err := client.ListRepositories(context.Background(), &pb.ListRepositoriesRequest{}) if err == nil { t.Fatal("expected error for unauthenticated request") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.Unauthenticated { t.Fatalf("code: got %v, want Unauthenticated", st.Code()) } } func TestAuthInterceptorInvalidToken(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{err: auth.ErrUnauthorized} cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "bad-token") client := pb.NewRegistryServiceClient(cc) _, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{}) if err == nil { t.Fatal("expected error for invalid token") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.Unauthenticated { t.Fatalf("code: got %v, want Unauthenticated", st.Code()) } } func TestAuthInterceptorValidToken(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "alice", AccountType: "human", Roles: []string{"user"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") client := pb.NewRegistryServiceClient(cc) resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{}) if err != nil { t.Fatalf("ListRepositories: %v", err) } if resp == nil { t.Fatal("expected non-nil response") } } func TestAdminInterceptorDenied(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") // Policy RPCs require admin. policyClient := pb.NewPolicyServiceClient(cc) _, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{}) if err == nil { t.Fatal("expected error for non-admin user") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.PermissionDenied { t.Fatalf("code: got %v, want PermissionDenied", st.Code()) } } func TestAdminInterceptorAllowed(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") // Admin user should be able to list policy rules. policyClient := pb.NewPolicyServiceClient(cc) resp, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{}) if err != nil { t.Fatalf("ListPolicyRules: %v", err) } if resp == nil { t.Fatal("expected non-nil response") } } func TestAdminRequiredMethodsCompleteness(t *testing.T) { // Verify that admin-required methods match our security spec. // This test catches the security defect of adding an RPC without // adding it to the adminRequiredMethods map. expected := []string{ "/mcr.v1.RegistryService/DeleteRepository", "/mcr.v1.RegistryService/GarbageCollect", "/mcr.v1.RegistryService/GetGCStatus", "/mcr.v1.PolicyService/ListPolicyRules", "/mcr.v1.PolicyService/CreatePolicyRule", "/mcr.v1.PolicyService/GetPolicyRule", "/mcr.v1.PolicyService/UpdatePolicyRule", "/mcr.v1.PolicyService/DeletePolicyRule", "/mcr.v1.AuditService/ListAuditEvents", } for _, method := range expected { if !adminRequiredMethods[method] { t.Errorf("method %s should require admin but is not in adminRequiredMethods", method) } } if len(adminRequiredMethods) != len(expected) { t.Errorf("adminRequiredMethods has %d entries, expected %d", len(adminRequiredMethods), len(expected)) } } func TestAuthBypassMethodsCompleteness(t *testing.T) { // Health is the only method that bypasses auth. expected := []string{ "/mcr.v1.AdminService/Health", } for _, method := range expected { if !authBypassMethods[method] { t.Errorf("method %s should bypass auth but is not in authBypassMethods", method) } } if len(authBypassMethods) != len(expected) { t.Errorf("authBypassMethods has %d entries, expected %d", len(authBypassMethods), len(expected)) } } func TestDeleteRepoRequiresAdmin(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") client := pb.NewRegistryServiceClient(cc) _, err := client.DeleteRepository(ctx, &pb.DeleteRepositoryRequest{Name: "test"}) if err == nil { t.Fatal("expected error for non-admin user trying to delete repo") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.PermissionDenied { t.Fatalf("code: got %v, want PermissionDenied", st.Code()) } } func TestGCRequiresAdmin(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") client := pb.NewRegistryServiceClient(cc) _, err := client.GarbageCollect(ctx, &pb.GarbageCollectRequest{}) if err == nil { t.Fatal("expected error for non-admin user trying to trigger GC") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.PermissionDenied { t.Fatalf("code: got %v, want PermissionDenied", st.Code()) } } func TestAuditRequiresAdmin(t *testing.T) { database := openTestDB(t) validator := &fakeValidator{ claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}, } cc := startTestServer(t, Deps{ DB: database, Validator: validator, }) ctx := withAuth(context.Background(), "valid-token") client := pb.NewAuditServiceClient(cc) _, err := client.ListAuditEvents(ctx, &pb.ListAuditEventsRequest{}) if err == nil { t.Fatal("expected error for non-admin user trying to list audit events") } st, ok := status.FromError(err) if !ok { t.Fatalf("expected gRPC status error, got %v", err) } if st.Code() != codes.PermissionDenied { t.Fatalf("code: got %v, want PermissionDenied", st.Code()) } }