Phase 10: gRPC admin API with interceptor chain

Proto definitions for 4 services (RegistryService, PolicyService,
AuditService, AdminService) with hand-written Go stubs using JSON
codec until protobuf tooling is available.

Interceptor chain: logging (method, peer IP, duration, never logs
auth metadata) → auth (bearer token via MCIAS, Health bypasses) →
admin (role check for GC, policy, delete, audit RPCs).

All RPCs share business logic with REST handlers via internal/db
and internal/gc packages. TLS 1.3 minimum on gRPC listener.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-19 20:46:21 -07:00
parent 562b69e875
commit 185b68ff6d
30 changed files with 3616 additions and 4 deletions

View File

@@ -0,0 +1,16 @@
package grpcserver
import (
"context"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
)
// adminService implements pb.AdminServiceServer.
type adminService struct {
pb.UnimplementedAdminServiceServer
}
func (s *adminService) Health(_ context.Context, _ *pb.HealthRequest) (*pb.HealthResponse, error) {
return &pb.HealthResponse{Status: "ok"}, nil
}

View File

@@ -0,0 +1,43 @@
package grpcserver
import (
"context"
"testing"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
func TestHealthReturnsOk(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewAdminServiceClient(cc)
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
if err != nil {
t.Fatalf("Health: %v", err)
}
if resp.GetStatus() != "ok" {
t.Fatalf("status: got %q, want %q", resp.Status, "ok")
}
}
func TestHealthWithoutAuth(t *testing.T) {
database := openTestDB(t)
// Use a validator that always rejects.
validator := &fakeValidator{err: auth.ErrUnauthorized}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
client := pb.NewAdminServiceClient(cc)
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
if err != nil {
t.Fatalf("Health without auth should succeed: %v", err)
}
if resp.GetStatus() != "ok" {
t.Fatalf("status: got %q, want %q", resp.Status, "ok")
}
}

View File

@@ -0,0 +1,61 @@
package grpcserver
import (
"context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/db"
)
// auditService implements pb.AuditServiceServer.
type auditService struct {
pb.UnimplementedAuditServiceServer
db *db.DB
}
func (s *auditService) ListAuditEvents(_ context.Context, req *pb.ListAuditEventsRequest) (*pb.ListAuditEventsResponse, error) {
limit := int32(50)
offset := int32(0)
if req.GetPagination() != nil {
if req.Pagination.Limit > 0 {
limit = req.Pagination.Limit
}
if req.Pagination.Offset >= 0 {
offset = req.Pagination.Offset
}
}
filter := db.AuditFilter{
EventType: req.GetEventType(),
ActorID: req.GetActorId(),
Repository: req.GetRepository(),
Since: req.GetSince(),
Until: req.GetUntil(),
Limit: int(limit),
Offset: int(offset),
}
events, err := s.db.ListAuditEvents(filter)
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
var result []*pb.AuditEvent
for _, e := range events {
result = append(result, &pb.AuditEvent{
Id: e.ID,
EventTime: e.EventTime,
EventType: e.EventType,
ActorId: e.ActorID,
Repository: e.Repository,
Digest: e.Digest,
IpAddress: e.IPAddress,
Details: e.Details,
})
}
return &pb.ListAuditEventsResponse{Events: result}, nil
}

View File

@@ -0,0 +1,95 @@
package grpcserver
import (
"testing"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
)
func TestListAuditEventsEmpty(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewAuditServiceClient(cc)
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{})
if err != nil {
t.Fatalf("ListAuditEvents: %v", err)
}
if len(resp.GetEvents()) != 0 {
t.Fatalf("expected 0 events, got %d", len(resp.Events))
}
}
func TestListAuditEventsWithData(t *testing.T) {
deps := adminDeps(t)
// Write some audit events directly.
err := deps.DB.WriteAuditEvent("test_event", "actor-1", "repo/test", "", "1.2.3.4", map[string]string{"key": "value"})
if err != nil {
t.Fatalf("WriteAuditEvent: %v", err)
}
err = deps.DB.WriteAuditEvent("other_event", "actor-2", "", "", "5.6.7.8", nil)
if err != nil {
t.Fatalf("WriteAuditEvent: %v", err)
}
cc := startTestServer(t, deps)
client := pb.NewAuditServiceClient(cc)
// List all events.
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{})
if err != nil {
t.Fatalf("ListAuditEvents: %v", err)
}
if len(resp.GetEvents()) != 2 {
t.Fatalf("expected 2 events, got %d", len(resp.Events))
}
// Filter by event type.
resp, err = client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{
EventType: "test_event",
})
if err != nil {
t.Fatalf("ListAuditEvents with filter: %v", err)
}
if len(resp.GetEvents()) != 1 {
t.Fatalf("expected 1 event, got %d", len(resp.Events))
}
if resp.Events[0].EventType != "test_event" {
t.Fatalf("event_type: got %q, want %q", resp.Events[0].EventType, "test_event")
}
if resp.Events[0].ActorId != "actor-1" {
t.Fatalf("actor_id: got %q, want %q", resp.Events[0].ActorId, "actor-1")
}
if resp.Events[0].Details["key"] != "value" {
t.Fatalf("details: got %v, want key=value", resp.Events[0].Details)
}
}
func TestListAuditEventsPagination(t *testing.T) {
deps := adminDeps(t)
// Write 5 events.
for i := range 5 {
err := deps.DB.WriteAuditEvent("event", "actor", "", "", "", map[string]string{
"index": string(rune('0' + i)),
})
if err != nil {
t.Fatalf("WriteAuditEvent %d: %v", i, err)
}
}
cc := startTestServer(t, deps)
client := pb.NewAuditServiceClient(cc)
// Get first 2 events.
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{
Pagination: &pb.PaginationRequest{Limit: 2},
})
if err != nil {
t.Fatalf("ListAuditEvents: %v", err)
}
if len(resp.GetEvents()) != 2 {
t.Fatalf("expected 2 events, got %d", len(resp.Events))
}
}

