Phase 10: gRPC admin API with interceptor chain
Proto definitions for 4 services (RegistryService, PolicyService, AuditService, AdminService) with hand-written Go stubs using JSON codec until protobuf tooling is available. Interceptor chain: logging (method, peer IP, duration, never logs auth metadata) → auth (bearer token via MCIAS, Health bypasses) → admin (role check for GC, policy, delete, audit RPCs). All RPCs share business logic with REST handlers via internal/db and internal/gc packages. TLS 1.3 minimum on gRPC listener. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
16
internal/grpcserver/admin.go
Normal file
16
internal/grpcserver/admin.go
Normal file
@@ -0,0 +1,16 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
)
|
||||
|
||||
// adminService implements pb.AdminServiceServer.
|
||||
type adminService struct {
|
||||
pb.UnimplementedAdminServiceServer
|
||||
}
|
||||
|
||||
func (s *adminService) Health(_ context.Context, _ *pb.HealthRequest) (*pb.HealthResponse, error) {
|
||||
return &pb.HealthResponse{Status: "ok"}, nil
|
||||
}
|
||||
43
internal/grpcserver/admin_test.go
Normal file
43
internal/grpcserver/admin_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
)
|
||||
|
||||
func TestHealthReturnsOk(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewAdminServiceClient(cc)
|
||||
|
||||
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health: %v", err)
|
||||
}
|
||||
if resp.GetStatus() != "ok" {
|
||||
t.Fatalf("status: got %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHealthWithoutAuth(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
// Use a validator that always rejects.
|
||||
validator := &fakeValidator{err: auth.ErrUnauthorized}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
client := pb.NewAdminServiceClient(cc)
|
||||
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health without auth should succeed: %v", err)
|
||||
}
|
||||
if resp.GetStatus() != "ok" {
|
||||
t.Fatalf("status: got %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
}
|
||||
61
internal/grpcserver/audit.go
Normal file
61
internal/grpcserver/audit.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// auditService implements pb.AuditServiceServer.
|
||||
type auditService struct {
|
||||
pb.UnimplementedAuditServiceServer
|
||||
db *db.DB
|
||||
}
|
||||
|
||||
func (s *auditService) ListAuditEvents(_ context.Context, req *pb.ListAuditEventsRequest) (*pb.ListAuditEventsResponse, error) {
|
||||
limit := int32(50)
|
||||
offset := int32(0)
|
||||
if req.GetPagination() != nil {
|
||||
if req.Pagination.Limit > 0 {
|
||||
limit = req.Pagination.Limit
|
||||
}
|
||||
if req.Pagination.Offset >= 0 {
|
||||
offset = req.Pagination.Offset
|
||||
}
|
||||
}
|
||||
|
||||
filter := db.AuditFilter{
|
||||
EventType: req.GetEventType(),
|
||||
ActorID: req.GetActorId(),
|
||||
Repository: req.GetRepository(),
|
||||
Since: req.GetSince(),
|
||||
Until: req.GetUntil(),
|
||||
Limit: int(limit),
|
||||
Offset: int(offset),
|
||||
}
|
||||
|
||||
events, err := s.db.ListAuditEvents(filter)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
var result []*pb.AuditEvent
|
||||
for _, e := range events {
|
||||
result = append(result, &pb.AuditEvent{
|
||||
Id: e.ID,
|
||||
EventTime: e.EventTime,
|
||||
EventType: e.EventType,
|
||||
ActorId: e.ActorID,
|
||||
Repository: e.Repository,
|
||||
Digest: e.Digest,
|
||||
IpAddress: e.IPAddress,
|
||||
Details: e.Details,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.ListAuditEventsResponse{Events: result}, nil
|
||||
}
|
||||
95
internal/grpcserver/audit_test.go
Normal file
95
internal/grpcserver/audit_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
)
|
||||
|
||||
func TestListAuditEventsEmpty(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewAuditServiceClient(cc)
|
||||
|
||||
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAuditEvents: %v", err)
|
||||
}
|
||||
if len(resp.GetEvents()) != 0 {
|
||||
t.Fatalf("expected 0 events, got %d", len(resp.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAuditEventsWithData(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
|
||||
// Write some audit events directly.
|
||||
err := deps.DB.WriteAuditEvent("test_event", "actor-1", "repo/test", "", "1.2.3.4", map[string]string{"key": "value"})
|
||||
if err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
err = deps.DB.WriteAuditEvent("other_event", "actor-2", "", "", "5.6.7.8", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewAuditServiceClient(cc)
|
||||
|
||||
// List all events.
|
||||
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAuditEvents: %v", err)
|
||||
}
|
||||
if len(resp.GetEvents()) != 2 {
|
||||
t.Fatalf("expected 2 events, got %d", len(resp.Events))
|
||||
}
|
||||
|
||||
// Filter by event type.
|
||||
resp, err = client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{
|
||||
EventType: "test_event",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAuditEvents with filter: %v", err)
|
||||
}
|
||||
if len(resp.GetEvents()) != 1 {
|
||||
t.Fatalf("expected 1 event, got %d", len(resp.Events))
|
||||
}
|
||||
if resp.Events[0].EventType != "test_event" {
|
||||
t.Fatalf("event_type: got %q, want %q", resp.Events[0].EventType, "test_event")
|
||||
}
|
||||
if resp.Events[0].ActorId != "actor-1" {
|
||||
t.Fatalf("actor_id: got %q, want %q", resp.Events[0].ActorId, "actor-1")
|
||||
}
|
||||
if resp.Events[0].Details["key"] != "value" {
|
||||
t.Fatalf("details: got %v, want key=value", resp.Events[0].Details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListAuditEventsPagination(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
|
||||
// Write 5 events.
|
||||
for i := range 5 {
|
||||
err := deps.DB.WriteAuditEvent("event", "actor", "", "", "", map[string]string{
|
||||
"index": string(rune('0' + i)),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("WriteAuditEvent %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewAuditServiceClient(cc)
|
||||
|
||||
// Get first 2 events.
|
||||
resp, err := client.ListAuditEvents(adminCtx(), &pb.ListAuditEventsRequest{
|
||||
Pagination: &pb.PaginationRequest{Limit: 2},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListAuditEvents: %v", err)
|
||||
}
|
||||
if len(resp.GetEvents()) != 2 {
|
||||
t.Fatalf("expected 2 events, got %d", len(resp.Events))
|
||||
}
|
||||
}
|
||||
165
internal/grpcserver/interceptors.go
Normal file
165
internal/grpcserver/interceptors.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/peer"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/server"
|
||||
)
|
||||
|
||||
// authBypassMethods contains the full gRPC method names that bypass
|
||||
// authentication. Health is the only method that does not require auth.
|
||||
var authBypassMethods = map[string]bool{
|
||||
"/mcr.v1.AdminService/Health": true,
|
||||
}
|
||||
|
||||
// adminRequiredMethods contains the full gRPC method names that require
|
||||
// the admin role. Adding an RPC without adding it to the correct map is
|
||||
// a security defect per ARCHITECTURE.md.
|
||||
var adminRequiredMethods = map[string]bool{
|
||||
// Registry admin operations.
|
||||
"/mcr.v1.RegistryService/DeleteRepository": true,
|
||||
"/mcr.v1.RegistryService/GarbageCollect": true,
|
||||
"/mcr.v1.RegistryService/GetGCStatus": true,
|
||||
|
||||
// Policy management — all RPCs require admin.
|
||||
"/mcr.v1.PolicyService/ListPolicyRules": true,
|
||||
"/mcr.v1.PolicyService/CreatePolicyRule": true,
|
||||
"/mcr.v1.PolicyService/GetPolicyRule": true,
|
||||
"/mcr.v1.PolicyService/UpdatePolicyRule": true,
|
||||
"/mcr.v1.PolicyService/DeletePolicyRule": true,
|
||||
|
||||
// Audit — requires admin.
|
||||
"/mcr.v1.AuditService/ListAuditEvents": true,
|
||||
}
|
||||
|
||||
// authInterceptor validates bearer tokens from the authorization metadata.
|
||||
type authInterceptor struct {
|
||||
validator server.TokenValidator
|
||||
}
|
||||
|
||||
func newAuthInterceptor(v server.TokenValidator) *authInterceptor {
|
||||
return &authInterceptor{validator: v}
|
||||
}
|
||||
|
||||
// unary is the unary server interceptor for auth.
|
||||
func (a *authInterceptor) unary(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
// Health bypasses auth.
|
||||
if authBypassMethods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// Extract bearer token from authorization metadata.
|
||||
token, err := extractToken(ctx)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
|
||||
}
|
||||
|
||||
// Validate the token via MCIAS.
|
||||
claims, err := a.validator.ValidateToken(token)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "invalid token")
|
||||
}
|
||||
|
||||
// Inject claims into the context.
|
||||
ctx = auth.ContextWithClaims(ctx, claims)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// extractToken extracts a bearer token from the "authorization" gRPC metadata.
|
||||
func extractToken(ctx context.Context) (string, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
return "", status.Errorf(codes.Unauthenticated, "missing metadata")
|
||||
}
|
||||
|
||||
vals := md.Get("authorization")
|
||||
if len(vals) == 0 {
|
||||
return "", status.Errorf(codes.Unauthenticated, "missing authorization metadata")
|
||||
}
|
||||
|
||||
val := vals[0]
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(val, prefix) {
|
||||
return "", status.Errorf(codes.Unauthenticated, "invalid authorization format")
|
||||
}
|
||||
|
||||
token := strings.TrimSpace(val[len(prefix):])
|
||||
if token == "" {
|
||||
return "", status.Errorf(codes.Unauthenticated, "empty bearer token")
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// adminInterceptor checks that the caller has the admin role for
|
||||
// methods in adminRequiredMethods.
|
||||
type adminInterceptor struct{}
|
||||
|
||||
func newAdminInterceptor() *adminInterceptor {
|
||||
return &adminInterceptor{}
|
||||
}
|
||||
|
||||
// unary is the unary server interceptor for admin role checks.
|
||||
func (a *adminInterceptor) unary(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
if !adminRequiredMethods[info.FullMethod] {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
claims := auth.ClaimsFromContext(ctx)
|
||||
if claims == nil {
|
||||
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
|
||||
}
|
||||
|
||||
if !hasRole(claims.Roles, "admin") {
|
||||
return nil, status.Errorf(codes.PermissionDenied, "admin role required")
|
||||
}
|
||||
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
// loggingInterceptor logs the method, peer IP, status code, and duration.
|
||||
// It never logs the authorization metadata value.
|
||||
func loggingInterceptor(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
start := time.Now()
|
||||
|
||||
peerAddr := ""
|
||||
if p, ok := peer.FromContext(ctx); ok {
|
||||
peerAddr = p.Addr.String()
|
||||
}
|
||||
|
||||
resp, err := handler(ctx, req)
|
||||
|
||||
duration := time.Since(start)
|
||||
code := codes.OK
|
||||
if err != nil {
|
||||
if st, ok := status.FromError(err); ok {
|
||||
code = st.Code()
|
||||
} else {
|
||||
code = codes.Unknown
|
||||
}
|
||||
}
|
||||
|
||||
log.Printf("grpc %s peer=%s code=%s duration=%s", info.FullMethod, peerAddr, code, duration)
|
||||
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// hasRole checks if any of the roles match the target role.
|
||||
func hasRole(roles []string, target string) bool {
|
||||
for _, r := range roles {
|
||||
if r == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
350
internal/grpcserver/interceptors_test.go
Normal file
350
internal/grpcserver/interceptors_test.go
Normal file
@@ -0,0 +1,350 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// fakeValidator is a test double for server.TokenValidator.
|
||||
type fakeValidator struct {
|
||||
claims *auth.Claims
|
||||
err error
|
||||
}
|
||||
|
||||
func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) {
|
||||
return f.claims, f.err
|
||||
}
|
||||
|
||||
// openTestDB creates a temporary test database with migrations applied.
|
||||
func openTestDB(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "test.db")
|
||||
d, err := db.Open(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = d.Close() })
|
||||
if err := d.Migrate(); err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
// startTestServer creates a gRPC server and client for testing.
|
||||
// Returns the client connection and a cleanup function.
|
||||
func startTestServer(t *testing.T, deps Deps) *grpc.ClientConn {
|
||||
t.Helper()
|
||||
|
||||
srv, err := New("", "", deps)
|
||||
if err != nil {
|
||||
t.Fatalf("New: %v", err)
|
||||
}
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("Listen: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = srv.Serve(lis)
|
||||
}()
|
||||
t.Cleanup(func() { srv.GracefulStop() })
|
||||
|
||||
//nolint:gosec // insecure credentials for testing only
|
||||
cc, err := grpc.NewClient(
|
||||
lis.Addr().String(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(pb.JSONCodec{})),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = cc.Close() })
|
||||
|
||||
return cc
|
||||
}
|
||||
|
||||
// withAuth adds a bearer token to the outgoing context metadata.
|
||||
func withAuth(ctx context.Context, token string) context.Context {
|
||||
return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
func TestHealthBypassesAuth(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{err: auth.ErrUnauthorized}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
client := pb.NewAdminServiceClient(cc)
|
||||
resp, err := client.Health(context.Background(), &pb.HealthRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("Health: %v", err)
|
||||
}
|
||||
if resp.Status != "ok" {
|
||||
t.Fatalf("Health status: got %q, want %q", resp.Status, "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInterceptorNoToken(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{err: auth.ErrUnauthorized}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
_, err := client.ListRepositories(context.Background(), &pb.ListRepositoriesRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unauthenticated request")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInterceptorInvalidToken(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{err: auth.ErrUnauthorized}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "bad-token")
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
_, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid token")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.Unauthenticated {
|
||||
t.Fatalf("code: got %v, want Unauthenticated", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthInterceptorValidToken(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "alice", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
resp, err := client.ListRepositories(ctx, &pb.ListRepositoriesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRepositories: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptorDenied(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
|
||||
// Policy RPCs require admin.
|
||||
policyClient := pb.NewPolicyServiceClient(cc)
|
||||
_, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-admin user")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminInterceptorAllowed(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
|
||||
// Admin user should be able to list policy rules.
|
||||
policyClient := pb.NewPolicyServiceClient(cc)
|
||||
resp, err := policyClient.ListPolicyRules(ctx, &pb.ListPolicyRulesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicyRules: %v", err)
|
||||
}
|
||||
if resp == nil {
|
||||
t.Fatal("expected non-nil response")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminRequiredMethodsCompleteness(t *testing.T) {
|
||||
// Verify that admin-required methods match our security spec.
|
||||
// This test catches the security defect of adding an RPC without
|
||||
// adding it to the adminRequiredMethods map.
|
||||
expected := []string{
|
||||
"/mcr.v1.RegistryService/DeleteRepository",
|
||||
"/mcr.v1.RegistryService/GarbageCollect",
|
||||
"/mcr.v1.RegistryService/GetGCStatus",
|
||||
"/mcr.v1.PolicyService/ListPolicyRules",
|
||||
"/mcr.v1.PolicyService/CreatePolicyRule",
|
||||
"/mcr.v1.PolicyService/GetPolicyRule",
|
||||
"/mcr.v1.PolicyService/UpdatePolicyRule",
|
||||
"/mcr.v1.PolicyService/DeletePolicyRule",
|
||||
"/mcr.v1.AuditService/ListAuditEvents",
|
||||
}
|
||||
|
||||
for _, method := range expected {
|
||||
if !adminRequiredMethods[method] {
|
||||
t.Errorf("method %s should require admin but is not in adminRequiredMethods", method)
|
||||
}
|
||||
}
|
||||
|
||||
if len(adminRequiredMethods) != len(expected) {
|
||||
t.Errorf("adminRequiredMethods has %d entries, expected %d", len(adminRequiredMethods), len(expected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthBypassMethodsCompleteness(t *testing.T) {
|
||||
// Health is the only method that bypasses auth.
|
||||
expected := []string{
|
||||
"/mcr.v1.AdminService/Health",
|
||||
}
|
||||
|
||||
for _, method := range expected {
|
||||
if !authBypassMethods[method] {
|
||||
t.Errorf("method %s should bypass auth but is not in authBypassMethods", method)
|
||||
}
|
||||
}
|
||||
|
||||
if len(authBypassMethods) != len(expected) {
|
||||
t.Errorf("authBypassMethods has %d entries, expected %d", len(authBypassMethods), len(expected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRepoRequiresAdmin(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
_, err := client.DeleteRepository(ctx, &pb.DeleteRepositoryRequest{Name: "test"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-admin user trying to delete repo")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCRequiresAdmin(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
_, err := client.GarbageCollect(ctx, &pb.GarbageCollectRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-admin user trying to trigger GC")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditRequiresAdmin(t *testing.T) {
|
||||
database := openTestDB(t)
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
|
||||
cc := startTestServer(t, Deps{
|
||||
DB: database,
|
||||
Validator: validator,
|
||||
})
|
||||
|
||||
ctx := withAuth(context.Background(), "valid-token")
|
||||
client := pb.NewAuditServiceClient(cc)
|
||||
_, err := client.ListAuditEvents(ctx, &pb.ListAuditEventsRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-admin user trying to list audit events")
|
||||
}
|
||||
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status error, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.PermissionDenied {
|
||||
t.Fatalf("code: got %v, want PermissionDenied", st.Code())
|
||||
}
|
||||
}
|
||||
292
internal/grpcserver/policy.go
Normal file
292
internal/grpcserver/policy.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||
)
|
||||
|
||||
var validActions = map[string]bool{
|
||||
string(policy.ActionVersionCheck): true,
|
||||
string(policy.ActionPull): true,
|
||||
string(policy.ActionPush): true,
|
||||
string(policy.ActionDelete): true,
|
||||
string(policy.ActionCatalog): true,
|
||||
string(policy.ActionPolicyManage): true,
|
||||
}
|
||||
|
||||
// policyService implements pb.PolicyServiceServer.
|
||||
type policyService struct {
|
||||
pb.UnimplementedPolicyServiceServer
|
||||
db *db.DB
|
||||
engine PolicyReloader
|
||||
auditFn AuditFunc
|
||||
}
|
||||
|
||||
func (s *policyService) ListPolicyRules(_ context.Context, req *pb.ListPolicyRulesRequest) (*pb.ListPolicyRulesResponse, error) {
|
||||
limit := int32(50)
|
||||
offset := int32(0)
|
||||
if req.GetPagination() != nil {
|
||||
if req.Pagination.Limit > 0 {
|
||||
limit = req.Pagination.Limit
|
||||
}
|
||||
if req.Pagination.Offset >= 0 {
|
||||
offset = req.Pagination.Offset
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := s.db.ListPolicyRules(int(limit), int(offset))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
var result []*pb.PolicyRule
|
||||
for _, r := range rules {
|
||||
result = append(result, policyRuleRowToProto(&r))
|
||||
}
|
||||
|
||||
return &pb.ListPolicyRulesResponse{Rules: result}, nil
|
||||
}
|
||||
|
||||
func (s *policyService) CreatePolicyRule(ctx context.Context, req *pb.CreatePolicyRuleRequest) (*pb.PolicyRule, error) {
|
||||
if req.GetPriority() < 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||
}
|
||||
if req.GetDescription() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "description is required")
|
||||
}
|
||||
if err := validateEffect(req.GetEffect()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
if len(req.GetActions()) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "at least one action is required")
|
||||
}
|
||||
if err := validateActions(req.GetActions()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
|
||||
claims := auth.ClaimsFromContext(ctx)
|
||||
createdBy := ""
|
||||
if claims != nil {
|
||||
createdBy = claims.Subject
|
||||
}
|
||||
|
||||
row := db.PolicyRuleRow{
|
||||
Priority: int(req.Priority),
|
||||
Description: req.Description,
|
||||
Effect: req.Effect,
|
||||
Roles: req.Roles,
|
||||
AccountTypes: req.AccountTypes,
|
||||
SubjectUUID: req.SubjectUuid,
|
||||
Actions: req.Actions,
|
||||
Repositories: req.Repositories,
|
||||
Enabled: req.Enabled,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
|
||||
id, err := s.db.CreatePolicyRule(row)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if s.engine != nil {
|
||||
_ = s.engine.Reload(s.db)
|
||||
}
|
||||
|
||||
if s.auditFn != nil {
|
||||
s.auditFn("policy_rule_created", createdBy, "", "", "", map[string]string{
|
||||
"rule_id": strconv.FormatInt(id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
created, err := s.db.GetPolicyRule(id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
return policyRuleRowToProto(created), nil
|
||||
}
|
||||
|
||||
func (s *policyService) GetPolicyRule(_ context.Context, req *pb.GetPolicyRuleRequest) (*pb.PolicyRule, error) {
|
||||
if req.GetId() == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
|
||||
}
|
||||
|
||||
rule, err := s.db.GetPolicyRule(req.Id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
return policyRuleRowToProto(rule), nil
|
||||
}
|
||||
|
||||
func (s *policyService) UpdatePolicyRule(ctx context.Context, req *pb.UpdatePolicyRuleRequest) (*pb.PolicyRule, error) {
|
||||
if req.GetId() == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
|
||||
}
|
||||
|
||||
mask := make(map[string]bool, len(req.GetUpdateMask()))
|
||||
for _, f := range req.GetUpdateMask() {
|
||||
mask[f] = true
|
||||
}
|
||||
|
||||
// Validate fields if they are in the update mask.
|
||||
if mask["priority"] && req.Priority < 1 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||
}
|
||||
if mask["effect"] {
|
||||
if err := validateEffect(req.GetEffect()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
if mask["actions"] {
|
||||
if len(req.GetActions()) == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "at least one action is required")
|
||||
}
|
||||
if err := validateActions(req.GetActions()); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
updates := db.PolicyRuleRow{}
|
||||
if mask["priority"] {
|
||||
updates.Priority = int(req.Priority)
|
||||
}
|
||||
if mask["description"] {
|
||||
updates.Description = req.Description
|
||||
}
|
||||
if mask["effect"] {
|
||||
updates.Effect = req.Effect
|
||||
}
|
||||
if mask["roles"] {
|
||||
updates.Roles = req.Roles
|
||||
}
|
||||
if mask["account_types"] {
|
||||
updates.AccountTypes = req.AccountTypes
|
||||
}
|
||||
if mask["subject_uuid"] {
|
||||
updates.SubjectUUID = req.SubjectUuid
|
||||
}
|
||||
if mask["actions"] {
|
||||
updates.Actions = req.Actions
|
||||
}
|
||||
if mask["repositories"] {
|
||||
updates.Repositories = req.Repositories
|
||||
}
|
||||
|
||||
if err := s.db.UpdatePolicyRule(req.Id, updates); err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
// Handle enabled separately since it's a bool.
|
||||
if mask["enabled"] {
|
||||
if err := s.db.SetPolicyRuleEnabled(req.Id, req.Enabled); err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if s.engine != nil {
|
||||
_ = s.engine.Reload(s.db)
|
||||
}
|
||||
|
||||
if s.auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(ctx)
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
s.auditFn("policy_rule_updated", actorID, "", "", "", map[string]string{
|
||||
"rule_id": strconv.FormatInt(req.Id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
updated, err := s.db.GetPolicyRule(req.Id)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
return policyRuleRowToProto(updated), nil
|
||||
}
|
||||
|
||||
func (s *policyService) DeletePolicyRule(ctx context.Context, req *pb.DeletePolicyRuleRequest) (*pb.DeletePolicyRuleResponse, error) {
|
||||
if req.GetId() == 0 {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "rule ID required")
|
||||
}
|
||||
|
||||
if err := s.db.DeletePolicyRule(req.Id); err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if s.engine != nil {
|
||||
_ = s.engine.Reload(s.db)
|
||||
}
|
||||
|
||||
if s.auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(ctx)
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
s.auditFn("policy_rule_deleted", actorID, "", "", "", map[string]string{
|
||||
"rule_id": strconv.FormatInt(req.Id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.DeletePolicyRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// policyRuleRowToProto converts a db.PolicyRuleRow to a protobuf PolicyRule.
|
||||
func policyRuleRowToProto(r *db.PolicyRuleRow) *pb.PolicyRule {
|
||||
return &pb.PolicyRule{
|
||||
Id: r.ID,
|
||||
Priority: int32(r.Priority), //nolint:gosec // priority is always small positive int
|
||||
Description: r.Description,
|
||||
Effect: r.Effect,
|
||||
Roles: r.Roles,
|
||||
AccountTypes: r.AccountTypes,
|
||||
SubjectUuid: r.SubjectUUID,
|
||||
Actions: r.Actions,
|
||||
Repositories: r.Repositories,
|
||||
Enabled: r.Enabled,
|
||||
CreatedBy: r.CreatedBy,
|
||||
CreatedAt: r.CreatedAt,
|
||||
UpdatedAt: r.UpdatedAt,
|
||||
}
|
||||
}
|
||||
|
||||
func validateEffect(effect string) error {
|
||||
if effect != "allow" && effect != "deny" {
|
||||
return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateActions(actions []string) error {
|
||||
for _, a := range actions {
|
||||
if !validActions[a] {
|
||||
return fmt.Errorf("invalid action: %q", a)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
306
internal/grpcserver/policy_test.go
Normal file
306
internal/grpcserver/policy_test.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||
)
|
||||
|
||||
type fakePolicyReloader struct {
|
||||
reloadCount int
|
||||
}
|
||||
|
||||
func (f *fakePolicyReloader) Reload(_ policy.RuleStore) error {
|
||||
f.reloadCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestCreatePolicyRule(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
reloader := &fakePolicyReloader{}
|
||||
deps.Engine = reloader
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
resp, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
|
||||
Priority: 10,
|
||||
Description: "allow pull for all",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule: %v", err)
|
||||
}
|
||||
if resp.GetId() == 0 {
|
||||
t.Fatal("expected non-zero ID")
|
||||
}
|
||||
if resp.GetDescription() != "allow pull for all" {
|
||||
t.Fatalf("description: got %q, want %q", resp.Description, "allow pull for all")
|
||||
}
|
||||
if resp.GetEffect() != "allow" {
|
||||
t.Fatalf("effect: got %q, want %q", resp.Effect, "allow")
|
||||
}
|
||||
if !resp.GetEnabled() {
|
||||
t.Fatal("expected enabled=true")
|
||||
}
|
||||
if reloader.reloadCount != 1 {
|
||||
t.Fatalf("reloadCount: got %d, want 1", reloader.reloadCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreatePolicyRuleValidation(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
req *pb.CreatePolicyRuleRequest
|
||||
code codes.Code
|
||||
}{
|
||||
{
|
||||
name: "zero priority",
|
||||
req: &pb.CreatePolicyRuleRequest{
|
||||
Priority: 0,
|
||||
Description: "test",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
},
|
||||
code: codes.InvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "empty description",
|
||||
req: &pb.CreatePolicyRuleRequest{
|
||||
Priority: 1,
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
},
|
||||
code: codes.InvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid effect",
|
||||
req: &pb.CreatePolicyRuleRequest{
|
||||
Priority: 1,
|
||||
Description: "test",
|
||||
Effect: "maybe",
|
||||
Actions: []string{"registry:pull"},
|
||||
},
|
||||
code: codes.InvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "no actions",
|
||||
req: &pb.CreatePolicyRuleRequest{
|
||||
Priority: 1,
|
||||
Description: "test",
|
||||
Effect: "allow",
|
||||
},
|
||||
code: codes.InvalidArgument,
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
req: &pb.CreatePolicyRuleRequest{
|
||||
Priority: 1,
|
||||
Description: "test",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:fly"},
|
||||
},
|
||||
code: codes.InvalidArgument,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := client.CreatePolicyRule(adminCtx(), tt.req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != tt.code {
|
||||
t.Fatalf("code: got %v, want %v", st.Code(), tt.code)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPolicyRule(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
// Create a rule first.
|
||||
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
|
||||
Priority: 5,
|
||||
Description: "test rule",
|
||||
Effect: "deny",
|
||||
Actions: []string{"registry:push"},
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule: %v", err)
|
||||
}
|
||||
|
||||
// Fetch it.
|
||||
got, err := client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: created.Id})
|
||||
if err != nil {
|
||||
t.Fatalf("GetPolicyRule: %v", err)
|
||||
}
|
||||
if got.Id != created.Id {
|
||||
t.Fatalf("id: got %d, want %d", got.Id, created.Id)
|
||||
}
|
||||
if got.Effect != "deny" {
|
||||
t.Fatalf("effect: got %q, want %q", got.Effect, "deny")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetPolicyRuleNotFound(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
_, err := client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: 99999})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.NotFound {
|
||||
t.Fatalf("code: got %v, want NotFound", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestListPolicyRules(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
// Create two rules.
|
||||
for i := range 2 {
|
||||
_, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
|
||||
Priority: int32(i + 1),
|
||||
Description: "rule",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := client.ListPolicyRules(adminCtx(), &pb.ListPolicyRulesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListPolicyRules: %v", err)
|
||||
}
|
||||
if len(resp.GetRules()) < 2 {
|
||||
t.Fatalf("expected at least 2 rules, got %d", len(resp.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeletePolicyRule(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
reloader := &fakePolicyReloader{}
|
||||
deps.Engine = reloader
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
// Create then delete.
|
||||
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
|
||||
Priority: 1,
|
||||
Description: "to be deleted",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule: %v", err)
|
||||
}
|
||||
initialReloads := reloader.reloadCount
|
||||
|
||||
_, err = client.DeletePolicyRule(adminCtx(), &pb.DeletePolicyRuleRequest{Id: created.Id})
|
||||
if err != nil {
|
||||
t.Fatalf("DeletePolicyRule: %v", err)
|
||||
}
|
||||
|
||||
// Verify it was reloaded.
|
||||
if reloader.reloadCount != initialReloads+1 {
|
||||
t.Fatalf("reloadCount: got %d, want %d", reloader.reloadCount, initialReloads+1)
|
||||
}
|
||||
|
||||
// Verify it's gone.
|
||||
_, err = client.GetPolicyRule(adminCtx(), &pb.GetPolicyRuleRequest{Id: created.Id})
|
||||
if err == nil {
|
||||
t.Fatal("expected error after deletion")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.NotFound {
|
||||
t.Fatalf("code: got %v, want NotFound", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeletePolicyRuleNotFound(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
_, err := client.DeletePolicyRule(adminCtx(), &pb.DeletePolicyRuleRequest{Id: 99999})
|
||||
if err == nil {
|
||||
t.Fatal("expected error")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.NotFound {
|
||||
t.Fatalf("code: got %v, want NotFound", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdatePolicyRule(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
reloader := &fakePolicyReloader{}
|
||||
deps.Engine = reloader
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewPolicyServiceClient(cc)
|
||||
|
||||
// Create a rule.
|
||||
created, err := client.CreatePolicyRule(adminCtx(), &pb.CreatePolicyRuleRequest{
|
||||
Priority: 10,
|
||||
Description: "original",
|
||||
Effect: "allow",
|
||||
Actions: []string{"registry:pull"},
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("CreatePolicyRule: %v", err)
|
||||
}
|
||||
|
||||
// Update description.
|
||||
updated, err := client.UpdatePolicyRule(adminCtx(), &pb.UpdatePolicyRuleRequest{
|
||||
Id: created.Id,
|
||||
Description: "updated description",
|
||||
UpdateMask: []string{"description"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("UpdatePolicyRule: %v", err)
|
||||
}
|
||||
if updated.Description != "updated description" {
|
||||
t.Fatalf("description: got %q, want %q", updated.Description, "updated description")
|
||||
}
|
||||
// Effect should be unchanged.
|
||||
if updated.Effect != "allow" {
|
||||
t.Fatalf("effect: got %q, want %q", updated.Effect, "allow")
|
||||
}
|
||||
}
|
||||
203
internal/grpcserver/registry.go
Normal file
203
internal/grpcserver/registry.go
Normal file
@@ -0,0 +1,203 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/google/uuid"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/gc"
|
||||
)
|
||||
|
||||
// registryService implements pb.RegistryServiceServer.
|
||||
type registryService struct {
|
||||
pb.UnimplementedRegistryServiceServer
|
||||
db *db.DB
|
||||
collector *gc.Collector
|
||||
gcStatus *GCStatus
|
||||
auditFn AuditFunc
|
||||
}
|
||||
|
||||
func (s *registryService) ListRepositories(_ context.Context, req *pb.ListRepositoriesRequest) (*pb.ListRepositoriesResponse, error) {
|
||||
limit := int32(50)
|
||||
offset := int32(0)
|
||||
if req.GetPagination() != nil {
|
||||
if req.Pagination.Limit > 0 {
|
||||
limit = req.Pagination.Limit
|
||||
}
|
||||
if req.Pagination.Offset >= 0 {
|
||||
offset = req.Pagination.Offset
|
||||
}
|
||||
}
|
||||
|
||||
repos, err := s.db.ListRepositoriesWithMetadata(int(limit), int(offset))
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
var result []*pb.RepositoryMetadata
|
||||
for _, r := range repos {
|
||||
result = append(result, &pb.RepositoryMetadata{
|
||||
Name: r.Name,
|
||||
TagCount: int32(r.TagCount), //nolint:gosec // tag count fits int32
|
||||
ManifestCount: int32(r.ManifestCount), //nolint:gosec // manifest count fits int32
|
||||
TotalSize: r.TotalSize,
|
||||
CreatedAt: r.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.ListRepositoriesResponse{Repositories: result}, nil
|
||||
}
|
||||
|
||||
func (s *registryService) GetRepository(_ context.Context, req *pb.GetRepositoryRequest) (*pb.GetRepositoryResponse, error) {
|
||||
if req.GetName() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "repository name required")
|
||||
}
|
||||
|
||||
detail, err := s.db.GetRepositoryDetail(req.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrRepoNotFound) {
|
||||
return nil, status.Errorf(codes.NotFound, "repository not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
resp := &pb.GetRepositoryResponse{
|
||||
Name: detail.Name,
|
||||
TotalSize: detail.TotalSize,
|
||||
CreatedAt: detail.CreatedAt,
|
||||
}
|
||||
for _, t := range detail.Tags {
|
||||
resp.Tags = append(resp.Tags, &pb.TagInfo{
|
||||
Name: t.Name,
|
||||
Digest: t.Digest,
|
||||
})
|
||||
}
|
||||
for _, m := range detail.Manifests {
|
||||
resp.Manifests = append(resp.Manifests, &pb.ManifestInfo{
|
||||
Digest: m.Digest,
|
||||
MediaType: m.MediaType,
|
||||
Size: m.Size,
|
||||
CreatedAt: m.CreatedAt,
|
||||
})
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *registryService) DeleteRepository(ctx context.Context, req *pb.DeleteRepositoryRequest) (*pb.DeleteRepositoryResponse, error) {
|
||||
if req.GetName() == "" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "repository name required")
|
||||
}
|
||||
|
||||
if err := s.db.DeleteRepository(req.Name); err != nil {
|
||||
if errors.Is(err, db.ErrRepoNotFound) {
|
||||
return nil, status.Errorf(codes.NotFound, "repository not found")
|
||||
}
|
||||
return nil, status.Errorf(codes.Internal, "internal error")
|
||||
}
|
||||
|
||||
if s.auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(ctx)
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
s.auditFn("repo_deleted", actorID, req.Name, "", "", nil)
|
||||
}
|
||||
|
||||
return &pb.DeleteRepositoryResponse{}, nil
|
||||
}
|
||||
|
||||
func (s *registryService) GarbageCollect(_ context.Context, _ *pb.GarbageCollectRequest) (*pb.GarbageCollectResponse, error) {
|
||||
s.gcStatus.mu.Lock()
|
||||
if s.gcStatus.running {
|
||||
s.gcStatus.mu.Unlock()
|
||||
return nil, status.Errorf(codes.AlreadyExists, "garbage collection already running")
|
||||
}
|
||||
s.gcStatus.running = true
|
||||
s.gcStatus.mu.Unlock()
|
||||
|
||||
gcID := uuid.New().String()
|
||||
|
||||
// Run GC asynchronously like the REST handler. GC is a long-running
|
||||
// background operation that must not be tied to the request context,
|
||||
// so we intentionally use context.Background() inside runGC.
|
||||
go s.runGC(gcID) //nolint:gosec // G118: GC must outlive the triggering RPC
|
||||
|
||||
return &pb.GarbageCollectResponse{Id: gcID}, nil
|
||||
}
|
||||
|
||||
// runGC executes garbage collection in the background. It uses
|
||||
// context.Background() because GC must not be cancelled when the
|
||||
// triggering RPC completes.
|
||||
func (s *registryService) runGC(gcID string) {
|
||||
startedAt := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
if s.auditFn != nil {
|
||||
s.auditFn("gc_started", "", "", "", "", map[string]string{
|
||||
"gc_id": gcID,
|
||||
})
|
||||
}
|
||||
|
||||
var blobsRemoved int
|
||||
var bytesFreed int64
|
||||
var gcErr error
|
||||
if s.collector != nil {
|
||||
r, err := s.collector.Run(context.Background()) //nolint:gosec // GC is intentionally background, not request-scoped
|
||||
if err != nil {
|
||||
gcErr = err
|
||||
}
|
||||
if r != nil {
|
||||
blobsRemoved = r.BlobsRemoved
|
||||
bytesFreed = r.BytesFreed
|
||||
}
|
||||
}
|
||||
|
||||
completedAt := time.Now().UTC().Format(time.RFC3339)
|
||||
|
||||
s.gcStatus.mu.Lock()
|
||||
s.gcStatus.running = false
|
||||
s.gcStatus.lastRun = &gcLastRun{
|
||||
StartedAt: startedAt,
|
||||
CompletedAt: completedAt,
|
||||
BlobsRemoved: blobsRemoved,
|
||||
BytesFreed: bytesFreed,
|
||||
}
|
||||
s.gcStatus.mu.Unlock()
|
||||
|
||||
if s.auditFn != nil && gcErr == nil {
|
||||
details := map[string]string{
|
||||
"gc_id": gcID,
|
||||
"blobs_removed": fmt.Sprintf("%d", blobsRemoved),
|
||||
"bytes_freed": fmt.Sprintf("%d", bytesFreed),
|
||||
}
|
||||
s.auditFn("gc_completed", "", "", "", "", details)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *registryService) GetGCStatus(_ context.Context, _ *pb.GetGCStatusRequest) (*pb.GetGCStatusResponse, error) {
|
||||
s.gcStatus.mu.Lock()
|
||||
resp := &pb.GetGCStatusResponse{
|
||||
Running: s.gcStatus.running,
|
||||
}
|
||||
if s.gcStatus.lastRun != nil {
|
||||
resp.LastRun = &pb.GCLastRun{
|
||||
StartedAt: s.gcStatus.lastRun.StartedAt,
|
||||
CompletedAt: s.gcStatus.lastRun.CompletedAt,
|
||||
BlobsRemoved: int32(s.gcStatus.lastRun.BlobsRemoved), //nolint:gosec // blob count fits int32
|
||||
BytesFreed: s.gcStatus.lastRun.BytesFreed,
|
||||
}
|
||||
}
|
||||
s.gcStatus.mu.Unlock()
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
144
internal/grpcserver/registry_test.go
Normal file
144
internal/grpcserver/registry_test.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
)
|
||||
|
||||
func adminDeps(t *testing.T) Deps {
|
||||
t.Helper()
|
||||
return Deps{
|
||||
DB: openTestDB(t),
|
||||
Validator: &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func adminCtx() context.Context {
|
||||
return withAuth(context.Background(), "admin-token")
|
||||
}
|
||||
|
||||
func TestListRepositoriesEmpty(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
resp, err := client.ListRepositories(adminCtx(), &pb.ListRepositoriesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRepositories: %v", err)
|
||||
}
|
||||
if len(resp.GetRepositories()) != 0 {
|
||||
t.Fatalf("expected 0 repos, got %d", len(resp.Repositories))
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRepositoryNotFound(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
_, err := client.GetRepository(adminCtx(), &pb.GetRepositoryRequest{Name: "nonexistent"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent repo")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.NotFound {
|
||||
t.Fatalf("code: got %v, want NotFound", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetRepositoryEmptyName(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
_, err := client.GetRepository(adminCtx(), &pb.GetRepositoryRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty name")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.InvalidArgument {
|
||||
t.Fatalf("code: got %v, want InvalidArgument", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRepositoryNotFound(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
_, err := client.DeleteRepository(adminCtx(), &pb.DeleteRepositoryRequest{Name: "nonexistent"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent repo")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.NotFound {
|
||||
t.Fatalf("code: got %v, want NotFound", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRepositoryEmptyName(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
_, err := client.DeleteRepository(adminCtx(), &pb.DeleteRepositoryRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty name")
|
||||
}
|
||||
st, ok := status.FromError(err)
|
||||
if !ok {
|
||||
t.Fatalf("expected gRPC status, got %v", err)
|
||||
}
|
||||
if st.Code() != codes.InvalidArgument {
|
||||
t.Fatalf("code: got %v, want InvalidArgument", st.Code())
|
||||
}
|
||||
}
|
||||
|
||||
func TestGCStatusInitial(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
resp, err := client.GetGCStatus(adminCtx(), &pb.GetGCStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetGCStatus: %v", err)
|
||||
}
|
||||
if resp.Running {
|
||||
t.Fatal("expected running=false on startup")
|
||||
}
|
||||
if resp.LastRun != nil {
|
||||
t.Fatal("expected no last_run on startup")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGarbageCollectTrigger(t *testing.T) {
|
||||
deps := adminDeps(t)
|
||||
cc := startTestServer(t, deps)
|
||||
client := pb.NewRegistryServiceClient(cc)
|
||||
|
||||
// Trigger GC without a collector (no-op but should return an ID).
|
||||
resp, err := client.GarbageCollect(adminCtx(), &pb.GarbageCollectRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GarbageCollect: %v", err)
|
||||
}
|
||||
if resp.GetId() == "" {
|
||||
t.Fatal("expected non-empty GC ID")
|
||||
}
|
||||
}
|
||||
141
internal/grpcserver/server.go
Normal file
141
internal/grpcserver/server.go
Normal file
@@ -0,0 +1,141 @@
|
||||
// Package grpcserver implements the MCR gRPC admin API server.
|
||||
//
|
||||
// It provides the same business logic as the REST admin API in
|
||||
// internal/server/, using shared internal/db and internal/gc packages.
|
||||
// The server enforces TLS 1.3 minimum, auth via MCIAS token validation,
|
||||
// and admin role checks on privileged RPCs.
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"sync"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/gc"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/server"
|
||||
)
|
||||
|
||||
// AuditFunc is a callback for recording audit events. It follows the same
|
||||
// signature as db.WriteAuditEvent but without an error return -- audit
|
||||
// failures should not block request processing.
|
||||
type AuditFunc func(eventType, actorID, repository, digest, ip string, details map[string]string)
|
||||
|
||||
// Deps holds the dependencies injected into the gRPC server.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Validator server.TokenValidator
|
||||
Engine PolicyReloader
|
||||
AuditFn AuditFunc
|
||||
Collector *gc.Collector
|
||||
}
|
||||
|
||||
// PolicyReloader can reload policy rules from a store.
|
||||
type PolicyReloader interface {
|
||||
Reload(store policy.RuleStore) error
|
||||
}
|
||||
|
||||
// GCStatus tracks the current state of garbage collection for the gRPC server.
|
||||
type GCStatus struct {
|
||||
mu sync.Mutex
|
||||
running bool
|
||||
lastRun *gcLastRun
|
||||
}
|
||||
|
||||
type gcLastRun struct {
|
||||
StartedAt string
|
||||
CompletedAt string
|
||||
BlobsRemoved int
|
||||
BytesFreed int64
|
||||
}
|
||||
|
||||
// Server wraps a grpc.Server with MCR-specific configuration.
|
||||
type Server struct {
|
||||
gs *grpc.Server
|
||||
deps Deps
|
||||
gcStatus *GCStatus
|
||||
}
|
||||
|
||||
// New creates a configured gRPC server with the interceptor chain:
|
||||
// [Request Logger] -> [Auth Interceptor] -> [Admin Interceptor] -> [Handler]
|
||||
//
|
||||
// The TLS config enforces TLS 1.3 minimum. If certFile or keyFile is
|
||||
// empty, the server is created without TLS (for testing only).
|
||||
func New(certFile, keyFile string, deps Deps) (*Server, error) {
|
||||
authInt := newAuthInterceptor(deps.Validator)
|
||||
adminInt := newAdminInterceptor()
|
||||
|
||||
chain := grpc.ChainUnaryInterceptor(
|
||||
loggingInterceptor,
|
||||
authInt.unary,
|
||||
adminInt.unary,
|
||||
)
|
||||
|
||||
var opts []grpc.ServerOption
|
||||
opts = append(opts, chain)
|
||||
|
||||
// Configure TLS if cert and key are provided.
|
||||
if certFile != "" && keyFile != "" {
|
||||
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("grpcserver: load TLS cert: %w", err)
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
opts = append(opts, grpc.Creds(credentials.NewTLS(tlsCfg)))
|
||||
}
|
||||
|
||||
// The JSON codec is registered globally via init() in gen/mcr/v1/codec.go.
|
||||
// The client must use grpc.ForceCodecV2(mcrv1.JSONCodec{}) to match.
|
||||
_ = pb.JSONCodec{} // ensure the gen/mcr/v1 init() runs (codec registration)
|
||||
|
||||
gs := grpc.NewServer(opts...)
|
||||
|
||||
gcStatus := &GCStatus{}
|
||||
|
||||
s := &Server{gs: gs, deps: deps, gcStatus: gcStatus}
|
||||
|
||||
// Register all services.
|
||||
pb.RegisterRegistryServiceServer(gs, ®istryService{
|
||||
db: deps.DB,
|
||||
collector: deps.Collector,
|
||||
gcStatus: gcStatus,
|
||||
auditFn: deps.AuditFn,
|
||||
})
|
||||
pb.RegisterPolicyServiceServer(gs, &policyService{
|
||||
db: deps.DB,
|
||||
engine: deps.Engine,
|
||||
auditFn: deps.AuditFn,
|
||||
})
|
||||
pb.RegisterAuditServiceServer(gs, &auditService{
|
||||
db: deps.DB,
|
||||
})
|
||||
pb.RegisterAdminServiceServer(gs, &adminService{})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Serve starts the gRPC server on the given listener.
|
||||
func (s *Server) Serve(lis net.Listener) error {
|
||||
log.Printf("grpc server listening on %s", lis.Addr())
|
||||
return s.gs.Serve(lis)
|
||||
}
|
||||
|
||||
// GracefulStop gracefully stops the gRPC server.
|
||||
func (s *Server) GracefulStop() {
|
||||
s.gs.GracefulStop()
|
||||
}
|
||||
|
||||
// GRPCServer returns the underlying grpc.Server for testing.
|
||||
func (s *Server) GRPCServer() *grpc.Server {
|
||||
return s.gs
|
||||
}
|
||||
Reference in New Issue
Block a user