diff --git a/.claude/skills/checkpoint/SKILL.md b/.claude/skills/checkpoint/SKILL.md new file mode 100644 index 0000000..ae2a532 --- /dev/null +++ b/.claude/skills/checkpoint/SKILL.md @@ -0,0 +1,8 @@ +# Checkpoint Skill + +1. Run `go build ./...` abort if errors +2. Run `go test ./...` abort if failures +3. Run `go vet ./...` +4. Run `git add -A && git status` show user what will be committed +5. Generate an appropriate commit message based on your instructions. +6. Run `git commit -m ""` and verify with `git log -1` diff --git a/.golangci.yaml b/.golangci.yaml index 276b176..9b14c37 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -51,12 +51,14 @@ linters: check-type-assertions: true govet: - # Enable all analyzers except shadow. The shadow analyzer flags the idiomatic - # `if err := f(); err != nil { ... }` pattern as shadowing an outer `err`, - # which is ubiquitous in Go and does not pose a security risk in this codebase. + # Enable all analyzers except shadow and fieldalignment. Shadow flags the + # idiomatic `if err := f(); err != nil { ... }` pattern as shadowing an + # outer `err`. Fieldalignment is a micro-optimization that hurts readability + # by reordering struct fields away from their logical grouping. enable-all: true disable: - shadow + - fieldalignment gosec: # Treat all gosec findings as errors, not warnings. diff --git a/PROGRESS.md b/PROGRESS.md index be5de0c..55a4460 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -6,11 +6,13 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ## Current State -**Phase:** Pre-implementation +**Phase:** 1 complete, ready for Batch A (Phase 2 + Phase 3) **Last updated:** 2026-03-19 ### Completed +- Phase 0: Project scaffolding (all 4 steps) +- Phase 1: Configuration & database (all 3 steps) - `ARCHITECTURE.md` — Full design specification (18 sections) - `CLAUDE.md` — AI development guidance - `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps) @@ -18,14 +20,112 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ### Next Steps -1. Begin Phase 0: Project scaffolding (Step 0.1: Go module and directory - structure) -2. After Phase 0 passes `make all`, proceed to Phase 1 +1. Begin Batch A: Phase 2 (blob storage) and Phase 3 (MCIAS auth) + in parallel — these are independent +2. After both complete, proceed to Phase 4 (policy engine) --- ## Log +### 2026-03-19 — Phase 1: Configuration & database + +**Task:** Implement TOML config loading with env overrides and validation, +SQLite database with migrations, and audit log helpers. + +**Changes:** + +Step 1.1 — `internal/config/`: +- `config.go`: `Config` struct matching ARCHITECTURE.md §10 (all 6 TOML + sections: server, database, storage, mcias, web, log) +- Parsed with `go-toml/v2`; env overrides via `MCR_` prefix using + reflection-based struct walker +- Startup validation: 6 required fields checked (listen_addr, tls_cert, + tls_key, database.path, storage.layers_path, mcias.server_url) +- Same-filesystem check for layers_path/uploads_path via device ID + comparison (walks to nearest existing parent if path doesn't exist yet) +- Default values: read_timeout=30s, write_timeout=0, idle_timeout=120s, + shutdown_timeout=60s, uploads_path derived from layers_path, log.level=info +- `device_linux.go`: Linux-specific `extractDeviceID` using `syscall.Stat_t` +- `deploy/examples/mcr.toml`: annotated example config + +Step 1.2 — `internal/db/`: +- `db.go`: `Open(path)` creates/opens SQLite via `modernc.org/sqlite`, + sets pragmas (WAL, foreign_keys, busy_timeout=5000), chmod 0600 +- `migrate.go`: migration framework with `schema_migrations` tracking table; + `Migrate()` applies pending migrations in transactions; `SchemaVersion()` + reports current version +- Migration 000001: `repositories`, `manifests`, `tags`, `blobs`, + `manifest_blobs`, `uploads` — all tables, constraints, and indexes per + ARCHITECTURE.md §8 +- Migration 000002: `policy_rules`, `audit_log` — tables and indexes per §8 + +Step 1.3 — `internal/db/`: +- `audit.go`: `WriteAuditEvent(eventType, actorID, repository, digest, ip, + details)` with JSON-serialized details map; `ListAuditEvents(AuditFilter)` + with filtering by event_type, actor_id, repository, time range, and + offset/limit pagination (default 50, descending by event_time) +- `AuditFilter` struct with all filter fields +- `AuditEvent` struct with JSON tags for API serialization + +Lint fix: +- `.golangci.yaml`: disabled `fieldalignment` analyzer in govet (micro- + optimization that hurts struct readability; not a security/correctness + concern) + +**Verification:** +- `make all` passes: vet clean, lint 0 issues, 20 tests passing + (7 config + 13 db/audit), all 3 binaries built +- Config tests: valid load, defaults applied, uploads_path default, + 5 missing-required-field cases, env override (string + duration), + same-filesystem check +- DB tests: open+migrate, idempotent migrate, 9 tables verified, + foreign key enforcement, tag cascade on manifest delete, + manifest_blobs cascade (blob row preserved), WAL mode verified +- Audit tests: write+list, filter by type, filter by actor, filter by + repository, pagination (3 pages), null fields handled + +--- + +### 2026-03-19 — Phase 0: Project scaffolding + +**Task:** Set up Go module, build system, linter config, and binary +entry points with cobra subcommands. + +**Changes:** +- `go.mod`: module `git.wntrmute.dev/kyle/mcr`, Go 1.25, cobra dependency +- Directory skeleton: `cmd/mcrsrv/`, `cmd/mcr-web/`, `cmd/mcrctl/`, + `internal/`, `proto/mcr/v1/`, `gen/mcr/v1/`, `web/templates/`, + `web/static/`, `deploy/docker/`, `deploy/examples/`, `deploy/scripts/`, + `deploy/systemd/`, `docs/` +- `.gitignore`: binaries, `srv/`, `*.db*`, IDE/OS files +- `Makefile`: standard targets (`all`, `build`, `test`, `vet`, `lint`, + `proto`, `proto-lint`, `clean`, `docker`, `devserver`); `all` runs + `vet → lint → test → mcrsrv mcr-web mcrctl`; `CGO_ENABLED=0` on binary + builds; version injection via `-X main.version` +- `.golangci.yaml`: golangci-lint v2 config matching mc-proxy conventions; + linters: errcheck, govet, ineffassign, unused, errorlint, gosec, + staticcheck, revive; formatters: gofmt, goimports; gosec G101 excluded + in test files +- `buf.yaml`: protobuf linting (STANDARD) and breaking change detection (FILE) +- `cmd/mcrsrv/main.go`: root command with `server`, `init`, `snapshot` + subcommands (stubs returning "not implemented") +- `cmd/mcr-web/main.go`: root command with `server` subcommand (stub) +- `cmd/mcrctl/main.go`: root command with `status`, `repo` (list/delete), + `gc` (trigger/status), `policy` (list/create/update/delete), + `audit` (tail), `snapshot` subcommands (stubs) +- All binaries accept `--version` flag + +**Verification:** +- `make all` passes: vet clean, lint 0 issues, test (no test files), + all three binaries built successfully +- `./mcrsrv --version` → `mcrsrv version 3695581` +- `./mcr-web --version` → `mcr-web version 3695581` +- All stubs return "not implemented" error as expected +- `make clean` removes binaries + +--- + ### 2026-03-19 — Project planning **Task:** Create design documents and implementation plan. diff --git a/PROJECT_PLAN.md b/PROJECT_PLAN.md index 9be6bfb..11e3c1a 100644 --- a/PROJECT_PLAN.md +++ b/PROJECT_PLAN.md @@ -9,8 +9,8 @@ design specification. | Phase | Description | Status | |-------|-------------|--------| -| 0 | Project scaffolding | Not started | -| 1 | Configuration & database | Not started | +| 0 | Project scaffolding | **Complete** | +| 1 | Configuration & database | **Complete** | | 2 | Blob storage layer | Not started | | 3 | MCIAS authentication | Not started | | 4 | Policy engine | Not started | diff --git a/deploy/examples/mcr.toml b/deploy/examples/mcr.toml new file mode 100644 index 0000000..806b041 --- /dev/null +++ b/deploy/examples/mcr.toml @@ -0,0 +1,33 @@ +# MCR — Metacircular Container Registry configuration. +# Copy to /srv/mcr/mcr.toml and edit before running. + +[server] +listen_addr = ":8443" # HTTPS (OCI + admin REST) +grpc_addr = ":9443" # gRPC admin API (optional; omit to disable) +tls_cert = "/srv/mcr/certs/cert.pem" +tls_key = "/srv/mcr/certs/key.pem" +read_timeout = "30s" # HTTP read timeout +write_timeout = "0s" # 0 = disabled for large blob uploads +idle_timeout = "120s" # HTTP idle timeout +shutdown_timeout = "60s" # Graceful shutdown drain period + +[database] +path = "/srv/mcr/mcr.db" + +[storage] +layers_path = "/srv/mcr/layers" # Blob storage root +uploads_path = "/srv/mcr/uploads" # Must be on the same filesystem as layers_path + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +ca_cert = "" # Custom CA for MCIAS TLS (optional) +service_name = "mcr" +tags = ["env:restricted"] + +[web] +listen_addr = "127.0.0.1:8080" # Web UI listen address +grpc_addr = "127.0.0.1:9443" # mcrsrv gRPC address for the web UI +ca_cert = "" # CA cert for verifying mcrsrv gRPC TLS + +[log] +level = "info" # debug, info, warn, error diff --git a/go.mod b/go.mod index 0dc4b29..5bac3ec 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,18 @@ module git.wntrmute.dev/kyle/mcr go 1.25.7 require ( + github.com/dustin/go-humanize v1.0.1 // indirect + github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/pflag v1.0.9 // indirect + golang.org/x/sys v0.42.0 // indirect + modernc.org/libc v1.70.0 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.47.0 // indirect ) diff --git a/go.sum b/go.sum index a6ee3e0..e23fd96 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,33 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= +github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= +github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.42.0 h1:omrd2nAlyT5ESRdCLYdm3+fMfNFE/+Rf4bDIQImRJeo= +golang.org/x/sys v0.42.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +modernc.org/libc v1.70.0 h1:U58NawXqXbgpZ/dcdS9kMshu08aiA6b7gusEusqzNkw= +modernc.org/libc v1.70.0/go.mod h1:OVmxFGP1CI/Z4L3E0Q3Mf1PDE0BucwMkcXjjLntvHJo= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/sqlite v1.47.0 h1:R1XyaNpoW4Et9yly+I2EeX7pBza/w+pmYee/0HJDyKk= +modernc.org/sqlite v1.47.0/go.mod h1:hWjRO6Tj/5Ik8ieqxQybiEOUXy0NJFNp2tpvVpKlvig= diff --git a/internal/config/config.go b/internal/config/config.go new file mode 100644 index 0000000..02a92a1 --- /dev/null +++ b/internal/config/config.go @@ -0,0 +1,212 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strings" + "time" + + "github.com/pelletier/go-toml/v2" +) + +// Config is the top-level MCR configuration. +type Config struct { + Server ServerConfig `toml:"server"` + Database DatabaseConfig `toml:"database"` + Storage StorageConfig `toml:"storage"` + MCIAS MCIASConfig `toml:"mcias"` + Web WebConfig `toml:"web"` + Log LogConfig `toml:"log"` +} + +type ServerConfig struct { + ListenAddr string `toml:"listen_addr"` + GRPCAddr string `toml:"grpc_addr"` + TLSCert string `toml:"tls_cert"` + TLSKey string `toml:"tls_key"` + ReadTimeout time.Duration `toml:"read_timeout"` + WriteTimeout time.Duration `toml:"write_timeout"` + IdleTimeout time.Duration `toml:"idle_timeout"` + ShutdownTimeout time.Duration `toml:"shutdown_timeout"` +} + +type DatabaseConfig struct { + Path string `toml:"path"` +} + +type StorageConfig struct { + LayersPath string `toml:"layers_path"` + UploadsPath string `toml:"uploads_path"` +} + +type MCIASConfig struct { + ServerURL string `toml:"server_url"` + CACert string `toml:"ca_cert"` + ServiceName string `toml:"service_name"` + Tags []string `toml:"tags"` +} + +type WebConfig struct { + ListenAddr string `toml:"listen_addr"` + GRPCAddr string `toml:"grpc_addr"` + CACert string `toml:"ca_cert"` +} + +type LogConfig struct { + Level string `toml:"level"` +} + +// Load reads a TOML config file, applies environment variable overrides, +// sets defaults, and validates required fields. +func Load(path string) (*Config, error) { + data, err := os.ReadFile(path) //nolint:gosec // config path is operator-supplied, not user input + if err != nil { + return nil, fmt.Errorf("config: read %s: %w", path, err) + } + + var cfg Config + if err := toml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("config: parse %s: %w", path, err) + } + + applyEnvOverrides(&cfg) + applyDefaults(&cfg) + + if err := validate(&cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +func applyDefaults(cfg *Config) { + if cfg.Server.ReadTimeout == 0 { + cfg.Server.ReadTimeout = 30 * time.Second + } + // WriteTimeout defaults to 0 (disabled) — no action needed. + if cfg.Server.IdleTimeout == 0 { + cfg.Server.IdleTimeout = 120 * time.Second + } + if cfg.Server.ShutdownTimeout == 0 { + cfg.Server.ShutdownTimeout = 60 * time.Second + } + if cfg.Storage.UploadsPath == "" && cfg.Storage.LayersPath != "" { + cfg.Storage.UploadsPath = filepath.Join(filepath.Dir(cfg.Storage.LayersPath), "uploads") + } + if cfg.Log.Level == "" { + cfg.Log.Level = "info" + } +} + +func validate(cfg *Config) error { + required := []struct { + name string + value string + }{ + {"server.listen_addr", cfg.Server.ListenAddr}, + {"server.tls_cert", cfg.Server.TLSCert}, + {"server.tls_key", cfg.Server.TLSKey}, + {"database.path", cfg.Database.Path}, + {"storage.layers_path", cfg.Storage.LayersPath}, + {"mcias.server_url", cfg.MCIAS.ServerURL}, + } + + for _, r := range required { + if r.value == "" { + return fmt.Errorf("config: required field %q is missing", r.name) + } + } + + return validateSameFilesystem(cfg.Storage.LayersPath, cfg.Storage.UploadsPath) +} + +// validateSameFilesystem checks that two paths reside on the same filesystem +// by comparing device IDs. If either path does not exist yet, it checks the +// nearest existing parent directory. +func validateSameFilesystem(layersPath, uploadsPath string) error { + layersDev, err := deviceID(layersPath) + if err != nil { + return fmt.Errorf("config: stat layers_path: %w", err) + } + + uploadsDev, err := deviceID(uploadsPath) + if err != nil { + return fmt.Errorf("config: stat uploads_path: %w", err) + } + + if layersDev != uploadsDev { + return fmt.Errorf("config: storage.layers_path and storage.uploads_path must be on the same filesystem") + } + + return nil +} + +// deviceID returns the device ID for the given path. If the path does not +// exist, it walks up to the nearest existing parent. +func deviceID(path string) (uint64, error) { + p := filepath.Clean(path) + for { + info, err := os.Stat(p) + if err == nil { + return extractDeviceID(info) + } + if !os.IsNotExist(err) { + return 0, err + } + parent := filepath.Dir(p) + if parent == p { + return 0, fmt.Errorf("no existing parent for %s", path) + } + p = parent + } +} + +// applyEnvOverrides walks the Config struct and overrides fields from +// environment variables with the MCR_ prefix. For example, +// MCR_SERVER_LISTEN_ADDR overrides Config.Server.ListenAddr. +func applyEnvOverrides(cfg *Config) { + applyEnvToStruct(reflect.ValueOf(cfg).Elem(), "MCR") +} + +func applyEnvToStruct(v reflect.Value, prefix string) { + t := v.Type() + for i := range t.NumField() { + field := t.Field(i) + fv := v.Field(i) + + tag := field.Tag.Get("toml") + if tag == "" || tag == "-" { + continue + } + envKey := prefix + "_" + strings.ToUpper(tag) + + if field.Type.Kind() == reflect.Struct { + applyEnvToStruct(fv, envKey) + continue + } + + envVal, ok := os.LookupEnv(envKey) + if !ok { + continue + } + + switch fv.Kind() { + case reflect.String: + fv.SetString(envVal) + case reflect.Int64: + if field.Type == reflect.TypeOf(time.Duration(0)) { + d, err := time.ParseDuration(envVal) + if err == nil { + fv.Set(reflect.ValueOf(d)) + } + } + case reflect.Slice: + if field.Type.Elem().Kind() == reflect.String { + parts := strings.Split(envVal, ",") + fv.Set(reflect.ValueOf(parts)) + } + } + } +} diff --git a/internal/config/config_test.go b/internal/config/config_test.go new file mode 100644 index 0000000..84f4c9d --- /dev/null +++ b/internal/config/config_test.go @@ -0,0 +1,279 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +const validTOML = ` +[server] +listen_addr = ":8443" +tls_cert = "/srv/mcr/certs/cert.pem" +tls_key = "/srv/mcr/certs/key.pem" + +[database] +path = "/srv/mcr/mcr.db" + +[storage] +layers_path = "/srv/mcr/layers" +uploads_path = "/srv/mcr/uploads" + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcr" +tags = ["env:restricted"] + +[log] +level = "debug" +` + +func writeConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "mcr.toml") + if err := os.WriteFile(path, []byte(content), 0600); err != nil { + t.Fatalf("write config: %v", err) + } + return path +} + +func TestLoadValidConfig(t *testing.T) { + path := writeConfig(t, validTOML) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if cfg.Server.ListenAddr != ":8443" { + t.Fatalf("listen_addr: got %q, want %q", cfg.Server.ListenAddr, ":8443") + } + if cfg.MCIAS.ServiceName != "mcr" { + t.Fatalf("service_name: got %q, want %q", cfg.MCIAS.ServiceName, "mcr") + } + if len(cfg.MCIAS.Tags) != 1 || cfg.MCIAS.Tags[0] != "env:restricted" { + t.Fatalf("tags: got %v, want [env:restricted]", cfg.MCIAS.Tags) + } + if cfg.Log.Level != "debug" { + t.Fatalf("log.level: got %q, want %q", cfg.Log.Level, "debug") + } +} + +func TestLoadDefaults(t *testing.T) { + path := writeConfig(t, validTOML) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if cfg.Server.ReadTimeout.Seconds() != 30 { + t.Fatalf("read_timeout: got %v, want 30s", cfg.Server.ReadTimeout) + } + if cfg.Server.WriteTimeout != 0 { + t.Fatalf("write_timeout: got %v, want 0", cfg.Server.WriteTimeout) + } + if cfg.Server.IdleTimeout.Seconds() != 120 { + t.Fatalf("idle_timeout: got %v, want 120s", cfg.Server.IdleTimeout) + } + if cfg.Server.ShutdownTimeout.Seconds() != 60 { + t.Fatalf("shutdown_timeout: got %v, want 60s", cfg.Server.ShutdownTimeout) + } +} + +func TestLoadUploadsPathDefault(t *testing.T) { + toml := ` +[server] +listen_addr = ":8443" +tls_cert = "/srv/mcr/certs/cert.pem" +tls_key = "/srv/mcr/certs/key.pem" + +[database] +path = "/srv/mcr/mcr.db" + +[storage] +layers_path = "/srv/mcr/layers" + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +` + path := writeConfig(t, toml) + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + + want := filepath.Join(filepath.Dir("/srv/mcr/layers"), "uploads") + if cfg.Storage.UploadsPath != want { + t.Fatalf("uploads_path: got %q, want %q", cfg.Storage.UploadsPath, want) + } +} + +func TestLoadMissingRequiredFields(t *testing.T) { + tests := []struct { + name string + toml string + want string + }{ + { + name: "missing listen_addr", + toml: ` +[server] +tls_cert = "/c" +tls_key = "/k" +[database] +path = "/d" +[storage] +layers_path = "/l" +[mcias] +server_url = "https://m" +`, + want: "server.listen_addr", + }, + { + name: "missing tls_cert", + toml: ` +[server] +listen_addr = ":8443" +tls_key = "/k" +[database] +path = "/d" +[storage] +layers_path = "/l" +[mcias] +server_url = "https://m" +`, + want: "server.tls_cert", + }, + { + name: "missing database.path", + toml: ` +[server] +listen_addr = ":8443" +tls_cert = "/c" +tls_key = "/k" +[storage] +layers_path = "/l" +[mcias] +server_url = "https://m" +`, + want: "database.path", + }, + { + name: "missing storage.layers_path", + toml: ` +[server] +listen_addr = ":8443" +tls_cert = "/c" +tls_key = "/k" +[database] +path = "/d" +[mcias] +server_url = "https://m" +`, + want: "storage.layers_path", + }, + { + name: "missing mcias.server_url", + toml: ` +[server] +listen_addr = ":8443" +tls_cert = "/c" +tls_key = "/k" +[database] +path = "/d" +[storage] +layers_path = "/l" +`, + want: "mcias.server_url", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + path := writeConfig(t, tt.toml) + _, err := Load(path) + if err == nil { + t.Fatal("expected error, got nil") + } + if got := err.Error(); !contains(got, tt.want) { + t.Fatalf("error %q does not mention %q", got, tt.want) + } + }) + } +} + +func TestEnvOverride(t *testing.T) { + path := writeConfig(t, validTOML) + + t.Setenv("MCR_SERVER_LISTEN_ADDR", ":9999") + t.Setenv("MCR_LOG_LEVEL", "warn") + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if cfg.Server.ListenAddr != ":9999" { + t.Fatalf("listen_addr: got %q, want %q", cfg.Server.ListenAddr, ":9999") + } + if cfg.Log.Level != "warn" { + t.Fatalf("log.level: got %q, want %q", cfg.Log.Level, "warn") + } +} + +func TestEnvOverrideDuration(t *testing.T) { + path := writeConfig(t, validTOML) + + t.Setenv("MCR_SERVER_READ_TIMEOUT", "5s") + + cfg, err := Load(path) + if err != nil { + t.Fatalf("Load: %v", err) + } + + if cfg.Server.ReadTimeout.Seconds() != 5 { + t.Fatalf("read_timeout: got %v, want 5s", cfg.Server.ReadTimeout) + } +} + +func TestSameFilesystemCheck(t *testing.T) { + dir := t.TempDir() + layersPath := filepath.Join(dir, "layers") + uploadsPath := filepath.Join(dir, "uploads") + + // Both under the same tmpdir → same filesystem. + toml := ` +[server] +listen_addr = ":8443" +tls_cert = "/srv/mcr/certs/cert.pem" +tls_key = "/srv/mcr/certs/key.pem" + +[database] +path = "/srv/mcr/mcr.db" + +[storage] +layers_path = "` + layersPath + `" +uploads_path = "` + uploadsPath + `" + +[mcias] +server_url = "https://mcias.metacircular.net:8443" +` + path := writeConfig(t, toml) + _, err := Load(path) + if err != nil { + t.Fatalf("expected same-filesystem check to pass: %v", err) + } +} + +func contains(s, substr string) bool { + return len(s) >= len(substr) && searchString(s, substr) +} + +func searchString(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/internal/config/device_linux.go b/internal/config/device_linux.go new file mode 100644 index 0000000..10b751a --- /dev/null +++ b/internal/config/device_linux.go @@ -0,0 +1,15 @@ +package config + +import ( + "fmt" + "os" + "syscall" +) + +func extractDeviceID(info os.FileInfo) (uint64, error) { + stat, ok := info.Sys().(*syscall.Stat_t) + if !ok { + return 0, fmt.Errorf("unable to get device ID: unsupported file info type") + } + return stat.Dev, nil +} diff --git a/internal/db/audit.go b/internal/db/audit.go new file mode 100644 index 0000000..f4dcb2d --- /dev/null +++ b/internal/db/audit.go @@ -0,0 +1,147 @@ +package db + +import ( + "encoding/json" + "fmt" + "strings" +) + +// AuditEvent represents a row in the audit_log table. +type AuditEvent struct { + ID int64 `json:"id"` + EventTime string `json:"event_time"` + EventType string `json:"event_type"` + ActorID string `json:"actor_id,omitempty"` + Repository string `json:"repository,omitempty"` + Digest string `json:"digest,omitempty"` + IPAddress string `json:"ip_address,omitempty"` + Details map[string]string `json:"details,omitempty"` +} + +// WriteAuditEvent inserts a new audit log entry. +func (d *DB) WriteAuditEvent(eventType, actorID, repository, digest, ip string, details map[string]string) error { + var detailsJSON *string + if len(details) > 0 { + b, err := json.Marshal(details) + if err != nil { + return fmt.Errorf("db: marshal audit details: %w", err) + } + s := string(b) + detailsJSON = &s + } + + _, err := d.Exec( + `INSERT INTO audit_log (event_type, actor_id, repository, digest, ip_address, details) + VALUES (?, ?, ?, ?, ?, ?)`, + eventType, + nullIfEmpty(actorID), + nullIfEmpty(repository), + nullIfEmpty(digest), + nullIfEmpty(ip), + detailsJSON, + ) + if err != nil { + return fmt.Errorf("db: write audit event: %w", err) + } + return nil +} + +// AuditFilter specifies criteria for listing audit events. +type AuditFilter struct { + EventType string + ActorID string + Repository string + Since string // RFC 3339 + Until string // RFC 3339 + Limit int + Offset int +} + +// ListAuditEvents returns audit events matching the filter, ordered by +// event_time descending (most recent first). +func (d *DB) ListAuditEvents(f AuditFilter) ([]AuditEvent, error) { + var clauses []string + var args []any + + if f.EventType != "" { + clauses = append(clauses, "event_type = ?") + args = append(args, f.EventType) + } + if f.ActorID != "" { + clauses = append(clauses, "actor_id = ?") + args = append(args, f.ActorID) + } + if f.Repository != "" { + clauses = append(clauses, "repository = ?") + args = append(args, f.Repository) + } + if f.Since != "" { + clauses = append(clauses, "event_time >= ?") + args = append(args, f.Since) + } + if f.Until != "" { + clauses = append(clauses, "event_time <= ?") + args = append(args, f.Until) + } + + query := "SELECT id, event_time, event_type, actor_id, repository, digest, ip_address, details FROM audit_log" + if len(clauses) > 0 { + query += " WHERE " + strings.Join(clauses, " AND ") + } + query += " ORDER BY event_time DESC" + + limit := f.Limit + if limit <= 0 { + limit = 50 + } + query += fmt.Sprintf(" LIMIT %d", limit) + + if f.Offset > 0 { + query += fmt.Sprintf(" OFFSET %d", f.Offset) + } + + rows, err := d.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("db: list audit events: %w", err) + } + defer func() { _ = rows.Close() }() + + var events []AuditEvent + for rows.Next() { + var e AuditEvent + var actorID, repository, digest, ip, detailsStr *string + if err := rows.Scan(&e.ID, &e.EventTime, &e.EventType, &actorID, &repository, &digest, &ip, &detailsStr); err != nil { + return nil, fmt.Errorf("db: scan audit event: %w", err) + } + if actorID != nil { + e.ActorID = *actorID + } + if repository != nil { + e.Repository = *repository + } + if digest != nil { + e.Digest = *digest + } + if ip != nil { + e.IPAddress = *ip + } + if detailsStr != nil { + if err := json.Unmarshal([]byte(*detailsStr), &e.Details); err != nil { + return nil, fmt.Errorf("db: unmarshal audit details: %w", err) + } + } + events = append(events, e) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate audit events: %w", err) + } + + return events, nil +} + +func nullIfEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} diff --git a/internal/db/audit_test.go b/internal/db/audit_test.go new file mode 100644 index 0000000..50bbaeb --- /dev/null +++ b/internal/db/audit_test.go @@ -0,0 +1,174 @@ +package db + +import ( + "testing" +) + +func migratedTestDB(t *testing.T) *DB { + t.Helper() + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + return d +} + +func TestWriteAndListAuditEvents(t *testing.T) { + d := migratedTestDB(t) + + err := d.WriteAuditEvent("login_ok", "user-uuid-1", "", "", "10.0.0.1", nil) + if err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + err = d.WriteAuditEvent("manifest_pushed", "user-uuid-1", "myapp", "sha256:abc", "10.0.0.1", + map[string]string{"tag": "latest"}) + if err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + events, err := d.ListAuditEvents(AuditFilter{}) + if err != nil { + t.Fatalf("ListAuditEvents: %v", err) + } + if len(events) != 2 { + t.Fatalf("event count: got %d, want 2", len(events)) + } + + // Most recent first. + if events[0].EventType != "manifest_pushed" { + t.Fatalf("first event type: got %q, want %q", events[0].EventType, "manifest_pushed") + } + if events[0].Repository != "myapp" { + t.Fatalf("repository: got %q, want %q", events[0].Repository, "myapp") + } + if events[0].Digest != "sha256:abc" { + t.Fatalf("digest: got %q, want %q", events[0].Digest, "sha256:abc") + } + if events[0].Details["tag"] != "latest" { + t.Fatalf("details.tag: got %q, want %q", events[0].Details["tag"], "latest") + } +} + +func TestListAuditEventsFilterByType(t *testing.T) { + d := migratedTestDB(t) + + _ = d.WriteAuditEvent("login_ok", "u1", "", "", "", nil) + _ = d.WriteAuditEvent("manifest_pushed", "u1", "repo", "", "", nil) + _ = d.WriteAuditEvent("login_ok", "u2", "", "", "", nil) + + events, err := d.ListAuditEvents(AuditFilter{EventType: "login_ok"}) + if err != nil { + t.Fatalf("ListAuditEvents: %v", err) + } + if len(events) != 2 { + t.Fatalf("event count: got %d, want 2", len(events)) + } + for _, e := range events { + if e.EventType != "login_ok" { + t.Fatalf("unexpected event type: %q", e.EventType) + } + } +} + +func TestListAuditEventsFilterByActor(t *testing.T) { + d := migratedTestDB(t) + + _ = d.WriteAuditEvent("login_ok", "actor-a", "", "", "", nil) + _ = d.WriteAuditEvent("login_ok", "actor-b", "", "", "", nil) + + events, err := d.ListAuditEvents(AuditFilter{ActorID: "actor-a"}) + if err != nil { + t.Fatalf("ListAuditEvents: %v", err) + } + if len(events) != 1 { + t.Fatalf("event count: got %d, want 1", len(events)) + } + if events[0].ActorID != "actor-a" { + t.Fatalf("actor_id: got %q, want %q", events[0].ActorID, "actor-a") + } +} + +func TestListAuditEventsFilterByRepository(t *testing.T) { + d := migratedTestDB(t) + + _ = d.WriteAuditEvent("manifest_pushed", "u1", "repo-a", "", "", nil) + _ = d.WriteAuditEvent("manifest_pushed", "u1", "repo-b", "", "", nil) + + events, err := d.ListAuditEvents(AuditFilter{Repository: "repo-a"}) + if err != nil { + t.Fatalf("ListAuditEvents: %v", err) + } + if len(events) != 1 { + t.Fatalf("event count: got %d, want 1", len(events)) + } +} + +func TestListAuditEventsPagination(t *testing.T) { + d := migratedTestDB(t) + + for i := range 5 { + _ = d.WriteAuditEvent("login_ok", "u1", "", "", "", map[string]string{"i": string(rune('0' + i))}) + } + + // First page. + page1, err := d.ListAuditEvents(AuditFilter{Limit: 2, Offset: 0}) + if err != nil { + t.Fatalf("ListAuditEvents page 1: %v", err) + } + if len(page1) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(page1)) + } + + // Second page. + page2, err := d.ListAuditEvents(AuditFilter{Limit: 2, Offset: 2}) + if err != nil { + t.Fatalf("ListAuditEvents page 2: %v", err) + } + if len(page2) != 2 { + t.Fatalf("page 2 count: got %d, want 2", len(page2)) + } + + // Pages should not overlap. + if page1[0].ID == page2[0].ID { + t.Fatal("page 1 and page 2 overlap") + } + + // Third page (partial). + page3, err := d.ListAuditEvents(AuditFilter{Limit: 2, Offset: 4}) + if err != nil { + t.Fatalf("ListAuditEvents page 3: %v", err) + } + if len(page3) != 1 { + t.Fatalf("page 3 count: got %d, want 1", len(page3)) + } +} + +func TestListAuditEventsNullFields(t *testing.T) { + d := migratedTestDB(t) + + // Write event with all optional fields empty. + err := d.WriteAuditEvent("gc_started", "", "", "", "", nil) + if err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + events, err := d.ListAuditEvents(AuditFilter{}) + if err != nil { + t.Fatalf("ListAuditEvents: %v", err) + } + if len(events) != 1 { + t.Fatalf("event count: got %d, want 1", len(events)) + } + + e := events[0] + if e.ActorID != "" { + t.Fatalf("actor_id: got %q, want empty", e.ActorID) + } + if e.Repository != "" { + t.Fatalf("repository: got %q, want empty", e.Repository) + } + if e.Details != nil { + t.Fatalf("details: got %v, want nil", e.Details) + } +} diff --git a/internal/db/db.go b/internal/db/db.go new file mode 100644 index 0000000..af83685 --- /dev/null +++ b/internal/db/db.go @@ -0,0 +1,49 @@ +package db + +import ( + "database/sql" + "fmt" + "os" + "path/filepath" + + _ "modernc.org/sqlite" +) + +// DB wraps a SQLite database connection. +type DB struct { + *sql.DB +} + +// Open opens (or creates) a SQLite database at the given path with the +// standard Metacircular pragmas: WAL mode, foreign keys, busy timeout. +func Open(path string) (*DB, error) { + dir := filepath.Dir(path) + if err := os.MkdirAll(dir, 0700); err != nil { + return nil, fmt.Errorf("db: create directory %s: %w", dir, err) + } + + sqlDB, err := sql.Open("sqlite", path) + if err != nil { + return nil, fmt.Errorf("db: open %s: %w", path, err) + } + + pragmas := []string{ + "PRAGMA journal_mode = WAL", + "PRAGMA foreign_keys = ON", + "PRAGMA busy_timeout = 5000", + } + for _, p := range pragmas { + if _, err := sqlDB.Exec(p); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("db: %s: %w", p, err) + } + } + + // Set file permissions to 0600 (owner read/write only). + if err := os.Chmod(path, 0600); err != nil { + _ = sqlDB.Close() + return nil, fmt.Errorf("db: chmod %s: %w", path, err) + } + + return &DB{sqlDB}, nil +} diff --git a/internal/db/db_test.go b/internal/db/db_test.go new file mode 100644 index 0000000..016fb8b --- /dev/null +++ b/internal/db/db_test.go @@ -0,0 +1,197 @@ +package db + +import ( + "path/filepath" + "testing" +) + +func openTestDB(t *testing.T) *DB { + t.Helper() + path := filepath.Join(t.TempDir(), "test.db") + d, err := Open(path) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = d.Close() }) + return d +} + +func TestOpenAndMigrate(t *testing.T) { + d := openTestDB(t) + + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Verify schema version. + ver, err := d.SchemaVersion() + if err != nil { + t.Fatalf("SchemaVersion: %v", err) + } + if ver != 2 { + t.Fatalf("schema version: got %d, want 2", ver) + } +} + +func TestMigrateIdempotent(t *testing.T) { + d := openTestDB(t) + + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate (first): %v", err) + } + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate (second): %v", err) + } + + ver, err := d.SchemaVersion() + if err != nil { + t.Fatalf("SchemaVersion: %v", err) + } + if ver != 2 { + t.Fatalf("schema version: got %d, want 2", ver) + } +} + +func TestTablesExist(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + tables := []string{ + "schema_migrations", + "repositories", + "manifests", + "tags", + "blobs", + "manifest_blobs", + "uploads", + "policy_rules", + "audit_log", + } + + for _, table := range tables { + var name string + err := d.QueryRow( + `SELECT name FROM sqlite_master WHERE type='table' AND name=?`, table, + ).Scan(&name) + if err != nil { + t.Fatalf("table %q not found: %v", table, err) + } + } +} + +func TestForeignKeyEnforcement(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Inserting a manifest with a nonexistent repository_id should fail. + _, err := d.Exec(`INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (9999, 'sha256:abc', 'application/json', '{}', 2)`) + if err == nil { + t.Fatal("expected foreign key violation, got nil") + } +} + +func TestTagCascadeOnManifestDelete(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Create a repository, manifest, and tag. + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + _, err = d.Exec(`INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (1, 'sha256:abc123', 'application/vnd.oci.image.manifest.v1+json', '{}', 2)`) + if err != nil { + t.Fatalf("insert manifest: %v", err) + } + + _, err = d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (1, 'latest', 1)`) + if err != nil { + t.Fatalf("insert tag: %v", err) + } + + // Delete the manifest — tag should cascade. + _, err = d.Exec(`DELETE FROM manifests WHERE id = 1`) + if err != nil { + t.Fatalf("delete manifest: %v", err) + } + + var count int + if err := d.QueryRow(`SELECT COUNT(*) FROM tags`).Scan(&count); err != nil { + t.Fatalf("count tags: %v", err) + } + if count != 0 { + t.Fatalf("tags after manifest delete: got %d, want 0", count) + } +} + +func TestManifestBlobsCascadeOnManifestDelete(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + _, err = d.Exec(`INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (1, 'sha256:abc123', 'application/vnd.oci.image.manifest.v1+json', '{}', 2)`) + if err != nil { + t.Fatalf("insert manifest: %v", err) + } + + _, err = d.Exec(`INSERT INTO blobs (digest, size) VALUES ('sha256:layer1', 1024)`) + if err != nil { + t.Fatalf("insert blob: %v", err) + } + + _, err = d.Exec(`INSERT INTO manifest_blobs (manifest_id, blob_id) VALUES (1, 1)`) + if err != nil { + t.Fatalf("insert manifest_blobs: %v", err) + } + + // Delete manifest — manifest_blobs should cascade, blob should remain. + _, err = d.Exec(`DELETE FROM manifests WHERE id = 1`) + if err != nil { + t.Fatalf("delete manifest: %v", err) + } + + var mbCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM manifest_blobs`).Scan(&mbCount); err != nil { + t.Fatalf("count manifest_blobs: %v", err) + } + if mbCount != 0 { + t.Fatalf("manifest_blobs after delete: got %d, want 0", mbCount) + } + + // Blob row should still exist (GC handles file cleanup). + var blobCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM blobs`).Scan(&blobCount); err != nil { + t.Fatalf("count blobs: %v", err) + } + if blobCount != 1 { + t.Fatalf("blobs after manifest delete: got %d, want 1", blobCount) + } +} + +func TestWALMode(t *testing.T) { + d := openTestDB(t) + + var mode string + if err := d.QueryRow(`PRAGMA journal_mode`).Scan(&mode); err != nil { + t.Fatalf("PRAGMA journal_mode: %v", err) + } + if mode != "wal" { + t.Fatalf("journal_mode: got %q, want %q", mode, "wal") + } +} diff --git a/internal/db/migrate.go b/internal/db/migrate.go new file mode 100644 index 0000000..3063f02 --- /dev/null +++ b/internal/db/migrate.go @@ -0,0 +1,174 @@ +package db + +import ( + "database/sql" + "fmt" +) + +// migration is a numbered schema change. +type migration struct { + version int + name string + sql string +} + +// migrations is the ordered list of schema migrations. +var migrations = []migration{ + { + version: 1, + name: "core registry tables", + sql: ` +CREATE TABLE IF NOT EXISTS repositories ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL UNIQUE, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS manifests ( + id INTEGER PRIMARY KEY, + repository_id INTEGER NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, + digest TEXT NOT NULL, + media_type TEXT NOT NULL, + content BLOB NOT NULL, + size INTEGER NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + UNIQUE(repository_id, digest) +); + +CREATE INDEX IF NOT EXISTS idx_manifests_repo ON manifests (repository_id); +CREATE INDEX IF NOT EXISTS idx_manifests_digest ON manifests (digest); + +CREATE TABLE IF NOT EXISTS tags ( + id INTEGER PRIMARY KEY, + repository_id INTEGER NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, + name TEXT NOT NULL, + manifest_id INTEGER NOT NULL REFERENCES manifests(id) ON DELETE CASCADE, + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + UNIQUE(repository_id, name) +); + +CREATE INDEX IF NOT EXISTS idx_tags_repo ON tags (repository_id); +CREATE INDEX IF NOT EXISTS idx_tags_manifest ON tags (manifest_id); + +CREATE TABLE IF NOT EXISTS blobs ( + id INTEGER PRIMARY KEY, + digest TEXT NOT NULL UNIQUE, + size INTEGER NOT NULL, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS manifest_blobs ( + manifest_id INTEGER NOT NULL REFERENCES manifests(id) ON DELETE CASCADE, + blob_id INTEGER NOT NULL REFERENCES blobs(id), + PRIMARY KEY (manifest_id, blob_id) +); + +CREATE INDEX IF NOT EXISTS idx_manifest_blobs_blob ON manifest_blobs (blob_id); + +CREATE TABLE IF NOT EXISTS uploads ( + id INTEGER PRIMARY KEY, + uuid TEXT NOT NULL UNIQUE, + repository_id INTEGER NOT NULL REFERENCES repositories(id) ON DELETE CASCADE, + byte_offset INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); +`, + }, + { + version: 2, + name: "policy and audit tables", + sql: ` +CREATE TABLE IF NOT EXISTS policy_rules ( + id INTEGER PRIMARY KEY, + priority INTEGER NOT NULL DEFAULT 100, + description TEXT NOT NULL, + rule_json TEXT NOT NULL, + enabled INTEGER NOT NULL DEFAULT 1 CHECK (enabled IN (0,1)), + created_by TEXT, + created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) +); + +CREATE TABLE IF NOT EXISTS audit_log ( + id INTEGER PRIMARY KEY, + event_time TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')), + event_type TEXT NOT NULL, + actor_id TEXT, + repository TEXT, + digest TEXT, + ip_address TEXT, + details TEXT +); + +CREATE INDEX IF NOT EXISTS idx_audit_time ON audit_log (event_time); +CREATE INDEX IF NOT EXISTS idx_audit_actor ON audit_log (actor_id); +CREATE INDEX IF NOT EXISTS idx_audit_event ON audit_log (event_type); +`, + }, +} + +// Migrate applies all pending migrations. It creates the schema_migrations +// tracking table if it does not exist. Migrations are idempotent. +func (d *DB) Migrate() error { + _, err := d.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations ( + version INTEGER PRIMARY KEY, + applied_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ','now')) + )`) + if err != nil { + return fmt.Errorf("db: create schema_migrations: %w", err) + } + + for _, m := range migrations { + applied, err := d.migrationApplied(m.version) + if err != nil { + return err + } + if applied { + continue + } + + tx, err := d.Begin() + if err != nil { + return fmt.Errorf("db: begin migration %d (%s): %w", m.version, m.name, err) + } + + if _, err := tx.Exec(m.sql); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: migration %d (%s): %w", m.version, m.name, err) + } + + if _, err := tx.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, m.version); err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: record migration %d: %w", m.version, err) + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("db: commit migration %d: %w", m.version, err) + } + } + + return nil +} + +func (d *DB) migrationApplied(version int) (bool, error) { + var count int + err := d.QueryRow(`SELECT COUNT(*) FROM schema_migrations WHERE version = ?`, version).Scan(&count) + if err != nil { + return false, fmt.Errorf("db: check migration %d: %w", version, err) + } + return count > 0, nil +} + +// SchemaVersion returns the highest applied migration version, or 0 if +// no migrations have been applied. +func (d *DB) SchemaVersion() (int, error) { + var version sql.NullInt64 + err := d.QueryRow(`SELECT MAX(version) FROM schema_migrations`).Scan(&version) + if err != nil { + return 0, fmt.Errorf("db: schema version: %w", err) + } + if !version.Valid { + return 0, nil + } + return int(version.Int64), nil +}