View File

@@ -0,0 +1,165 @@
package grpcserver
import (
"context"
"log"
"strings"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/peer"
"google.golang.org/grpc/status"
"git.wntrmute.dev/kyle/mcr/internal/auth"
"git.wntrmute.dev/kyle/mcr/internal/server"
)
// authBypassMethods contains the full gRPC method names that bypass
// authentication. Health is the only method that does not require auth.
var authBypassMethods = map[string]bool{
"/mcr.v1.AdminService/Health": true,
}
// adminRequiredMethods contains the full gRPC method names that require
// the admin role. Adding an RPC without adding it to the correct map is
// a security defect per ARCHITECTURE.md.
var adminRequiredMethods = map[string]bool{
// Registry admin operations.
"/mcr.v1.RegistryService/DeleteRepository": true,
"/mcr.v1.RegistryService/GarbageCollect": true,
"/mcr.v1.RegistryService/GetGCStatus": true,
// Policy management — all RPCs require admin.
"/mcr.v1.PolicyService/ListPolicyRules": true,
"/mcr.v1.PolicyService/CreatePolicyRule": true,
"/mcr.v1.PolicyService/GetPolicyRule": true,
"/mcr.v1.PolicyService/UpdatePolicyRule": true,
"/mcr.v1.PolicyService/DeletePolicyRule": true,
// Audit — requires admin.
"/mcr.v1.AuditService/ListAuditEvents": true,
}
// authInterceptor validates bearer tokens from the authorization metadata.
type authInterceptor struct {
validator server.TokenValidator
}
func newAuthInterceptor(v server.TokenValidator) *authInterceptor {
return &authInterceptor{validator: v}
}
// unary is the unary server interceptor for auth.
func (a *authInterceptor) unary(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
// Health bypasses auth.
if authBypassMethods[info.FullMethod] {
return handler(ctx, req)
}
// Extract bearer token from authorization metadata.
token, err := extractToken(ctx)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
// Validate the token via MCIAS.
claims, err := a.validator.ValidateToken(token)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
}
// Inject claims into the context.
ctx = auth.ContextWithClaims(ctx, claims)
return handler(ctx, req)
}
// extractToken extracts a bearer token from the "authorization" gRPC metadata.
func extractToken(ctx context.Context) (string, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return "", status.Errorf(codes.Unauthenticated, "missing metadata")
}
vals := md.Get("authorization")
if len(vals) == 0 {
return "", status.Errorf(codes.Unauthenticated, "missing authorization metadata")
}
val := vals[0]
const prefix = "Bearer "
if !strings.HasPrefix(val, prefix) {
return "", status.Errorf(codes.Unauthenticated, "invalid authorization format")
}
token := strings.TrimSpace(val[len(prefix):])
if token == "" {
return "", status.Errorf(codes.Unauthenticated, "empty bearer token")
}
return token, nil
}
// adminInterceptor checks that the caller has the admin role for
// methods in adminRequiredMethods.
type adminInterceptor struct{}
func newAdminInterceptor() *adminInterceptor {
return &adminInterceptor{}
}
// unary is the unary server interceptor for admin role checks.
func (a *adminInterceptor) unary(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
if !adminRequiredMethods[info.FullMethod] {
return handler(ctx, req)
}
claims := auth.ClaimsFromContext(ctx)
if claims == nil {
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
}
if !hasRole(claims.Roles, "admin") {
return nil, status.Errorf(codes.PermissionDenied, "admin role required")
}
return handler(ctx, req)
}
// loggingInterceptor logs the method, peer IP, status code, and duration.
// It never logs the authorization metadata value.
func loggingInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
start := time.Now()
peerAddr := ""
if p, ok := peer.FromContext(ctx); ok {
peerAddr = p.Addr.String()
}
resp, err := handler(ctx, req)
duration := time.Since(start)
code := codes.OK
if err != nil {
if st, ok := status.FromError(err); ok {
code = st.Code()
} else {
code = codes.Unknown
}
}
log.Printf("grpc %s peer=%s code=%s duration=%s", info.FullMethod, peerAddr, code, duration)
return resp, err
}
// hasRole checks if any of the roles match the target role.
func hasRole(roles []string, target string) bool {
for _, r := range roles {
if r == target {
return true
}
}
return false
}

View File

