Files
mcr/internal/grpcserver/interceptors_test.go
Kyle Isom d5580f01f2 Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00

415 lines
11 KiB
Go

package grpcserver
import (
"context"
"encoding/json"
"log/slog"
"net"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
mcdslauth "git.wntrmute.dev/mc/mcdsl/auth"
pb "git.wntrmute.dev/mc/mcr/gen/mcr/v1"
"git.wntrmute.dev/mc/mcr/internal/db"
)
// mockMCIAS starts a fake MCIAS HTTP server for token validation.
// Recognized tokens:
// - "admin-token" → valid, username=admin-uuid, roles=[admin]
// - "user-token" → valid, username=user-uuid, account_type=human, roles=[user]
// - "alice-token" → valid, username=alice, account_type=human, roles=[user]
// - anything else → invalid
func mockMCIAS(t *testing.T) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("POST /v1/token/validate", func(w http.ResponseWriter, r *http.Request) {
var req struct {
Token string `json:"token"`
}
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
switch req.Token {
case "admin-token":
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": "admin-uuid",
"account_type": "human",
"roles": []string{"admin"},
})
case "user-token":
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": "user-uuid",
"account_type": "human",
"roles": []string{"user"},
})
case "alice-token":
_ = json.NewEncoder(w).Encode(map[string]interface{}{
"valid": true,
"username": "alice",
"account_type": "human",
"roles": []string{"user"},
})
default:
_ = json.NewEncoder(w).Encode(map[string]interface{}{"valid": false})
}
})
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
// testAuthenticator creates an mcdsl/auth.Authenticator that talks to the given mock MCIAS.
func testAuthenticator(t *testing.T, serverURL string) *mcdslauth.Authenticator {
t.Helper()
a, err := mcdslauth.New(mcdslauth.Config{ServerURL: serverURL}, slog.Default())
if err != nil {
t.Fatalf("auth.New: %v", err)
}
return a
}
// openTestDB creates a temporary test database with migrations applied.
func openTestDB(t *testing.T) *db.DB {
t.Helper()
path := filepath.Join(t.TempDir(), "test.db")
d, err := db.Open(path)
if err != nil {
t.Fatalf("Open: %v", err)
}
t.Cleanup(func() { _ = d.Close() })
if err := d.Migrate(); err != nil {
t.Fatalf("Migrate: %v", err)
}
return d
}
// startTestServer creates a gRPC server and client for testing.
func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn {
t.Helper()
srv, err := New("", "", deps, slog.Default())
if err != nil {
t.Fatalf("New: %v", err)
}
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
go func() {
_ = srv.Serve(lis)
}()
t.Cleanup(func() { srv.GracefulStop() })
//nolint:gosec // insecure credentials for testing only
cc, err := grpc.NewClient(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(pb.JSONCodec{})),
)
if err != nil {
t.Fatalf("Dial: %v", err)
}
t.Cleanup(func() { _ = cc.Close() })
return cc
}
// withAuth adds a bearer token to the outgoing context metadata.
func withAuth(ctx context.Context, token string) context.Context {
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
}
func TestHealthBypassesAuth(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
client := pb.NewAdminServiceClient(cc)
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
if err != nil {
t.Fatalf("Health: %v", err)
}
if resp.Status != "ok" {
t.Fatalf("Health status: got %q, want %q", resp.Status, "ok")
}
}
func TestAuthInterceptorNoToken(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
client := pb.NewRegistryServiceClient(cc)
_, err := client.ListRepositories(context.Background(), &pb.ListRepositoriesRequest{})
if err == nil {
t.Fatal("expected error for unauthenticated request")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAuthInterceptorInvalidToken(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "bad-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
if err == nil {
t.Fatal("expected error for invalid token")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAuthInterceptorValidToken(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRegistryServiceClient(cc)
resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
if err != nil {
t.Fatalf("ListRepositories: %v", err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
}
func TestAdminInterceptorDenied(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "user-token")
// Policy RPCs require admin.
policyClient := pb.NewPolicyServiceClient(cc)
_, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
if err == nil {
t.Fatal("expected error for non-admin user")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestAdminInterceptorAllowed(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "admin-token")
// Admin user should be able to list policy rules.
policyClient := pb.NewPolicyServiceClient(cc)
resp, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
if err != nil {
t.Fatalf("ListPolicyRules: %v", err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
}
func TestMethodMapCompleteness(t *testing.T) {
mm := methodMap()
// Verify that admin-required methods match our security spec.
expectedAdmin := []string{
"/mcr.v1.RegistryService/DeleteRepository",
"/mcr.v1.RegistryService/GarbageCollect",
"/mcr.v1.RegistryService/GetGCStatus",
"/mcr.v1.PolicyService/ListPolicyRules",
"/mcr.v1.PolicyService/CreatePolicyRule",
"/mcr.v1.PolicyService/GetPolicyRule",
"/mcr.v1.PolicyService/UpdatePolicyRule",
"/mcr.v1.PolicyService/DeletePolicyRule",
"/mcr.v1.AuditService/ListAuditEvents",
}
for _, method := range expectedAdmin {
if !mm.AdminRequired[method] {
t.Errorf("method %s should require admin but is not in AdminRequired", method)
}
}
if len(mm.AdminRequired) != len(expectedAdmin) {
t.Errorf("AdminRequired has %d entries, expected %d", len(mm.AdminRequired), len(expectedAdmin))
}
// Health is the only public method.
expectedPublic := []string{
"/mcr.v1.AdminService/Health",
}
for _, method := range expectedPublic {
if !mm.Public[method] {
t.Errorf("method %s should be public but is not in Public", method)
}
}
if len(mm.Public) != len(expectedPublic) {
t.Errorf("Public has %d entries, expected %d", len(mm.Public), len(expectedPublic))
}
// Auth-required methods.
expectedAuth := []string{
"/mcr.v1.RegistryService/ListRepositories",
"/mcr.v1.RegistryService/GetRepository",
}
for _, method := range expectedAuth {
if !mm.AuthRequired[method] {
t.Errorf("method %s should require auth but is not in AuthRequired", method)
}
}
if len(mm.AuthRequired) != len(expectedAuth) {
t.Errorf("AuthRequired has %d entries, expected %d", len(mm.AuthRequired), len(expectedAuth))
}
}
func TestDeleteRepoRequiresAdmin(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.DeleteRepository(ctx, &pb.DeleteRepositoryRequest{Name: "test"})
if err == nil {
t.Fatal("expected error for non-admin user trying to delete repo")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestGCRequiresAdmin(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.GarbageCollect(ctx, &pb.GarbageCollectRequest{})
if err == nil {
t.Fatal("expected error for non-admin user trying to trigger GC")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestAuditRequiresAdmin(t *testing.T) {
mcias := mockMCIAS(t)
auth := testAuthenticator(t, mcias.URL)
database := openTestDB(t)
cc := startTestServer(t, Deps{
DB: database,
Authenticator: auth,
})
ctx := withAuth(context.Background(), "user-token")
client := pb.NewAuditServiceClient(cc)
_, err := client.ListAuditEvents(ctx, &pb.ListAuditEventsRequest{})
if err == nil {
t.Fatal("expected error for non-admin user trying to list audit events")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}