diff --git a/cmd/mcrsrv/main.go b/cmd/mcrsrv/main.go index 06db1be..f0cb742 100644 --- a/cmd/mcrsrv/main.go +++ b/cmd/mcrsrv/main.go @@ -1,10 +1,32 @@ package main import ( + "context" + "crypto/tls" + "crypto/x509" "fmt" + "log/slog" + "net" + "net/http" "os" + "os/signal" + "path/filepath" + "syscall" + "time" "github.com/spf13/cobra" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/config" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/gc" + "git.wntrmute.dev/kyle/mcr/internal/oci" + "git.wntrmute.dev/kyle/mcr/internal/policy" + "git.wntrmute.dev/kyle/mcr/internal/server" + "git.wntrmute.dev/kyle/mcr/internal/storage" + + "git.wntrmute.dev/kyle/mcr/internal/grpcserver" + mcdsldb "git.wntrmute.dev/kyle/mcdsl/db" ) var version = "dev" @@ -17,7 +39,7 @@ func main() { } root.AddCommand(serverCmd()) - root.AddCommand(initCmd()) + root.AddCommand(statusCmd()) root.AddCommand(snapshotCmd()) if err := root.Execute(); err != nil { @@ -26,31 +48,264 @@ func main() { } func serverCmd() *cobra.Command { - return &cobra.Command{ + var configPath string + + cmd := &cobra.Command{ Use: "server", Short: "Start the registry server", RunE: func(_ *cobra.Command, _ []string) error { - return fmt.Errorf("not implemented") + return runServer(configPath) }, } + + cmd.Flags().StringVarP(&configPath, "config", "c", "mcr.toml", "path to configuration file") + return cmd +} + +func runServer(configPath string) error { + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + + logger := slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + Level: parseLogLevel(cfg.Log.Level), + })) + + // Open and migrate the database. + database, err := db.Open(cfg.Database.Path) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer database.Close() + + if err := database.Migrate(); err != nil { + return fmt.Errorf("migrate database: %w", err) + } + + // Ensure storage directories exist. + for _, dir := range []string{cfg.Storage.LayersPath, cfg.Storage.UploadsPath} { + if err := os.MkdirAll(dir, 0700); err != nil { + return fmt.Errorf("create storage directory %s: %w", dir, err) + } + } + + // Create auth client for MCIAS integration. + authClient, err := auth.NewClient( + cfg.MCIAS.ServerURL, + cfg.MCIAS.CACert, + cfg.MCIAS.ServiceName, + cfg.MCIAS.Tags, + ) + if err != nil { + return fmt.Errorf("create auth client: %w", err) + } + + // Create blob storage. + store := storage.New(cfg.Storage.LayersPath, cfg.Storage.UploadsPath) + + // Create garbage collector. + collector := gc.New(database, store) + + // Create and load policy engine. + policyEngine := policy.NewEngine() + if err := policyEngine.Reload(database); err != nil { + logger.Warn("loading policy rules (using defaults)", "error", err) + } + + // Audit callback. + auditFn := func(eventType, actorID, repository, digest, ip string, details map[string]string) { + if err := database.WriteAuditEvent(eventType, actorID, repository, digest, ip, details); err != nil { + logger.Error("audit write failed", "error", err) + } + } + + // Create OCI handler and HTTP router. + ociHandler := oci.NewHandler(database, store, policyEngine, auditFn) + + router := server.NewRouter(authClient, authClient, cfg.MCIAS.ServiceName) + // Mount OCI endpoints at /v2. + router.Mount("/v2", ociHandler.Router()) + // Mount admin REST endpoints. + gcState := &server.GCState{ + Collector: collector, + AuditFn: auditFn, + } + server.MountAdminRoutes(router, authClient, cfg.MCIAS.ServiceName, server.AdminDeps{ + DB: database, + Login: authClient, + Engine: policyEngine, + AuditFn: auditFn, + GCState: gcState, + }) + + // TLS configuration. + cert, err := tls.LoadX509KeyPair(cfg.Server.TLSCert, cfg.Server.TLSKey) + if err != nil { + return fmt.Errorf("load TLS cert: %w", err) + } + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + Certificates: []tls.Certificate{cert}, + } + + // HTTP server. + httpServer := &http.Server{ + Addr: cfg.Server.ListenAddr, + Handler: router, + TLSConfig: tlsCfg, + ReadTimeout: cfg.Server.ReadTimeout.Duration, + WriteTimeout: cfg.Server.WriteTimeout.Duration, + IdleTimeout: cfg.Server.IdleTimeout.Duration, + } + + // Start gRPC server if configured. + var grpcSrv *grpcserver.Server + var grpcLis net.Listener + if cfg.Server.GRPCAddr != "" { + grpcDeps := grpcserver.Deps{ + DB: database, + Validator: authClient, + Engine: policyEngine, + AuditFn: auditFn, + Collector: collector, + } + grpcSrv, err = grpcserver.New(cfg.Server.TLSCert, cfg.Server.TLSKey, grpcDeps) + if err != nil { + return fmt.Errorf("create gRPC server: %w", err) + } + grpcLis, err = net.Listen("tcp", cfg.Server.GRPCAddr) + if err != nil { + return fmt.Errorf("listen gRPC on %s: %w", cfg.Server.GRPCAddr, err) + } + } + + // Graceful shutdown on SIGINT/SIGTERM. + ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) + defer stop() + + errCh := make(chan error, 2) + + if grpcSrv != nil { + go func() { + logger.Info("gRPC server listening", "addr", grpcLis.Addr()) + errCh <- grpcSrv.Serve(grpcLis) + }() + } + + go func() { + logger.Info("mcrsrv starting", + "version", version, + "addr", cfg.Server.ListenAddr, + ) + errCh <- httpServer.ListenAndServeTLS("", "") + }() + + select { + case err := <-errCh: + return fmt.Errorf("server error: %w", err) + case <-ctx.Done(): + logger.Info("shutting down") + if grpcSrv != nil { + grpcSrv.GracefulStop() + } + shutdownTimeout := 30 * time.Second + if cfg.Server.ShutdownTimeout.Duration > 0 { + shutdownTimeout = cfg.Server.ShutdownTimeout.Duration + } + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownTimeout) + defer cancel() + if err := httpServer.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("shutdown: %w", err) + } + logger.Info("mcrsrv stopped") + return nil + } } -func initCmd() *cobra.Command { - return &cobra.Command{ - Use: "init", - Short: "First-time setup (create directories, example config)", +func statusCmd() *cobra.Command { + var addr, caCert string + + cmd := &cobra.Command{ + Use: "status", + Short: "Check registry health", RunE: func(_ *cobra.Command, _ []string) error { - return fmt.Errorf("not implemented") + tlsCfg := &tls.Config{MinVersion: tls.VersionTLS13} + if caCert != "" { + pemData, err := os.ReadFile(caCert) + if err != nil { + return fmt.Errorf("read CA cert: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pemData) { + return fmt.Errorf("no valid certificates in %s", caCert) + } + tlsCfg.RootCAs = pool + } + client := &http.Client{ + Transport: &http.Transport{TLSClientConfig: tlsCfg}, + Timeout: 5 * time.Second, + } + resp, err := client.Get(addr + "/v1/health") + if err != nil { + return fmt.Errorf("health check: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("health check: status %d", resp.StatusCode) + } + fmt.Println("ok") + return nil }, } + + cmd.Flags().StringVar(&addr, "addr", "https://localhost:8443", "server address") + cmd.Flags().StringVar(&caCert, "ca-cert", "", "CA certificate for TLS verification") + return cmd } func snapshotCmd() *cobra.Command { - return &cobra.Command{ + var configPath string + + cmd := &cobra.Command{ Use: "snapshot", Short: "Database backup via VACUUM INTO", RunE: func(_ *cobra.Command, _ []string) error { - return fmt.Errorf("not implemented") + cfg, err := config.Load(configPath) + if err != nil { + return fmt.Errorf("load config: %w", err) + } + database, err := db.Open(cfg.Database.Path) + if err != nil { + return fmt.Errorf("open database: %w", err) + } + defer database.Close() + + backupDir := filepath.Join(filepath.Dir(cfg.Database.Path), "backups") + snapName := fmt.Sprintf("mcr-%s.db", time.Now().Format("20060102-150405")) + snapPath := filepath.Join(backupDir, snapName) + + if err := mcdsldb.Snapshot(database.DB, snapPath); err != nil { + return fmt.Errorf("snapshot: %w", err) + } + fmt.Printf("Snapshot saved to %s\n", snapPath) + return nil }, } + + cmd.Flags().StringVarP(&configPath, "config", "c", "mcr.toml", "path to configuration file") + return cmd +} + +func parseLogLevel(s string) slog.Level { + switch s { + case "debug": + return slog.LevelDebug + case "warn": + return slog.LevelWarn + case "error": + return slog.LevelError + default: + return slog.LevelInfo + } }