Files
mcr/cmd/mcrsrv/main.go
Kyle Isom 758aa91bfc Migrate gRPC server to mcdsl grpcserver package
Replace MCR's custom auth, admin, and logging interceptors with the
shared mcdsl grpcserver package. This eliminates ~110 lines of
interceptor code and uses the same method-map auth pattern used by
metacrypt.

Key changes:
- server.go: delegate to mcdslgrpc.New() for TLS, logging, and auth
- interceptors.go: replaced with MethodMap definition (public, auth-required, admin-required)
- Handler files: switch from auth.ClaimsFromContext to mcdslauth.TokenInfoFromContext
- auth/client.go: add Authenticator() accessor for the underlying mcdsl authenticator
- Tests: use mock MCIAS HTTP server instead of fakeValidator interface
- Vendor: add mcdsl/grpcserver to vendor directory

ListRepositories and GetRepository are now explicitly auth-required
(not admin-required), matching the REST API. Previously they were
implicitly auth-required by not being in the bypass or admin maps.

Security: method map uses default-deny -- unmapped RPCs are rejected.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:46:03 -07:00

310 lines
7.9 KiB
Go

package main
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"log/slog"
"net"
"net/http"
"os"
"os/signal"
"path/filepath"
"syscall"
"time"
"github.com/spf13/cobra"
"git.wntrmute.dev/kyle/mcr/internal/auth"
"git.wntrmute.dev/kyle/mcr/internal/config"
"git.wntrmute.dev/kyle/mcr/internal/db"
"git.wntrmute.dev/kyle/mcr/internal/gc"
"git.wntrmute.dev/kyle/mcr/internal/oci"
"git.wntrmute.dev/kyle/mcr/internal/policy"
"git.wntrmute.dev/kyle/mcr/internal/server"
"git.wntrmute.dev/kyle/mcr/internal/storage"
"git.wntrmute.dev/kyle/mcr/internal/grpcserver"
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
var version = "dev"
func main() {
root := &cobra.Command{
Use: "mcrsrv",
Short: "Metacircular Container Registry server",
Version: version,
}
root.AddCommand(serverCmd())
root.AddCommand(statusCmd())
root.AddCommand(snapshotCmd())
if err := root.Execute(); err != nil {
os.Exit(1)
}
}
func serverCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "server",
Short: "Start the registry server",
RunE: func(_ *cobra.Command, _ []string) error {
return runServer(configPath)
},
}
cmd.Flags().StringVarP(&configPath, "config", "c", "mcr.toml", "path to configuration file")
return cmd
}
func runServer(configPath string) error {
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: parseLogLevel(cfg.Log.Level),
}))
// Open and migrate the database.
database, err := db.Open(cfg.Database.Path)
if err != nil {
return fmt.Errorf("open database: %w", err)
}
defer database.Close()
if err := database.Migrate(); err != nil {
return fmt.Errorf("migrate database: %w", err)
}
// Ensure storage directories exist.
for _, dir := range []string{cfg.Storage.LayersPath, cfg.Storage.UploadsPath} {
if err := os.MkdirAll(dir, 0700); err != nil {
return fmt.Errorf("create storage directory %s: %w", dir, err)
}
}
// Create auth client for MCIAS integration.
authClient, err := auth.NewClient(
cfg.MCIAS.ServerURL,
cfg.MCIAS.CACert,
cfg.MCIAS.ServiceName,
cfg.MCIAS.Tags,
)
if err != nil {
return fmt.Errorf("create auth client: %w", err)
}
// Create blob storage.
store := storage.New(cfg.Storage.LayersPath, cfg.Storage.UploadsPath)
// Create garbage collector.
collector := gc.New(database, store)
// Create and load policy engine.
policyEngine := policy.NewEngine()
if err := policyEngine.Reload(database); err != nil {
logger.Warn("loading policy rules (using defaults)", "error", err)
}
// Audit callback.
auditFn := func(eventType, actorID, repository, digest, ip string, details map[string]string) {
if err := database.WriteAuditEvent(eventType, actorID, repository, digest, ip, details); err != nil {
logger.Error("audit write failed", "error", err)
}
}
// Create OCI handler and HTTP router.
ociHandler := oci.NewHandler(database, store, policyEngine, auditFn, logger)
router := server.NewRouter(authClient, authClient, cfg.MCIAS.ServiceName, ociHandler.Router())
// Mount admin REST endpoints.
gcState := &server.GCState{
Collector: collector,
AuditFn: auditFn,
}
server.MountAdminRoutes(router, authClient, cfg.MCIAS.ServiceName, server.AdminDeps{
DB: database,
Login: authClient,
Engine: policyEngine,
AuditFn: auditFn,
GCState: gcState,
})
// TLS configuration.
cert, err := tls.LoadX509KeyPair(cfg.Server.TLSCert, cfg.Server.TLSKey)
if err != nil {
return fmt.Errorf("load TLS cert: %w", err)
}
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS13,
Certificates: []tls.Certificate{cert},
}
// HTTP server.
httpServer := &http.Server{
Addr: cfg.Server.ListenAddr,
Handler: router,
TLSConfig: tlsCfg,
ReadTimeout: cfg.Server.ReadTimeout.Duration,
WriteTimeout: cfg.Server.WriteTimeout.Duration,
IdleTimeout: cfg.Server.IdleTimeout.Duration,
}
// Start gRPC server if configured.
var grpcSrv *grpcserver.Server
var grpcLis net.Listener
if cfg.Server.GRPCAddr != "" {
grpcDeps := grpcserver.Deps{
DB: database,
Authenticator: authClient.Authenticator(),
Engine: policyEngine,
AuditFn: auditFn,
Collector: collector,
}
grpcSrv, err = grpcserver.New(cfg.Server.TLSCert, cfg.Server.TLSKey, grpcDeps, logger)
if err != nil {
return fmt.Errorf("create gRPC server: %w", err)
}
grpcLis, err = net.Listen("tcp", cfg.Server.GRPCAddr)
if err != nil {
return fmt.Errorf("listen gRPC on %s: %w", cfg.Server.GRPCAddr, err)
}
}
// Graceful shutdown on SIGINT/SIGTERM.
ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer stop()
errCh := make(chan error, 2)
if grpcSrv != nil {
go func() {
logger.Info("gRPC server listening", "addr", grpcLis.Addr())
errCh <- grpcSrv.Serve(grpcLis)
}()
}
go func() {
logger.Info("mcrsrv starting",
"version", version,
"addr", cfg.Server.ListenAddr,
)
errCh <- httpServer.ListenAndServeTLS("", "")
}()
select {
case err := <-errCh:
return fmt.Errorf("server error: %w", err)
case <-ctx.Done():
logger.Info("shutting down")
if grpcSrv != nil {
grpcSrv.GracefulStop()
}
shutdownTimeout := 30 * time.Second
if cfg.Server.ShutdownTimeout.Duration > 0 {
shutdownTimeout = cfg.Server.ShutdownTimeout.Duration
}
shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
defer cancel()
if err := httpServer.Shutdown(shutdownCtx); err != nil {
return fmt.Errorf("shutdown: %w", err)
}
logger.Info("mcrsrv stopped")
return nil
}
}
func statusCmd() *cobra.Command {
var addr, caCert string
cmd := &cobra.Command{
Use: "status",
Short: "Check registry health",
RunE: func(_ *cobra.Command, _ []string) error {
tlsCfg := &tls.Config{MinVersion: tls.VersionTLS13}
if caCert != "" {
pemData, err := os.ReadFile(caCert)
if err != nil {
return fmt.Errorf("read CA cert: %w", err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(pemData) {
return fmt.Errorf("no valid certificates in %s", caCert)
}
tlsCfg.RootCAs = pool
}
client := &http.Client{
Transport: &http.Transport{TLSClientConfig: tlsCfg},
Timeout: 5 * time.Second,
}
resp, err := client.Get(addr + "/v1/health")
if err != nil {
return fmt.Errorf("health check: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("health check: status %d", resp.StatusCode)
}
fmt.Println("ok")
return nil
},
}
cmd.Flags().StringVar(&addr, "addr", "https://localhost:8443", "server address")
cmd.Flags().StringVar(&caCert, "ca-cert", "", "CA certificate for TLS verification")
return cmd
}
func snapshotCmd() *cobra.Command {
var configPath string
cmd := &cobra.Command{
Use: "snapshot",
Short: "Database backup via VACUUM INTO",
RunE: func(_ *cobra.Command, _ []string) error {
cfg, err := config.Load(configPath)
if err != nil {
return fmt.Errorf("load config: %w", err)
}
database, err := db.Open(cfg.Database.Path)
if err != nil {
return fmt.Errorf("open database: %w", err)
}
defer database.Close()
backupDir := filepath.Join(filepath.Dir(cfg.Database.Path), "backups")
snapName := fmt.Sprintf("mcr-%s.db", time.Now().Format("20060102-150405"))
snapPath := filepath.Join(backupDir, snapName)
if err := mcdsldb.Snapshot(database.DB, snapPath); err != nil {
return fmt.Errorf("snapshot: %w", err)
}
fmt.Printf("Snapshot saved to %s\n", snapPath)
return nil
},
}
cmd.Flags().StringVarP(&configPath, "config", "c", "mcr.toml", "path to configuration file")
return cmd
}
func parseLogLevel(s string) slog.Level {
switch s {
case "debug":
return slog.LevelDebug
case "warn":
return slog.LevelWarn
case "error":
return slog.LevelError
default:
return slog.LevelInfo
}
}