package grpcserver import ( "context" "log/slog" "path/filepath" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2" "git.wntrmute.dev/kyle/metacrypt/internal/auth" "git.wntrmute.dev/kyle/metacrypt/internal/barrier" "git.wntrmute.dev/kyle/metacrypt/internal/config" "git.wntrmute.dev/kyle/metacrypt/internal/crypto" "git.wntrmute.dev/kyle/metacrypt/internal/db" "git.wntrmute.dev/kyle/metacrypt/internal/engine" "git.wntrmute.dev/kyle/metacrypt/internal/policy" "git.wntrmute.dev/kyle/metacrypt/internal/seal" ) // ---- test helpers ---- func fastArgon2Params() crypto.Argon2Params { return crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1} } // mockBarrier is a no-op barrier for engine registry tests. type mockBarrier struct{} func (m *mockBarrier) Unseal(_ []byte) error { return nil } func (m *mockBarrier) Seal() error { return nil } func (m *mockBarrier) IsSealed() bool { return false } func (m *mockBarrier) Get(_ context.Context, _ string) ([]byte, error) { return nil, barrier.ErrNotFound } func (m *mockBarrier) Put(_ context.Context, _ string, _ []byte) error { return nil } func (m *mockBarrier) Delete(_ context.Context, _ string) error { return nil } func (m *mockBarrier) List(_ context.Context, _ string) ([]string, error) { return nil, nil } // mockEngine is a minimal engine.Engine for registry tests. type mockEngine struct{ t engine.EngineType } func (m *mockEngine) Type() engine.EngineType { return m.t } func (m *mockEngine) Initialize(_ context.Context, _ barrier.Barrier, _ string, _ map[string]interface{}) error { return nil } func (m *mockEngine) Unseal(_ context.Context, _ barrier.Barrier, _ string) error { return nil } func (m *mockEngine) Seal() error { return nil } func (m *mockEngine) HandleRequest(_ context.Context, _ *engine.Request) (*engine.Response, error) { return &engine.Response{Data: map[string]interface{}{"ok": true}}, nil } func newTestRegistry() *engine.Registry { reg := engine.NewRegistry(&mockBarrier{}, slog.Default()) reg.RegisterFactory(engine.EngineTypeTransit, func() engine.Engine { return &mockEngine{t: engine.EngineTypeTransit} }) return reg } func newTestGRPCServer(t *testing.T) (*GRPCServer, func()) { t.Helper() dir := t.TempDir() database, err := db.Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open db: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate: %v", err) } b := barrier.NewAESGCMBarrier(database) sealMgr := seal.NewManager(database, b, slog.Default()) policyEngine := policy.NewEngine(b) reg := newTestRegistry() authenticator := auth.NewAuthenticator(nil, slog.Default()) cfg := &config.Config{ Seal: config.SealConfig{ Argon2Time: 1, Argon2Memory: 64 * 1024, Argon2Threads: 1, }, } srv := New(cfg, sealMgr, authenticator, policyEngine, reg, slog.Default()) return srv, func() { _ = database.Close() } } // okHandler is a grpc.UnaryHandler that always succeeds. func okHandler(_ context.Context, _ interface{}) (interface{}, error) { return "ok", nil } func methodInfo(name string) *grpc.UnaryServerInfo { return &grpc.UnaryServerInfo{FullMethod: name} } // ---- interceptor tests ---- func TestSealInterceptor_Unsealed(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() // Initialize and unseal so state == StateUnsealed. if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil { t.Fatalf("initialize: %v", err) } methods := map[string]bool{"/test.Service/Method": true} interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods) resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err != nil { t.Fatalf("expected success when unsealed, got: %v", err) } if resp != "ok" { t.Errorf("expected 'ok', got %v", resp) } } func TestSealInterceptor_Sealed(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() // Initialize then seal. if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil { t.Fatalf("initialize: %v", err) } if err := srv.sealMgr.Seal(); err != nil { t.Fatalf("seal: %v", err) } methods := map[string]bool{"/test.Service/Method": true} interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods) _, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err == nil { t.Fatal("expected error when sealed") } if code := status.Code(err); code != codes.FailedPrecondition { t.Errorf("expected FailedPrecondition, got %v", code) } } func TestSealInterceptor_SkipsUnlistedMethod(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() // State is uninitialized (sealed), but method is not in the list. methods := map[string]bool{"/test.Service/Other": true} interceptor := sealInterceptor(srv.sealMgr, slog.Default(), methods) resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err != nil { t.Fatalf("expected pass-through, got: %v", err) } if resp != "ok" { t.Errorf("expected 'ok', got %v", resp) } } func TestAuthInterceptor_MissingToken(t *testing.T) { authenticator := auth.NewAuthenticator(nil, slog.Default()) methods := map[string]bool{"/test.Service/Method": true} interceptor := authInterceptor(authenticator, slog.Default(), methods) _, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err == nil { t.Fatal("expected error for missing token") } if code := status.Code(err); code != codes.Unauthenticated { t.Errorf("expected Unauthenticated, got %v", code) } } func TestAuthInterceptor_SkipsUnlistedMethod(t *testing.T) { authenticator := auth.NewAuthenticator(nil, slog.Default()) methods := map[string]bool{"/test.Service/Other": true} interceptor := authInterceptor(authenticator, slog.Default(), methods) resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err != nil { t.Fatalf("expected pass-through, got: %v", err) } if resp != "ok" { t.Errorf("expected 'ok', got %v", resp) } } func TestAdminInterceptor_NoTokenInfo(t *testing.T) { methods := map[string]bool{"/test.Service/Admin": true} interceptor := adminInterceptor(slog.Default(), methods) _, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Admin"), okHandler) if err == nil { t.Fatal("expected error when no token info in context") } if code := status.Code(err); code != codes.PermissionDenied { t.Errorf("expected PermissionDenied, got %v", code) } } func TestAdminInterceptor_NonAdmin(t *testing.T) { methods := map[string]bool{"/test.Service/Admin": true} interceptor := adminInterceptor(slog.Default(), methods) ctx := context.WithValue(context.Background(), tokenInfoKey, &auth.TokenInfo{ Username: "user", IsAdmin: false, }) _, err := interceptor(ctx, nil, methodInfo("/test.Service/Admin"), okHandler) if err == nil { t.Fatal("expected error for non-admin") } if code := status.Code(err); code != codes.PermissionDenied { t.Errorf("expected PermissionDenied, got %v", code) } } func TestAdminInterceptor_Admin(t *testing.T) { methods := map[string]bool{"/test.Service/Admin": true} interceptor := adminInterceptor(slog.Default(), methods) ctx := context.WithValue(context.Background(), tokenInfoKey, &auth.TokenInfo{ Username: "admin", IsAdmin: true, }) resp, err := interceptor(ctx, nil, methodInfo("/test.Service/Admin"), okHandler) if err != nil { t.Fatalf("expected success for admin, got: %v", err) } if resp != "ok" { t.Errorf("expected 'ok', got %v", resp) } } func TestAdminInterceptor_SkipsUnlistedMethod(t *testing.T) { methods := map[string]bool{"/test.Service/Other": true} interceptor := adminInterceptor(slog.Default(), methods) // No token info in context — but method not listed, so should pass through. resp, err := interceptor(context.Background(), nil, methodInfo("/test.Service/Method"), okHandler) if err != nil { t.Fatalf("expected pass-through, got: %v", err) } if resp != "ok" { t.Errorf("expected 'ok', got %v", resp) } } func TestChainInterceptors(t *testing.T) { var order []int makeInterceptor := func(n int) grpc.UnaryServerInterceptor { return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { order = append(order, n) return handler(ctx, req) } } chained := chainInterceptors(makeInterceptor(1), makeInterceptor(2), makeInterceptor(3)) _, err := chained(context.Background(), nil, methodInfo("/test/Method"), okHandler) if err != nil { t.Fatalf("chain: %v", err) } if len(order) != 3 || order[0] != 1 || order[1] != 2 || order[2] != 3 { t.Errorf("expected execution order [1 2 3], got %v", order) } } func TestExtractToken(t *testing.T) { tests := []struct { name string md metadata.MD expected string }{ {"no metadata", nil, ""}, {"no authorization", metadata.Pairs("other", "val"), ""}, {"bearer token", metadata.Pairs("authorization", "Bearer mytoken"), "mytoken"}, {"raw token", metadata.Pairs("authorization", "mytoken"), "mytoken"}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { var ctx context.Context if tc.md != nil { ctx = metadata.NewIncomingContext(context.Background(), tc.md) } else { ctx = context.Background() } got := extractToken(ctx) if got != tc.expected { t.Errorf("extractToken: got %q, want %q", got, tc.expected) } }) } } // ---- systemServer tests ---- func TestSystemStatus(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} resp, err := ss.Status(context.Background(), &pb.StatusRequest{}) if err != nil { t.Fatalf("Status: %v", err) } if resp.State != "uninitialized" { t.Errorf("expected 'uninitialized', got %q", resp.State) } } func TestSystemInit_EmptyPassword(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} _, err := ss.Init(context.Background(), &pb.InitRequest{Password: ""}) if err == nil { t.Fatal("expected error for empty password") } if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestSystemInit_Success(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} resp, err := ss.Init(context.Background(), &pb.InitRequest{Password: "testpassword"}) if err != nil { t.Fatalf("Init: %v", err) } if resp.State != "unsealed" { t.Errorf("expected 'unsealed' after init, got %q", resp.State) } } func TestSystemInit_AlreadyInitialized(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil { t.Fatalf("first Init: %v", err) } _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}) if err == nil { t.Fatal("expected error on second Init") } if code := status.Code(err); code != codes.AlreadyExists { t.Errorf("expected AlreadyExists, got %v", code) } } func TestSystemUnseal_NotInitialized(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} _, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "pw"}) if err == nil { t.Fatal("expected error when not initialized") } if code := status.Code(err); code != codes.FailedPrecondition { t.Errorf("expected FailedPrecondition, got %v", code) } } func TestSystemUnseal_InvalidPassword(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "correct"}); err != nil { t.Fatalf("Init: %v", err) } if err := srv.sealMgr.Seal(); err != nil { t.Fatalf("Seal: %v", err) } _, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "wrong"}) if err == nil { t.Fatal("expected error for wrong password") } if code := status.Code(err); code != codes.Unauthenticated { t.Errorf("expected Unauthenticated, got %v", code) } } func TestSystemUnseal_Success(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil { t.Fatalf("Init: %v", err) } if err := srv.sealMgr.Seal(); err != nil { t.Fatalf("Seal: %v", err) } resp, err := ss.Unseal(context.Background(), &pb.UnsealRequest{Password: "pw"}) if err != nil { t.Fatalf("Unseal: %v", err) } if resp.State != "unsealed" { t.Errorf("expected 'unsealed', got %q", resp.State) } } func TestSystemSeal_Success(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ss := &systemServer{s: srv} if _, err := ss.Init(context.Background(), &pb.InitRequest{Password: "pw"}); err != nil { t.Fatalf("Init: %v", err) } resp, err := ss.Seal(context.Background(), &pb.SealRequest{}) if err != nil { t.Fatalf("Seal: %v", err) } if resp.State != "sealed" { t.Errorf("expected 'sealed', got %q", resp.State) } } // ---- engineServer tests ---- func TestEngineMount_MissingFields(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "", Type: "transit"}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("empty name: expected InvalidArgument, got %v", code) } _, err = es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: ""}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("empty type: expected InvalidArgument, got %v", code) } } func TestEngineMount_UnknownType(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "test", Type: "unknown"}) if err == nil { t.Fatal("expected error for unknown engine type") } if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestEngineMount_Success(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}) if err != nil { t.Fatalf("Mount: %v", err) } } func TestEngineMount_Duplicate(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}); err != nil { t.Fatalf("first Mount: %v", err) } _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}) if err == nil { t.Fatal("expected error for duplicate mount") } if code := status.Code(err); code != codes.AlreadyExists { t.Errorf("expected AlreadyExists, got %v", code) } } func TestEngineUnmount_MissingName(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} _, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: ""}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestEngineUnmount_NotFound(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} _, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: "nonexistent"}) if code := status.Code(err); code != codes.NotFound { t.Errorf("expected NotFound, got %v", code) } } func TestEngineUnmount_Success(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "default", Type: "transit"}); err != nil { t.Fatalf("Mount: %v", err) } if _, err := es.Unmount(context.Background(), &pb.UnmountRequest{Name: "default"}); err != nil { t.Fatalf("Unmount: %v", err) } } func TestEngineListMounts(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() es := &engineServer{s: srv} resp, err := es.ListMounts(context.Background(), &pb.ListMountsRequest{}) if err != nil { t.Fatalf("ListMounts: %v", err) } if len(resp.Mounts) != 0 { t.Errorf("expected 0 mounts, got %d", len(resp.Mounts)) } if _, err := es.Mount(context.Background(), &pb.MountRequest{Name: "eng1", Type: "transit"}); err != nil { t.Fatalf("Mount: %v", err) } resp, err = es.ListMounts(context.Background(), &pb.ListMountsRequest{}) if err != nil { t.Fatalf("ListMounts after mount: %v", err) } if len(resp.Mounts) != 1 { t.Errorf("expected 1 mount, got %d", len(resp.Mounts)) } if resp.Mounts[0].Name != "eng1" { t.Errorf("mount name: got %q, want %q", resp.Mounts[0].Name, "eng1") } } // ---- policyServer tests ---- func TestPolicyCreate_MissingID(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ps := &policyServer{s: srv} _, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{ Rule: &pb.PolicyRule{Id: ""}, }) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestPolicyCreate_NilRule(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ps := &policyServer{s: srv} _, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{Rule: nil}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestPolicyRoundtrip(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() // Policy engine needs an unsealed barrier; unseal it via the seal manager. if err := srv.sealMgr.Initialize(context.Background(), []byte("pw"), fastArgon2Params()); err != nil { t.Fatalf("initialize: %v", err) } ps := &policyServer{s: srv} rule := &pb.PolicyRule{ Id: "rule-1", Priority: 10, Effect: "allow", Usernames: []string{"alice"}, Resources: []string{"/ca/*"}, Actions: []string{"read"}, } // Create. createResp, err := ps.CreatePolicy(context.Background(), &pb.CreatePolicyRequest{Rule: rule}) if err != nil { t.Fatalf("CreatePolicy: %v", err) } if createResp.Rule.Id != "rule-1" { t.Errorf("created rule id: got %q, want %q", createResp.Rule.Id, "rule-1") } // Get. getResp, err := ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: "rule-1"}) if err != nil { t.Fatalf("GetPolicy: %v", err) } if getResp.Rule.Id != "rule-1" { t.Errorf("get rule id: got %q, want %q", getResp.Rule.Id, "rule-1") } // List. listResp, err := ps.ListPolicies(context.Background(), &pb.ListPoliciesRequest{}) if err != nil { t.Fatalf("ListPolicies: %v", err) } if len(listResp.Rules) != 1 { t.Errorf("expected 1 rule, got %d", len(listResp.Rules)) } // Delete. if _, err := ps.DeletePolicy(context.Background(), &pb.DeletePolicyRequest{Id: "rule-1"}); err != nil { t.Fatalf("DeletePolicy: %v", err) } // Get after delete should fail. _, err = ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: "rule-1"}) if code := status.Code(err); code != codes.NotFound { t.Errorf("expected NotFound after delete, got %v", code) } } func TestPolicyGet_MissingID(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ps := &policyServer{s: srv} _, err := ps.GetPolicy(context.Background(), &pb.GetPolicyRequest{Id: ""}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } func TestPolicyDelete_MissingID(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() ps := &policyServer{s: srv} _, err := ps.DeletePolicy(context.Background(), &pb.DeletePolicyRequest{Id: ""}) if code := status.Code(err); code != codes.InvalidArgument { t.Errorf("expected InvalidArgument, got %v", code) } } // ---- authServer tests ---- func TestAuthTokenInfo_FromContext(t *testing.T) { srv, cleanup := newTestGRPCServer(t) defer cleanup() as := &authServer{s: srv} ti := &auth.TokenInfo{Username: "alice", Roles: []string{"user"}, IsAdmin: false} ctx := context.WithValue(context.Background(), tokenInfoKey, ti) resp, err := as.TokenInfo(ctx, &pb.TokenInfoRequest{}) if err != nil { t.Fatalf("TokenInfo: %v", err) } if resp.Username != "alice" { t.Errorf("username: got %q, want %q", resp.Username, "alice") } if resp.IsAdmin { t.Error("expected IsAdmin=false") } } // ---- pbToRule / ruleToPB conversion tests ---- func TestPbToRuleRoundtrip(t *testing.T) { original := &pb.PolicyRule{ Id: "test-rule", Priority: 5, Effect: "deny", Usernames: []string{"bob"}, Roles: []string{"operator"}, Resources: []string{"/pki/*"}, Actions: []string{"write", "delete"}, } rule := pbToRule(original) if rule.ID != original.Id { t.Errorf("ID: got %q, want %q", rule.ID, original.Id) } if rule.Priority != int(original.Priority) { t.Errorf("Priority: got %d, want %d", rule.Priority, original.Priority) } if string(rule.Effect) != original.Effect { t.Errorf("Effect: got %q, want %q", rule.Effect, original.Effect) } back := ruleToPB(rule) if back.Id != original.Id { t.Errorf("roundtrip Id: got %q, want %q", back.Id, original.Id) } if back.Priority != original.Priority { t.Errorf("roundtrip Priority: got %d, want %d", back.Priority, original.Priority) } if back.Effect != original.Effect { t.Errorf("roundtrip Effect: got %q, want %q", back.Effect, original.Effect) } }