@@ -0,0 +1,350 @@
package grpcserver
import (
"context"
"net"
"path/filepath"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/auth"
"git.wntrmute.dev/kyle/mcr/internal/db"
)
// fakeValidator is a test double for server.TokenValidator.
type fakeValidator struct {
claims *auth.Claims
err error
}
func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) {
return f.claims, f.err
}
// openTestDB creates a temporary test database with migrations applied.
func openTestDB(t *testing.T) *db.DB {
t.Helper()
path := filepath.Join(t.TempDir(), "test.db")
d, err := db.Open(path)
if err != nil {
t.Fatalf("Open: %v", err)
}
t.Cleanup(func() { _ = d.Close() })
if err := d.Migrate(); err != nil {
t.Fatalf("Migrate: %v", err)
}
return d
}
// startTestServer creates a gRPC server and client for testing.
// Returns the client connection and a cleanup function.
func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn {
t.Helper()
srv, err := New("", "", deps)
if err != nil {
t.Fatalf("New: %v", err)
}
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("Listen: %v", err)
}
go func() {
_ = srv.Serve(lis)
}()
t.Cleanup(func() { srv.GracefulStop() })
//nolint:gosec // insecure credentials for testing only
cc, err := grpc.NewClient(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(pb.JSONCodec{})),
)
if err != nil {
t.Fatalf("Dial: %v", err)
}
t.Cleanup(func() { _ = cc.Close() })
return cc
}
// withAuth adds a bearer token to the outgoing context metadata.
func withAuth(ctx context.Context, token string) context.Context {
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
}
func TestHealthBypassesAuth(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{err: auth.ErrUnauthorized}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
client := pb.NewAdminServiceClient(cc)
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
if err != nil {
t.Fatalf("Health: %v", err)
}
if resp.Status != "ok" {
t.Fatalf("Health status: got %q, want %q", resp.Status, "ok")
}
}
func TestAuthInterceptorNoToken(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{err: auth.ErrUnauthorized}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
client := pb.NewRegistryServiceClient(cc)
_, err := client.ListRepositories(context.Background(), &pb.ListRepositoriesRequest{})
if err == nil {
t.Fatal("expected error for unauthenticated request")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAuthInterceptorInvalidToken(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{err: auth.ErrUnauthorized}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "bad-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
if err == nil {
t.Fatal("expected error for invalid token")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.Unauthenticated {
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
}
}
func TestAuthInterceptorValidToken(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "alice", AccountType: "human", Roles: []string{"user"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
client := pb.NewRegistryServiceClient(cc)
resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
if err != nil {
t.Fatalf("ListRepositories: %v", err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
}
func TestAdminInterceptorDenied(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
// Policy RPCs require admin.
policyClient := pb.NewPolicyServiceClient(cc)
_, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
if err == nil {
t.Fatal("expected error for non-admin user")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestAdminInterceptorAllowed(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
// Admin user should be able to list policy rules.
policyClient := pb.NewPolicyServiceClient(cc)
resp, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
if err != nil {
t.Fatalf("ListPolicyRules: %v", err)
}
if resp == nil {
t.Fatal("expected non-nil response")
}
}
func TestAdminRequiredMethodsCompleteness(t *testing.T) {
// Verify that admin-required methods match our security spec.
// This test catches the security defect of adding an RPC without
// adding it to the adminRequiredMethods map.
expected := []string{
"/mcr.v1.RegistryService/DeleteRepository",
"/mcr.v1.RegistryService/GarbageCollect",
"/mcr.v1.RegistryService/GetGCStatus",
"/mcr.v1.PolicyService/ListPolicyRules",
"/mcr.v1.PolicyService/CreatePolicyRule",
"/mcr.v1.PolicyService/GetPolicyRule",
"/mcr.v1.PolicyService/UpdatePolicyRule",
"/mcr.v1.PolicyService/DeletePolicyRule",
"/mcr.v1.AuditService/ListAuditEvents",
}
for _, method := range expected {
if !adminRequiredMethods[method] {
t.Errorf("method %s should require admin but is not in adminRequiredMethods", method)
}
}
if len(adminRequiredMethods) != len(expected) {
t.Errorf("adminRequiredMethods has %d entries, expected %d", len(adminRequiredMethods), len(expected))
}
}
func TestAuthBypassMethodsCompleteness(t *testing.T) {
// Health is the only method that bypasses auth.
expected := []string{
"/mcr.v1.AdminService/Health",
}
for _, method := range expected {
if !authBypassMethods[method] {
t.Errorf("method %s should bypass auth but is not in authBypassMethods", method)
}
}
if len(authBypassMethods) != len(expected) {
t.Errorf("authBypassMethods has %d entries, expected %d", len(authBypassMethods), len(expected))
}
}
func TestDeleteRepoRequiresAdmin(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.DeleteRepository(ctx, &pb.DeleteRepositoryRequest{Name: "test"})
if err == nil {
t.Fatal("expected error for non-admin user trying to delete repo")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestGCRequiresAdmin(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
client := pb.NewRegistryServiceClient(cc)
_, err := client.GarbageCollect(ctx, &pb.GarbageCollectRequest{})
if err == nil {
t.Fatal("expected error for non-admin user trying to trigger GC")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}
func TestAuditRequiresAdmin(t *testing.T) {
database := openTestDB(t)
validator := &fakeValidator{
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
}
cc := startTestServer(t, Deps{
DB: database,
Validator: validator,
})
ctx := withAuth(context.Background(), "valid-token")
client := pb.NewAuditServiceClient(cc)
_, err := client.ListAuditEvents(ctx, &pb.ListAuditEventsRequest{})
if err == nil {
t.Fatal("expected error for non-admin user trying to list audit events")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status error, got %v", err)
}
if st.Code() != codes.PermissionDenied {
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
}
}

View File

@@ -0,0 +1,292 @@
package grpcserver
import (
"context"
"errors"
"fmt"
"strconv"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/auth"
"git.wntrmute.dev/kyle/mcr/internal/db"
"git.wntrmute.dev/kyle/mcr/internal/policy"
)
var validActions = map[string]bool{
string(policy.ActionVersionCheck): true,
string(policy.ActionPull): true,
string(policy.ActionPush): true,
string(policy.ActionDelete): true,
string(policy.ActionCatalog): true,
string(policy.ActionPolicyManage): true,
}
// policyService implements pb.PolicyServiceServer.
type policyService struct {
pb.UnimplementedPolicyServiceServer
db *db.DB
engine PolicyReloader
auditFn AuditFunc
}
func (s *policyService) ListPolicyRules(_ context.Context, req *pb.ListPolicyRulesRequest) (*pb.ListPolicyRulesResponse, error) {
limit := int32(50)
offset := int32(0)
if req.GetPagination() != nil {
if req.Pagination.Limit > 0 {
limit = req.Pagination.Limit
}
if req.Pagination.Offset >= 0 {
offset = req.Pagination.Offset
}
}
rules, err := s.db.ListPolicyRules(int(limit), int(offset))
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
var result []*pb.PolicyRule
for _, r := range rules {
result = append(result, policyRuleRowToProto(&r))
}
return &pb.ListPolicyRulesResponse{Rules: result}, nil
}
func (s *policyService) CreatePolicyRule(ctx context.Context, req *pb.CreatePolicyRuleRequest) (*pb.PolicyRule, error) {
if req.GetPriority() < 1 {
return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)")
}
if req.GetDescription() == "" {
return nil, status.Errorf(codes.InvalidArgument, "description is required")
}
if err := validateEffect(req.GetEffect()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
}
if len(req.GetActions()) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "at least one action is required")
}
if err := validateActions(req.GetActions()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
}
claims := auth.ClaimsFromContext(ctx)
createdBy := ""
if claims != nil {
createdBy = claims.Subject
}
row := db.PolicyRuleRow{
Priority: int(req.Priority),
Description: req.Description,
Effect: req.Effect,
Roles: req.Roles,
AccountTypes: req.AccountTypes,
SubjectUUID: req.SubjectUuid,
Actions: req.Actions,
Repositories: req.Repositories,
Enabled: req.Enabled,
CreatedBy: createdBy,
}
id, err := s.db.CreatePolicyRule(row)
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
// Reload policy engine.
if s.engine != nil {
_ = s.engine.Reload(s.db)
}
if s.auditFn != nil {
s.auditFn("policy_rule_created", createdBy, "", "", "", map[string]string{
"rule_id": strconv.FormatInt(id, 10),
})
}
created, err := s.db.GetPolicyRule(id)
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
return policyRuleRowToProto(created), nil
}
func (s *policyService) GetPolicyRule(_ context.Context, req *pb.GetPolicyRuleRequest) (*pb.PolicyRule, error) {
if req.GetId() == 0 {
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
}
rule, err := s.db.GetPolicyRule(req.Id)
if err != nil {
if errors.Is(err, db.ErrPolicyRuleNotFound) {
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
return nil, status.Errorf(codes.Internal, "internal error")
}
return policyRuleRowToProto(rule), nil
}
func (s *policyService) UpdatePolicyRule(ctx context.Context, req *pb.UpdatePolicyRuleRequest) (*pb.PolicyRule, error) {
if req.GetId() == 0 {
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
}
mask := make(map[string]bool, len(req.GetUpdateMask()))
for _, f := range req.GetUpdateMask() {
mask[f] = true
}
// Validate fields if they are in the update mask.
if mask["priority"] && req.Priority < 1 {
return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)")
}
if mask["effect"] {
if err := validateEffect(req.GetEffect()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
}
}
if mask["actions"] {
if len(req.GetActions()) == 0 {
return nil, status.Errorf(codes.InvalidArgument, "at least one action is required")
}
if err := validateActions(req.GetActions()); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
}
}
updates := db.PolicyRuleRow{}
if mask["priority"] {
updates.Priority = int(req.Priority)
}
if mask["description"] {
updates.Description = req.Description
}
if mask["effect"] {
updates.Effect = req.Effect
}
if mask["roles"] {
updates.Roles = req.Roles
}
if mask["account_types"] {
updates.AccountTypes = req.AccountTypes
}
if mask["subject_uuid"] {
updates.SubjectUUID = req.SubjectUuid
}
if mask["actions"] {
updates.Actions = req.Actions
}
if mask["repositories"] {
updates.Repositories = req.Repositories
}
if err := s.db.UpdatePolicyRule(req.Id, updates); err != nil {
if errors.Is(err, db.ErrPolicyRuleNotFound) {
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
return nil, status.Errorf(codes.Internal, "internal error")
}
// Handle enabled separately since it's a bool.
if mask["enabled"] {
if err := s.db.SetPolicyRuleEnabled(req.Id, req.Enabled); err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
}
// Reload policy engine.
if s.engine != nil {
_ = s.engine.Reload(s.db)
}
if s.auditFn != nil {
claims := auth.ClaimsFromContext(ctx)
actorID := ""
if claims != nil {
actorID = claims.Subject
}
s.auditFn("policy_rule_updated", actorID, "", "", "", map[string]string{
"rule_id": strconv.FormatInt(req.Id, 10),
})
}
updated, err := s.db.GetPolicyRule(req.Id)
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
return policyRuleRowToProto(updated), nil
}
func (s *policyService) DeletePolicyRule(ctx context.Context, req *pb.DeletePolicyRuleRequest) (*pb.DeletePolicyRuleResponse, error) {
if req.GetId() == 0 {
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
}
if err := s.db.DeletePolicyRule(req.Id); err != nil {
if errors.Is(err, db.ErrPolicyRuleNotFound) {
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
return nil, status.Errorf(codes.Internal, "internal error")
}
// Reload policy engine.
if s.engine != nil {
_ = s.engine.Reload(s.db)
}
if s.auditFn != nil {
claims := auth.ClaimsFromContext(ctx)
actorID := ""
if claims != nil {
actorID = claims.Subject
}
s.auditFn("policy_rule_deleted", actorID, "", "", "", map[string]string{
"rule_id": strconv.FormatInt(req.Id, 10),
})
}
return &pb.DeletePolicyRuleResponse{}, nil
}
// policyRuleRowToProto converts a db.PolicyRuleRow to a protobuf PolicyRule.
func policyRuleRowToProto(r *db.PolicyRuleRow) *pb.PolicyRule {
return &pb.PolicyRule{
Id: r.ID,
Priority: int32(r.Priority), //nolint:gosec // priority is always small positive int
Description: r.Description,
Effect: r.Effect,
Roles: r.Roles,
AccountTypes: r.AccountTypes,
SubjectUuid: r.SubjectUUID,
Actions: r.Actions,
Repositories: r.Repositories,
Enabled: r.Enabled,
CreatedBy: r.CreatedBy,
CreatedAt: r.CreatedAt,
UpdatedAt: r.UpdatedAt,
}
}
func validateEffect(effect string) error {
if effect != "allow" && effect != "deny" {
return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect)
}
return nil
}
func validateActions(actions []string) error {
for _, a := range actions {
if !validActions[a] {
return fmt.Errorf("invalid action: %q", a)
}
}
return nil
}

View File

@@ -0,0 +1,306 @@
package grpcserver
import (
"testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/policy"
)
type fakePolicyReloader struct {
reloadCount int
}
func (f *fakePolicyReloader) Reload(_ policy.RuleStore) error {
f.reloadCount++
return nil
}
func TestCreatePolicyRule(t *testing.T) {
deps := adminDeps(t)
reloader := &fakePolicyReloader{}
deps.Engine = reloader
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
resp, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
Priority: 10,
Description: "allow pull for all",
Effect: "allow",
Actions: []string{"registry:pull"},
Enabled: true,
})
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
if resp.GetId() == 0 {
t.Fatal("expected non-zero ID")
}
if resp.GetDescription() != "allow pull for all" {
t.Fatalf("description: got %q, want %q", resp.Description, "allow pull for all")
}
if resp.GetEffect() != "allow" {
t.Fatalf("effect: got %q, want %q", resp.Effect, "allow")
}
if !resp.GetEnabled() {
t.Fatal("expected enabled=true")
}
if reloader.reloadCount != 1 {
t.Fatalf("reloadCount: got %d, want 1", reloader.reloadCount)
}
}
func TestCreatePolicyRuleValidation(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
tests := []struct {
name string
req *pb.CreatePolicyRuleRequest
code codes.Code
}{
{
name: "zero priority",
req: &pb.CreatePolicyRuleRequest{
Priority: 0,
Description: "test",
Effect: "allow",
Actions: []string{"registry:pull"},
},
code: codes.InvalidArgument,
},
{
name: "empty description",
req: &pb.CreatePolicyRuleRequest{
Priority: 1,
Effect: "allow",
Actions: []string{"registry:pull"},
},
code: codes.InvalidArgument,
},
{
name: "invalid effect",
req: &pb.CreatePolicyRuleRequest{
Priority: 1,
Description: "test",
Effect: "maybe",
Actions: []string{"registry:pull"},
},
code: codes.InvalidArgument,
},
{
name: "no actions",
req: &pb.CreatePolicyRuleRequest{
Priority: 1,
Description: "test",
Effect: "allow",
},
code: codes.InvalidArgument,
},
{
name: "invalid action",
req: &pb.CreatePolicyRuleRequest{
Priority: 1,
Description: "test",
Effect: "allow",
Actions: []string{"registry:fly"},
},
code: codes.InvalidArgument,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := client.CreatePolicyRule(adminCtx(), tt.req)
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != tt.code {
t.Fatalf("code: got %v, want %v", st.Code(), tt.code)
}
})
}
}
func TestGetPolicyRule(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
// Create a rule first.
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
Priority: 5,
Description: "test rule",
Effect: "deny",
Actions: []string{"registry:push"},
Enabled: true,
})
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
// Fetch it.
got, err := client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: created.Id})
if err != nil {
t.Fatalf("GetPolicyRule: %v", err)
}
if got.Id != created.Id {
t.Fatalf("id: got %d, want %d", got.Id, created.Id)
}
if got.Effect != "deny" {
t.Fatalf("effect: got %q, want %q", got.Effect, "deny")
}
}
func TestGetPolicyRuleNotFound(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
_, err := client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: 99999})
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestListPolicyRules(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
// Create two rules.
for i := range 2 {
_, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
Priority: int32(i + 1),
Description: "rule",
Effect: "allow",
Actions: []string{"registry:pull"},
Enabled: true,
})
if err != nil {
t.Fatalf("CreatePolicyRule %d: %v", i, err)
}
}
resp, err := client.ListPolicyRules(adminCtx(), &pb.ListPolicyRulesRequest{})
if err != nil {
t.Fatalf("ListPolicyRules: %v", err)
}
if len(resp.GetRules()) < 2 {
t.Fatalf("expected at least 2 rules, got %d", len(resp.Rules))
}
}
func TestDeletePolicyRule(t *testing.T) {
deps := adminDeps(t)
reloader := &fakePolicyReloader{}
deps.Engine = reloader
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
// Create then delete.
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
Priority: 1,
Description: "to be deleted",
Effect: "allow",
Actions: []string{"registry:pull"},
Enabled: true,
})
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
initialReloads := reloader.reloadCount
_, err = client.DeletePolicyRule(adminCtx(), &pb.DeletePolicyRuleRequest{Id: created.Id})
if err != nil {
t.Fatalf("DeletePolicyRule: %v", err)
}
// Verify it was reloaded.
if reloader.reloadCount != initialReloads+1 {
t.Fatalf("reloadCount: got %d, want %d", reloader.reloadCount, initialReloads+1)
}
// Verify it's gone.
_, err = client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: created.Id})
if err == nil {
t.Fatal("expected error after deletion")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestDeletePolicyRuleNotFound(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
_, err := client.DeletePolicyRule(adminCtx(), &pb.DeletePolicyRuleRequest{Id: 99999})
if err == nil {
t.Fatal("expected error")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestUpdatePolicyRule(t *testing.T) {
deps := adminDeps(t)
reloader := &fakePolicyReloader{}
deps.Engine = reloader
cc := startTestServer(t, deps)
client := pb.NewPolicyServiceClient(cc)
// Create a rule.
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
Priority: 10,
Description: "original",
Effect: "allow",
Actions: []string{"registry:pull"},
Enabled: true,
})
if err != nil {
t.Fatalf("CreatePolicyRule: %v", err)
}
// Update description.
updated, err := client.UpdatePolicyRule(adminCtx(), &pb.UpdatePolicyRuleRequest{
Id: created.Id,
Description: "updated description",
UpdateMask: []string{"description"},
})
if err != nil {
t.Fatalf("UpdatePolicyRule: %v", err)
}
if updated.Description != "updated description" {
t.Fatalf("description: got %q, want %q", updated.Description, "updated description")
}
// Effect should be unchanged.
if updated.Effect != "allow" {
t.Fatalf("effect: got %q, want %q", updated.Effect, "allow")
}
}

View File

@@ -0,0 +1,203 @@
package grpcserver
import (
"context"
"errors"
"fmt"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/google/uuid"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/auth"
"git.wntrmute.dev/kyle/mcr/internal/db"
"git.wntrmute.dev/kyle/mcr/internal/gc"
)
// registryService implements pb.RegistryServiceServer.
type registryService struct {
pb.UnimplementedRegistryServiceServer
db *db.DB
collector *gc.Collector
gcStatus *GCStatus
auditFn AuditFunc
}
func (s *registryService) ListRepositories(_ context.Context, req *pb.ListRepositoriesRequest) (*pb.ListRepositoriesResponse, error) {
limit := int32(50)
offset := int32(0)
if req.GetPagination() != nil {
if req.Pagination.Limit > 0 {
limit = req.Pagination.Limit
}
if req.Pagination.Offset >= 0 {
offset = req.Pagination.Offset
}
}
repos, err := s.db.ListRepositoriesWithMetadata(int(limit), int(offset))
if err != nil {
return nil, status.Errorf(codes.Internal, "internal error")
}
var result []*pb.RepositoryMetadata
for _, r := range repos {
result = append(result, &pb.RepositoryMetadata{
Name: r.Name,
TagCount: int32(r.TagCount), //nolint:gosec // tag count fits int32
ManifestCount: int32(r.ManifestCount), //nolint:gosec // manifest count fits int32
TotalSize: r.TotalSize,
CreatedAt: r.CreatedAt,
})
}
return &pb.ListRepositoriesResponse{Repositories: result}, nil
}
func (s *registryService) GetRepository(_ context.Context, req *pb.GetRepositoryRequest) (*pb.GetRepositoryResponse, error) {
if req.GetName() == "" {
return nil, status.Errorf(codes.InvalidArgument, "repository name required")
}
detail, err := s.db.GetRepositoryDetail(req.Name)
if err != nil {
if errors.Is(err, db.ErrRepoNotFound) {
return nil, status.Errorf(codes.NotFound, "repository not found")
}
return nil, status.Errorf(codes.Internal, "internal error")
}
resp := &pb.GetRepositoryResponse{
Name: detail.Name,
TotalSize: detail.TotalSize,
CreatedAt: detail.CreatedAt,
}
for _, t := range detail.Tags {
resp.Tags = append(resp.Tags, &pb.TagInfo{
Name: t.Name,
Digest: t.Digest,
})
}
for _, m := range detail.Manifests {
resp.Manifests = append(resp.Manifests, &pb.ManifestInfo{
Digest: m.Digest,
MediaType: m.MediaType,
Size: m.Size,
CreatedAt: m.CreatedAt,
})
}
return resp, nil
}
func (s *registryService) DeleteRepository(ctx context.Context, req *pb.DeleteRepositoryRequest) (*pb.DeleteRepositoryResponse, error) {
if req.GetName() == "" {
return nil, status.Errorf(codes.InvalidArgument, "repository name required")
}
if err := s.db.DeleteRepository(req.Name); err != nil {
if errors.Is(err, db.ErrRepoNotFound) {
return nil, status.Errorf(codes.NotFound, "repository not found")
}
return nil, status.Errorf(codes.Internal, "internal error")
}
if s.auditFn != nil {
claims := auth.ClaimsFromContext(ctx)
actorID := ""
if claims != nil {
actorID = claims.Subject
}
s.auditFn("repo_deleted", actorID, req.Name, "", "", nil)
}
return &pb.DeleteRepositoryResponse{}, nil
}
func (s *registryService) GarbageCollect(_ context.Context, _ *pb.GarbageCollectRequest) (*pb.GarbageCollectResponse, error) {
s.gcStatus.mu.Lock()
if s.gcStatus.running {
s.gcStatus.mu.Unlock()
return nil, status.Errorf(codes.AlreadyExists, "garbage collection already running")
}
s.gcStatus.running = true
s.gcStatus.mu.Unlock()
gcID := uuid.New().String()
// Run GC asynchronously like the REST handler. GC is a long-running
// background operation that must not be tied to the request context,
// so we intentionally use context.Background() inside runGC.
go s.runGC(gcID) //nolint:gosec // G118: GC must outlive the triggering RPC
return &pb.GarbageCollectResponse{Id: gcID}, nil
}
// runGC executes garbage collection in the background. It uses
// context.Background() because GC must not be cancelled when the
// triggering RPC completes.
func (s *registryService) runGC(gcID string) {
startedAt := time.Now().UTC().Format(time.RFC3339)
if s.auditFn != nil {
s.auditFn("gc_started", "", "", "", "", map[string]string{
"gc_id": gcID,
})
}
var blobsRemoved int
var bytesFreed int64
var gcErr error
if s.collector != nil {
r, err := s.collector.Run(context.Background()) //nolint:gosec // GC is intentionally background, not request-scoped
if err != nil {
gcErr = err
}
if r != nil {
blobsRemoved = r.BlobsRemoved
bytesFreed = r.BytesFreed
}
}
completedAt := time.Now().UTC().Format(time.RFC3339)
s.gcStatus.mu.Lock()
s.gcStatus.running = false
s.gcStatus.lastRun = &gcLastRun{
StartedAt: startedAt,
CompletedAt: completedAt,
BlobsRemoved: blobsRemoved,
BytesFreed: bytesFreed,
}
s.gcStatus.mu.Unlock()
if s.auditFn != nil && gcErr == nil {
details := map[string]string{
"gc_id": gcID,
"blobs_removed": fmt.Sprintf("%d", blobsRemoved),
"bytes_freed": fmt.Sprintf("%d", bytesFreed),
}
s.auditFn("gc_completed", "", "", "", "", details)
}
}
func (s *registryService) GetGCStatus(_ context.Context, _ *pb.GetGCStatusRequest) (*pb.GetGCStatusResponse, error) {
s.gcStatus.mu.Lock()
resp := &pb.GetGCStatusResponse{
Running: s.gcStatus.running,
}
if s.gcStatus.lastRun != nil {
resp.LastRun = &pb.GCLastRun{
StartedAt: s.gcStatus.lastRun.StartedAt,
CompletedAt: s.gcStatus.lastRun.CompletedAt,
BlobsRemoved: int32(s.gcStatus.lastRun.BlobsRemoved), //nolint:gosec // blob count fits int32
BytesFreed: s.gcStatus.lastRun.BytesFreed,
}
}
s.gcStatus.mu.Unlock()
return resp, nil
}

View File

@@ -0,0 +1,144 @@
package grpcserver
import (
"context"
"testing"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
func adminDeps(t *testing.T) Deps {
t.Helper()
return Deps{
DB: openTestDB(t),
Validator: &fakeValidator{
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
},
}
}
func adminCtx() context.Context {
return withAuth(context.Background(), "admin-token")
}
func TestListRepositoriesEmpty(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
resp, err := client.ListRepositories(adminCtx(), &pb.ListRepositoriesRequest{})
if err != nil {
t.Fatalf("ListRepositories: %v", err)
}
if len(resp.GetRepositories()) != 0 {
t.Fatalf("expected 0 repos, got %d", len(resp.Repositories))
}
}
func TestGetRepositoryNotFound(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
_, err := client.GetRepository(adminCtx(), &pb.GetRepositoryRequest{Name: "nonexistent"})
if err == nil {
t.Fatal("expected error for nonexistent repo")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestGetRepositoryEmptyName(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
_, err := client.GetRepository(adminCtx(), &pb.GetRepositoryRequest{})
if err == nil {
t.Fatal("expected error for empty name")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.InvalidArgument {
t.Fatalf("code: got %v, want InvalidArgument", st.Code())
}
}
func TestDeleteRepositoryNotFound(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
_, err := client.DeleteRepository(adminCtx(), &pb.DeleteRepositoryRequest{Name: "nonexistent"})
if err == nil {
t.Fatal("expected error for nonexistent repo")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.NotFound {
t.Fatalf("code: got %v, want NotFound", st.Code())
}
}
func TestDeleteRepositoryEmptyName(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
_, err := client.DeleteRepository(adminCtx(), &pb.DeleteRepositoryRequest{})
if err == nil {
t.Fatal("expected error for empty name")
}
st, ok := status.FromError(err)
if !ok {
t.Fatalf("expected gRPC status, got %v", err)
}
if st.Code() != codes.InvalidArgument {
t.Fatalf("code: got %v, want InvalidArgument", st.Code())
}
}
func TestGCStatusInitial(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
resp, err := client.GetGCStatus(adminCtx(), &pb.GetGCStatusRequest{})
if err != nil {
t.Fatalf("GetGCStatus: %v", err)
}
if resp.Running {
t.Fatal("expected running=false on startup")
}
if resp.LastRun != nil {
t.Fatal("expected no last_run on startup")
}
}
func TestGarbageCollectTrigger(t *testing.T) {
deps := adminDeps(t)
cc := startTestServer(t, deps)
client := pb.NewRegistryServiceClient(cc)
// Trigger GC without a collector (no-op but should return an ID).
resp, err := client.GarbageCollect(adminCtx(), &pb.GarbageCollectRequest{})
if err != nil {
t.Fatalf("GarbageCollect: %v", err)
}
if resp.GetId() == "" {
t.Fatal("expected non-empty GC ID")
}
}

View File

@@ -0,0 +1,141 @@
// Package grpcserver implements the MCR gRPC admin API server.
//
// It provides the same business logic as the REST admin API in
// internal/server/, using shared internal/db and internal/gc packages.
// The server enforces TLS 1.3 minimum, auth via MCIAS token validation,
// and admin role checks on privileged RPCs.
package grpcserver
import (
"crypto/tls"
"fmt"
"log"
"net"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/internal/db"
"git.wntrmute.dev/kyle/mcr/internal/gc"
"git.wntrmute.dev/kyle/mcr/internal/policy"
"git.wntrmute.dev/kyle/mcr/internal/server"
)
// AuditFunc is a callback for recording audit events. It follows the same
// signature as db.WriteAuditEvent but without an error return -- audit
// failures should not block request processing.
type AuditFunc func(eventType, actorID, repository, digest, ip string, details map[string]string)
// Deps holds the dependencies injected into the gRPC server.
type Deps struct {
DB *db.DB
Validator server.TokenValidator
Engine PolicyReloader
AuditFn AuditFunc
Collector *gc.Collector
}
// PolicyReloader can reload policy rules from a store.
type PolicyReloader interface {
Reload(store policy.RuleStore) error
}
// GCStatus tracks the current state of garbage collection for the gRPC server.
type GCStatus struct {
mu sync.Mutex
running bool
lastRun *gcLastRun
}
type gcLastRun struct {
StartedAt string
CompletedAt string
BlobsRemoved int
BytesFreed int64
}
// Server wraps a grpc.Server with MCR-specific configuration.
type Server struct {
gs *grpc.Server
deps Deps
gcStatus *GCStatus
}
// New creates a configured gRPC server with the interceptor chain:
// [Request Logger] -> [Auth Interceptor] -> [Admin Interceptor] -> [Handler]
//
// The TLS config enforces TLS 1.3 minimum. If certFile or keyFile is
// empty, the server is created without TLS (for testing only).
func New(certFile, keyFile string, deps Deps) (*Server, error) {
authInt := newAuthInterceptor(deps.Validator)
adminInt := newAdminInterceptor()
chain := grpc.ChainUnaryInterceptor(
loggingInterceptor,
authInt.unary,
adminInt.unary,
)
var opts []grpc.ServerOption
opts = append(opts, chain)
// Configure TLS if cert and key are provided.
if certFile != "" && keyFile != "" {
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, fmt.Errorf("grpcserver: load TLS cert: %w", err)
}
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsCfg)))
}
// The JSON codec is registered globally via init() in gen/mcr/v1/codec.go.
// The client must use grpc.ForceCodecV2(mcrv1.JSONCodec{}) to match.
_ = pb.JSONCodec{} // ensure the gen/mcr/v1 init() runs (codec registration)
gs := grpc.NewServer(opts...)
gcStatus := &GCStatus{}
s := &Server{gs: gs, deps: deps, gcStatus: gcStatus}
// Register all services.
pb.RegisterRegistryServiceServer(gs, &registryService{
db: deps.DB,
collector: deps.Collector,
gcStatus: gcStatus,
auditFn: deps.AuditFn,
})
pb.RegisterPolicyServiceServer(gs, &policyService{
db: deps.DB,
engine: deps.Engine,
auditFn: deps.AuditFn,
})
pb.RegisterAuditServiceServer(gs, &auditService{
db: deps.DB,
})
pb.RegisterAdminServiceServer(gs, &adminService{})
return s, nil
}
// Serve starts the gRPC server on the given listener.
func (s *Server) Serve(lis net.Listener) error {
log.Printf("grpc server listening on %s", lis.Addr())
return s.gs.Serve(lis)
}
// GracefulStop gracefully stops the gRPC server.
func (s *Server) GracefulStop() {
s.gs.GracefulStop()
}
// GRPCServer returns the underlying grpc.Server for testing.
func (s *Server) GRPCServer() *grpc.Server {
return s.gs
}