From dddc66f31b1125bfda2e118dd9437acc00182131 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Thu, 19 Mar 2026 18:25:18 -0700 Subject: [PATCH] Phases 5, 6, 8: OCI pull/push paths and admin REST API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 5 (OCI pull): internal/oci/ package with manifest GET/HEAD by tag/digest, blob GET/HEAD with repo membership check, tag listing with OCI pagination, catalog listing. Multi-segment repo names via parseOCIPath() right-split routing. DB query layer in internal/db/repository.go. Phase 6 (OCI push): blob uploads (monolithic and chunked) with uploadManager tracking in-progress BlobWriters, manifest push implementing full ARCHITECTURE.md §5 flow in a single SQLite transaction (create repo, upsert manifest, populate manifest_blobs, atomic tag move). Digest verification on both blob commit and manifest push-by-digest. Phase 8 (admin REST): /v1 endpoints for auth (login/logout/health), repository management (list/detail/delete), policy CRUD with engine reload, audit log listing with filters, GC trigger/status stubs. RequireAdmin middleware, platform-standard error format. Co-Authored-By: Claude Opus 4.6 (1M context) --- PROGRESS.md | 177 +++++++- PROJECT_PLAN.md | 6 +- internal/db/admin.go | 429 +++++++++++++++++++ internal/db/admin_test.go | 593 +++++++++++++++++++++++++++ internal/db/errors.go | 9 + internal/db/push.go | 154 +++++++ internal/db/push_test.go | 278 +++++++++++++ internal/db/repository.go | 154 +++++++ internal/db/repository_test.go | 429 +++++++++++++++++++ internal/db/upload.go | 80 ++++ internal/db/upload_test.go | 124 ++++++ internal/oci/blob.go | 98 +++++ internal/oci/blob_test.go | 144 +++++++ internal/oci/catalog.go | 68 +++ internal/oci/catalog_test.go | 124 ++++++ internal/oci/handler.go | 119 ++++++ internal/oci/handler_test.go | 321 +++++++++++++++ internal/oci/manifest.go | 222 ++++++++++ internal/oci/manifest_test.go | 187 +++++++++ internal/oci/ocierror.go | 23 ++ internal/oci/push_test.go | 266 ++++++++++++ internal/oci/routes.go | 171 ++++++++ internal/oci/routes_test.go | 141 +++++++ internal/oci/tags.go | 84 ++++ internal/oci/tags_test.go | 154 +++++++ internal/oci/upload.go | 241 +++++++++++ internal/oci/upload_test.go | 291 +++++++++++++ internal/server/admin.go | 52 +++ internal/server/admin_audit.go | 48 +++ internal/server/admin_audit_test.go | 152 +++++++ internal/server/admin_auth.go | 60 +++ internal/server/admin_auth_test.go | 116 ++++++ internal/server/admin_gc.go | 66 +++ internal/server/admin_gc_test.go | 94 +++++ internal/server/admin_policy.go | 343 ++++++++++++++++ internal/server/admin_policy_test.go | 337 +++++++++++++++ internal/server/admin_repo.go | 94 +++++ internal/server/admin_repo_test.go | 186 +++++++++ internal/server/admin_routes.go | 56 +++ internal/server/admin_test.go | 148 +++++++ 40 files changed, 6832 insertions(+), 7 deletions(-) create mode 100644 internal/db/admin.go create mode 100644 internal/db/admin_test.go create mode 100644 internal/db/errors.go create mode 100644 internal/db/push.go create mode 100644 internal/db/push_test.go create mode 100644 internal/db/repository.go create mode 100644 internal/db/repository_test.go create mode 100644 internal/db/upload.go create mode 100644 internal/db/upload_test.go create mode 100644 internal/oci/blob.go create mode 100644 internal/oci/blob_test.go create mode 100644 internal/oci/catalog.go create mode 100644 internal/oci/catalog_test.go create mode 100644 internal/oci/handler.go create mode 100644 internal/oci/handler_test.go create mode 100644 internal/oci/manifest.go create mode 100644 internal/oci/manifest_test.go create mode 100644 internal/oci/ocierror.go create mode 100644 internal/oci/push_test.go create mode 100644 internal/oci/routes.go create mode 100644 internal/oci/routes_test.go create mode 100644 internal/oci/tags.go create mode 100644 internal/oci/tags_test.go create mode 100644 internal/oci/upload.go create mode 100644 internal/oci/upload_test.go create mode 100644 internal/server/admin.go create mode 100644 internal/server/admin_audit.go create mode 100644 internal/server/admin_audit_test.go create mode 100644 internal/server/admin_auth.go create mode 100644 internal/server/admin_auth_test.go create mode 100644 internal/server/admin_gc.go create mode 100644 internal/server/admin_gc_test.go create mode 100644 internal/server/admin_policy.go create mode 100644 internal/server/admin_policy_test.go create mode 100644 internal/server/admin_repo.go create mode 100644 internal/server/admin_repo_test.go create mode 100644 internal/server/admin_routes.go create mode 100644 internal/server/admin_test.go diff --git a/PROGRESS.md b/PROGRESS.md index f543662..95a8eb7 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -6,7 +6,7 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ## Current State -**Phase:** 4 complete, ready for Batch B (Phase 5 + Phase 8) +**Phase:** 6 complete, ready for Phase 7 **Last updated:** 2026-03-19 ### Completed @@ -16,6 +16,9 @@ See `PROJECT_PLAN.md` for the implementation roadmap and - Phase 2: Blob storage layer (all 2 steps) - Phase 3: MCIAS authentication (all 4 steps) - Phase 4: Policy engine (all 4 steps) +- Phase 5: OCI pull path (all 5 steps) +- Phase 6: OCI push path (all 3 steps) +- Phase 8: Admin REST API (all 5 steps) - `ARCHITECTURE.md` — Full design specification (18 sections) - `CLAUDE.md` — AI development guidance - `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps) @@ -23,14 +26,180 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ### Next Steps -1. Batch B: Phase 5 (OCI pull) and Phase 8 (admin REST) — independent, - can be done in parallel -2. After Phase 5, Phase 6 (OCI push) then Phase 7 (OCI delete) +1. Phase 7 (OCI delete) +2. After Phase 7, Phase 9 (garbage collection) +3. Phase 10 (gRPC admin API) --- ## Log +### 2026-03-19 — Phase 6: OCI push path + +**Task:** Implement blob uploads (monolithic and chunked) and manifest +push per ARCHITECTURE.md §5 and OCI Distribution Spec. + +**Changes:** + +Step 6.1 — Blob upload initiation: +- `db/upload.go`: `UploadRow` type, `ErrUploadNotFound` sentinel; + `CreateUpload()`, `GetUpload()`, `UpdateUploadOffset()`, `DeleteUpload()` +- `db/push.go`: `GetOrCreateRepository()` (implicit repo creation on + first push), `BlobExists()`, `InsertBlob()` (INSERT OR IGNORE for + content-addressed dedup) +- `oci/upload.go`: `uploadManager` (sync.Mutex-protected map of in-progress + BlobWriters by UUID); `generateUUID()` via crypto/rand; `handleUploadInitiate()` + — policy check (registry:push), implicit repo creation, DB row + storage + temp file, returns 202 with Location/Docker-Upload-UUID/Range headers +- Extended `DBQuerier` interface with push/upload methods +- Changed `BlobOpener` to `BlobStore` adding `StartUpload(uuid)` method + +Step 6.2 — Blob upload chunked and monolithic: +- `oci/upload.go`: `handleUploadChunk()` (PATCH — append body, update offset), + `handleUploadComplete()` (PUT — optional final body, commit with digest + verification, insert blob row, cleanup upload, audit event), + `handleUploadStatus()` (GET — 204 with Range header), + `handleUploadCancel()` (DELETE — cancel BlobWriter, remove upload row) +- `oci/routes.go`: `parseOCIPath()` extended to handle `/blobs/uploads/` + and `/blobs/uploads/` patterns; `dispatchUpload()` routes by method + +Step 6.3 — Manifest push: +- `db/push.go`: `PushManifestParams` struct, `PushManifest()` — single + SQLite transaction: create repo if not exists, upsert manifest (ON + CONFLICT DO UPDATE), clear and repopulate manifest_blobs join table, + upsert tag if provided (atomic tag move) +- `oci/manifest.go`: `handleManifestPut()` — full push flow per §5: + parse JSON, compute SHA-256, verify digest if push-by-digest, collect + config+layer descriptors, verify all referenced blobs exist (400 + MANIFEST_BLOB_UNKNOWN if missing), call PushManifest(), audit event + with tag details, returns 201 with Docker-Content-Digest/Location/ + Content-Type headers + +**Verification:** +- `make all` passes: vet clean, lint 0 issues, 198 tests passing, + all 3 binaries built +- DB tests (15 new): upload CRUD (create/get/update/delete/not-found), + GetOrCreateRepository (new + existing), BlobExists (found/not-found), + InsertBlob (new + idempotent), PushManifest by tag (verify repo creation, + manifest, tag, manifest_blobs), by digest (no tag), tag move (atomic + update), idempotent re-push +- OCI upload tests (8 new): initiate (202, Location/UUID/Range headers, + implicit repo creation), unique UUIDs (5 initiates → 5 distinct UUIDs), + monolithic upload (POST → PUT with body → 201), chunked upload + (POST → PATCH → PATCH → PUT → 201), digest mismatch (400 DIGEST_INVALID), + upload status (GET → 204), cancel (DELETE → 204), nonexistent UUID + (PATCH → 404 BLOB_UPLOAD_UNKNOWN) +- OCI manifest push tests (7 new): push by tag (201, correct headers), + push by digest, digest mismatch (400), missing blob (400 + MANIFEST_BLOB_UNKNOWN), malformed JSON (400 MANIFEST_INVALID), empty + manifest, tag update (atomic move), re-push idempotent +- Route parsing tests (5 new subtests): upload initiate (trailing slash), + upload initiate (no trailing slash), upload with UUID, multi-segment + repo upload, multi-segment repo upload initiate + +--- + +### 2026-03-19 — Batch B: Phase 5 (OCI pull) + Phase 8 (admin REST) + +**Task:** Implement the OCI Distribution Spec pull path and the admin +REST API management endpoints. Both phases depend on Phase 4 (policy) +but not on each other — implemented in parallel. + +**Changes:** + +Phase 5 — `internal/db/` + `internal/oci/` (Steps 5.1–5.5): + +Step 5.1 — OCI handler scaffolding: +- `db/errors.go`: `ErrRepoNotFound`, `ErrManifestNotFound`, `ErrBlobNotFound` + sentinel errors +- `oci/handler.go`: `Handler` struct with `DBQuerier`, `BlobOpener`, + `PolicyEval`, `AuditFunc` interfaces; `NewHandler()` constructor; + `checkPolicy()` inline policy check; `audit()` helper +- `oci/ocierror.go`: `writeOCIError()` duplicated from server package + (15 lines, avoids coupling) +- `oci/routes.go`: `parseOCIPath()` splits from the right to handle + multi-segment repo names (e.g., `org/team/app/manifests/latest`); + `Router()` returns chi router with `/_catalog` and `/*` catch-all; + `dispatch()` routes to manifest/blob/tag handlers + +Step 5.2 — Manifest pull: +- `db/repository.go`: `ManifestRow` type, `GetRepositoryByName()`, + `GetManifestByTag()`, `GetManifestByDigest()` +- `oci/manifest.go`: GET/HEAD `/v2/{name}/manifests/{reference}`; + resolves by tag or digest; sets Content-Type, Docker-Content-Digest, + Content-Length headers + +Step 5.3 — Blob download: +- `db/repository.go`: `BlobExistsInRepo()` — joins blobs+manifest_blobs+ + manifests to verify blob belongs to repo +- `oci/blob.go`: GET/HEAD `/v2/{name}/blobs/{digest}`; validates blob + exists in repo before streaming from storage + +Step 5.4 — Tag listing: +- `db/repository.go`: `ListTags()` with cursor-based pagination + (after/limit) +- `oci/tags.go`: GET `/v2/{name}/tags/list` with OCI `?n=`/`?last=` + pagination and `Link` header for next page + +Step 5.5 — Catalog listing: +- `db/repository.go`: `ListRepositoryNames()` with cursor-based pagination +- `oci/catalog.go`: GET `/v2/_catalog` with same pagination pattern + +Phase 8 — `internal/db/` + `internal/server/` (Steps 8.1–8.5): + +Step 8.1 — Auth endpoints: +- `server/admin_auth.go`: POST `/v1/auth/login` (JSON body → MCIAS), + POST `/v1/auth/logout` (204 stub), GET `/v1/health` (no auth) + +Step 8.2 — Repository management: +- `db/admin.go`: `RepoMetadata`, `TagInfo`, `ManifestInfo`, `RepoDetail` + types; `ListRepositoriesWithMetadata()`, `GetRepositoryDetail()`, + `DeleteRepository()` +- `server/admin_repo.go`: GET/DELETE `/v1/repositories` and + `/v1/repositories/*` (wildcard for multi-segment names) + +Step 8.3 — Policy CRUD: +- `db/admin.go`: `PolicyRuleRow` type, `ErrPolicyRuleNotFound`; + `CreatePolicyRule()`, `GetPolicyRule()`, `ListPolicyRules()`, + `UpdatePolicyRule()`, `SetPolicyRuleEnabled()`, `DeletePolicyRule()` +- `server/admin_policy.go`: full CRUD on `/v1/policy/rules` and + `/v1/policy/rules/{id}`; `PolicyReloader` interface; input validation + (priority >= 1, valid effect/actions); mutations trigger engine reload + +Step 8.4 — Audit endpoint: +- `server/admin_audit.go`: GET `/v1/audit` with query parameter filters + (event_type, actor_id, repository, since, until, n, offset); + delegates to `db.ListAuditEvents` + +Step 8.5 — GC endpoints: +- `server/admin_gc.go`: `GCState` struct with `sync.Mutex`; POST `/v1/gc` + returns 202 (stub for Phase 9); GET `/v1/gc/status`; concurrent + trigger returns 409 + +Shared admin infrastructure: +- `server/admin.go`: `writeAdminError()` (platform `{"error":"..."}` + format), `writeJSON()`, `RequireAdmin()` middleware, `hasRole()` helper +- `server/admin_routes.go`: `AdminDeps` struct, `MountAdminRoutes()` + mounts all `/v1/*` endpoints with proper auth/admin middleware layers + +**Verification:** +- `make all` passes: vet clean, lint 0 issues, 168 tests passing, + all 3 binaries built +- Phase 5 (34 new tests): parseOCIPath (14 subtests covering simple/ + multi-segment/edge cases), manifest GET by tag/digest + HEAD + not + found, blob GET/HEAD + not in repo + repo not found, tags list + + pagination + empty + repo not found, catalog list + pagination + empty, + DB repository methods (15 tests covering all 6 query methods) +- Phase 8 (51 new tests): DB admin methods (19 tests covering CRUD, + pagination, cascade, not-found), admin auth (login ok/fail, health, + logout), admin repos (list/detail/delete/non-admin 403), admin policy + (full CRUD cycle, validation errors, non-admin 403, engine reload), + admin audit (list with filters, pagination), admin GC (trigger 202, + status, concurrent 409), RequireAdmin middleware (allowed/denied/ + no claims) + +--- + ### 2026-03-19 — Phase 4: Policy engine **Task:** Implement the registry-specific authorization engine with diff --git a/PROJECT_PLAN.md b/PROJECT_PLAN.md index 7e00e2e..3dd010b 100644 --- a/PROJECT_PLAN.md +++ b/PROJECT_PLAN.md @@ -14,10 +14,10 @@ design specification. | 2 | Blob storage layer | **Complete** | | 3 | MCIAS authentication | **Complete** | | 4 | Policy engine | **Complete** | -| 5 | OCI API — pull path | Not started | -| 6 | OCI API — push path | Not started | +| 5 | OCI API — pull path | **Complete** | +| 6 | OCI API — push path | **Complete** | | 7 | OCI API — delete path | Not started | -| 8 | Admin REST API | Not started | +| 8 | Admin REST API | **Complete** | | 9 | Garbage collection | Not started | | 10 | gRPC admin API | Not started | | 11 | CLI tool (mcrctl) | Not started | diff --git a/internal/db/admin.go b/internal/db/admin.go new file mode 100644 index 0000000..f24ec9e --- /dev/null +++ b/internal/db/admin.go @@ -0,0 +1,429 @@ +package db + +import ( + "database/sql" + "encoding/json" + "errors" + "fmt" + "strings" + "time" +) + +// ErrPolicyRuleNotFound is returned when a policy rule lookup finds no matching row. +var ErrPolicyRuleNotFound = errors.New("db: policy rule not found") + +// RepoMetadata is a repository with aggregate counts for listing. +type RepoMetadata struct { + Name string `json:"name"` + TagCount int `json:"tag_count"` + ManifestCount int `json:"manifest_count"` + TotalSize int64 `json:"total_size"` + CreatedAt string `json:"created_at"` +} + +// TagInfo is a tag with its manifest digest for repo detail. +type TagInfo struct { + Name string `json:"name"` + Digest string `json:"digest"` +} + +// ManifestInfo is a manifest summary for repo detail. +type ManifestInfo struct { + Digest string `json:"digest"` + MediaType string `json:"media_type"` + Size int64 `json:"size"` + CreatedAt string `json:"created_at"` +} + +// RepoDetail contains detailed info about a single repository. +type RepoDetail struct { + Name string `json:"name"` + Tags []TagInfo `json:"tags"` + Manifests []ManifestInfo `json:"manifests"` + TotalSize int64 `json:"total_size"` + CreatedAt string `json:"created_at"` +} + +// PolicyRuleRow represents a row from the policy_rules table with parsed JSON. +type PolicyRuleRow struct { + ID int64 `json:"id"` + Priority int `json:"priority"` + Description string `json:"description"` + Effect string `json:"effect"` + Roles []string `json:"roles,omitempty"` + AccountTypes []string `json:"account_types,omitempty"` + SubjectUUID string `json:"subject_uuid,omitempty"` + Actions []string `json:"actions"` + Repositories []string `json:"repositories,omitempty"` + Enabled bool `json:"enabled"` + CreatedBy string `json:"created_by,omitempty"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` +} + +// ListRepositoriesWithMetadata returns all repositories with tag count, +// manifest count, and total size. +func (d *DB) ListRepositoriesWithMetadata(limit, offset int) ([]RepoMetadata, error) { + if limit <= 0 { + limit = 50 + } + rows, err := d.Query( + `SELECT r.name, r.created_at, + (SELECT COUNT(*) FROM tags t WHERE t.repository_id = r.id) AS tag_count, + (SELECT COUNT(*) FROM manifests m WHERE m.repository_id = r.id) AS manifest_count, + COALESCE((SELECT SUM(m.size) FROM manifests m WHERE m.repository_id = r.id), 0) AS total_size + FROM repositories r + ORDER BY r.name ASC + LIMIT ? OFFSET ?`, + limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("db: list repositories: %w", err) + } + defer func() { _ = rows.Close() }() + + var repos []RepoMetadata + for rows.Next() { + var r RepoMetadata + if err := rows.Scan(&r.Name, &r.CreatedAt, &r.TagCount, &r.ManifestCount, &r.TotalSize); err != nil { + return nil, fmt.Errorf("db: scan repository: %w", err) + } + repos = append(repos, r) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate repositories: %w", err) + } + return repos, nil +} + +// GetRepositoryDetail returns detailed information about a repository. +func (d *DB) GetRepositoryDetail(name string) (*RepoDetail, error) { + var repoID int64 + var createdAt string + err := d.QueryRow(`SELECT id, created_at FROM repositories WHERE name = ?`, name).Scan(&repoID, &createdAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrRepoNotFound + } + return nil, fmt.Errorf("db: get repository: %w", err) + } + + detail := &RepoDetail{Name: name, CreatedAt: createdAt} + + // Tags with manifest digests. + tagRows, err := d.Query( + `SELECT t.name, m.digest + FROM tags t JOIN manifests m ON m.id = t.manifest_id + WHERE t.repository_id = ? + ORDER BY t.name ASC`, repoID, + ) + if err != nil { + return nil, fmt.Errorf("db: list repo tags: %w", err) + } + defer func() { _ = tagRows.Close() }() + + for tagRows.Next() { + var ti TagInfo + if err := tagRows.Scan(&ti.Name, &ti.Digest); err != nil { + return nil, fmt.Errorf("db: scan tag: %w", err) + } + detail.Tags = append(detail.Tags, ti) + } + if err := tagRows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate tags: %w", err) + } + if detail.Tags == nil { + detail.Tags = []TagInfo{} + } + + // Manifests. + mRows, err := d.Query( + `SELECT digest, media_type, size, created_at + FROM manifests WHERE repository_id = ? + ORDER BY created_at DESC`, repoID, + ) + if err != nil { + return nil, fmt.Errorf("db: list repo manifests: %w", err) + } + defer func() { _ = mRows.Close() }() + + for mRows.Next() { + var mi ManifestInfo + if err := mRows.Scan(&mi.Digest, &mi.MediaType, &mi.Size, &mi.CreatedAt); err != nil { + return nil, fmt.Errorf("db: scan manifest: %w", err) + } + detail.TotalSize += mi.Size + detail.Manifests = append(detail.Manifests, mi) + } + if err := mRows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate manifests: %w", err) + } + if detail.Manifests == nil { + detail.Manifests = []ManifestInfo{} + } + + return detail, nil +} + +// DeleteRepository deletes a repository and all its manifests, tags, and +// manifest_blobs. CASCADE handles the dependent rows. +func (d *DB) DeleteRepository(name string) error { + result, err := d.Exec(`DELETE FROM repositories WHERE name = ?`, name) + if err != nil { + return fmt.Errorf("db: delete repository: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: delete repository rows affected: %w", err) + } + if n == 0 { + return ErrRepoNotFound + } + return nil +} + +// CreatePolicyRule inserts a new policy rule and returns its ID. +func (d *DB) CreatePolicyRule(rule PolicyRuleRow) (int64, error) { + body := ruleBody{ + Effect: rule.Effect, + Roles: rule.Roles, + AccountTypes: rule.AccountTypes, + SubjectUUID: rule.SubjectUUID, + Actions: rule.Actions, + Repositories: rule.Repositories, + } + ruleJSON, err := json.Marshal(body) + if err != nil { + return 0, fmt.Errorf("db: marshal rule body: %w", err) + } + + enabled := 0 + if rule.Enabled { + enabled = 1 + } + + result, err := d.Exec( + `INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by) + VALUES (?, ?, ?, ?, ?)`, + rule.Priority, rule.Description, string(ruleJSON), enabled, nullIfEmpty(rule.CreatedBy), + ) + if err != nil { + return 0, fmt.Errorf("db: create policy rule: %w", err) + } + + id, err := result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("db: policy rule last insert id: %w", err) + } + return id, nil +} + +// GetPolicyRule returns a single policy rule by ID. +func (d *DB) GetPolicyRule(id int64) (*PolicyRuleRow, error) { + var row PolicyRuleRow + var ruleJSON string + var enabledInt int + var createdBy *string + + err := d.QueryRow( + `SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at + FROM policy_rules WHERE id = ?`, id, + ).Scan(&row.ID, &row.Priority, &row.Description, &ruleJSON, &enabledInt, &createdBy, &row.CreatedAt, &row.UpdatedAt) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrPolicyRuleNotFound + } + return nil, fmt.Errorf("db: get policy rule: %w", err) + } + + row.Enabled = enabledInt == 1 + if createdBy != nil { + row.CreatedBy = *createdBy + } + + var body ruleBody + if err := json.Unmarshal([]byte(ruleJSON), &body); err != nil { + return nil, fmt.Errorf("db: parse rule_json for rule %d: %w", id, err) + } + row.Effect = body.Effect + row.Roles = body.Roles + row.AccountTypes = body.AccountTypes + row.SubjectUUID = body.SubjectUUID + row.Actions = body.Actions + row.Repositories = body.Repositories + + return &row, nil +} + +// ListPolicyRules returns all policy rules ordered by priority ascending. +func (d *DB) ListPolicyRules(limit, offset int) ([]PolicyRuleRow, error) { + if limit <= 0 { + limit = 50 + } + rows, err := d.Query( + `SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at + FROM policy_rules + ORDER BY priority ASC + LIMIT ? OFFSET ?`, + limit, offset, + ) + if err != nil { + return nil, fmt.Errorf("db: list policy rules: %w", err) + } + defer func() { _ = rows.Close() }() + + var rules []PolicyRuleRow + for rows.Next() { + var row PolicyRuleRow + var ruleJSON string + var enabledInt int + var createdBy *string + + if err := rows.Scan(&row.ID, &row.Priority, &row.Description, &ruleJSON, &enabledInt, &createdBy, &row.CreatedAt, &row.UpdatedAt); err != nil { + return nil, fmt.Errorf("db: scan policy rule: %w", err) + } + + row.Enabled = enabledInt == 1 + if createdBy != nil { + row.CreatedBy = *createdBy + } + + var body ruleBody + if err := json.Unmarshal([]byte(ruleJSON), &body); err != nil { + return nil, fmt.Errorf("db: parse rule_json: %w", err) + } + row.Effect = body.Effect + row.Roles = body.Roles + row.AccountTypes = body.AccountTypes + row.SubjectUUID = body.SubjectUUID + row.Actions = body.Actions + row.Repositories = body.Repositories + + rules = append(rules, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate policy rules: %w", err) + } + return rules, nil +} + +// UpdatePolicyRule performs a partial update of a policy rule. +// Only non-zero/non-empty fields in the input are updated. +// Always updates updated_at. +func (d *DB) UpdatePolicyRule(id int64, updates PolicyRuleRow) error { + // First check the rule exists. + var exists int + err := d.QueryRow(`SELECT COUNT(*) FROM policy_rules WHERE id = ?`, id).Scan(&exists) + if err != nil { + return fmt.Errorf("db: check policy rule: %w", err) + } + if exists == 0 { + return ErrPolicyRuleNotFound + } + + var setClauses []string + var args []any + + if updates.Priority != 0 { + setClauses = append(setClauses, "priority = ?") + args = append(args, updates.Priority) + } + if updates.Description != "" { + setClauses = append(setClauses, "description = ?") + args = append(args, updates.Description) + } + + // If any rule body fields are set, rebuild the full rule_json. + // Read the current value first, apply the update, then write back. + if updates.Effect != "" || updates.Actions != nil || updates.Roles != nil || + updates.AccountTypes != nil || updates.Repositories != nil || updates.SubjectUUID != "" { + var currentJSON string + err := d.QueryRow(`SELECT rule_json FROM policy_rules WHERE id = ?`, id).Scan(¤tJSON) + if err != nil { + return fmt.Errorf("db: read current rule_json: %w", err) + } + + var body ruleBody + if err := json.Unmarshal([]byte(currentJSON), &body); err != nil { + return fmt.Errorf("db: parse current rule_json: %w", err) + } + + if updates.Effect != "" { + body.Effect = updates.Effect + } + if updates.Actions != nil { + body.Actions = updates.Actions + } + if updates.Roles != nil { + body.Roles = updates.Roles + } + if updates.AccountTypes != nil { + body.AccountTypes = updates.AccountTypes + } + if updates.Repositories != nil { + body.Repositories = updates.Repositories + } + if updates.SubjectUUID != "" { + body.SubjectUUID = updates.SubjectUUID + } + + newJSON, err := json.Marshal(body) + if err != nil { + return fmt.Errorf("db: marshal updated rule_json: %w", err) + } + setClauses = append(setClauses, "rule_json = ?") + args = append(args, string(newJSON)) + } + + // Always update updated_at. + setClauses = append(setClauses, "updated_at = ?") + args = append(args, time.Now().UTC().Format("2006-01-02T15:04:05Z")) + + query := fmt.Sprintf("UPDATE policy_rules SET %s WHERE id = ?", strings.Join(setClauses, ", ")) + args = append(args, id) + + _, err = d.Exec(query, args...) + if err != nil { + return fmt.Errorf("db: update policy rule: %w", err) + } + return nil +} + +// SetPolicyRuleEnabled sets the enabled flag for a policy rule. +func (d *DB) SetPolicyRuleEnabled(id int64, enabled bool) error { + enabledInt := 0 + if enabled { + enabledInt = 1 + } + result, err := d.Exec( + `UPDATE policy_rules SET enabled = ?, updated_at = ? WHERE id = ?`, + enabledInt, time.Now().UTC().Format("2006-01-02T15:04:05Z"), id, + ) + if err != nil { + return fmt.Errorf("db: set policy rule enabled: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: set policy rule enabled rows affected: %w", err) + } + if n == 0 { + return ErrPolicyRuleNotFound + } + return nil +} + +// DeletePolicyRule deletes a policy rule by ID. +func (d *DB) DeletePolicyRule(id int64) error { + result, err := d.Exec(`DELETE FROM policy_rules WHERE id = ?`, id) + if err != nil { + return fmt.Errorf("db: delete policy rule: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: delete policy rule rows affected: %w", err) + } + if n == 0 { + return ErrPolicyRuleNotFound + } + return nil +} diff --git a/internal/db/admin_test.go b/internal/db/admin_test.go new file mode 100644 index 0000000..7928cf2 --- /dev/null +++ b/internal/db/admin_test.go @@ -0,0 +1,593 @@ +package db + +import ( + "errors" + "testing" +) + +// seedAdminRepo inserts a repository with manifests, tags, and blobs for admin tests. +func seedAdminRepo(t *testing.T, d *DB, name string, tagNames []string) int64 { + t.Helper() + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name) + if err != nil { + t.Fatalf("insert repo %q: %v", name, err) + } + + var repoID int64 + if err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&repoID); err != nil { + t.Fatalf("select repo id: %v", err) + } + + _, err = d.Exec( + `INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (?, ?, 'application/vnd.oci.image.manifest.v1+json', '{}', 1024)`, + repoID, "sha256:aaa-"+name, + ) + if err != nil { + t.Fatalf("insert manifest for %q: %v", name, err) + } + + var manifestID int64 + if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil { + t.Fatalf("select manifest id: %v", err) + } + + for _, tag := range tagNames { + _, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`, + repoID, tag, manifestID) + if err != nil { + t.Fatalf("insert tag %q: %v", tag, err) + } + } + + return repoID +} + +func TestListRepositoriesWithMetadata(t *testing.T) { + d := migratedTestDB(t) + + seedAdminRepo(t, d, "alpha/app", []string{"latest", "v1.0"}) + seedAdminRepo(t, d, "bravo/lib", []string{"latest"}) + + repos, err := d.ListRepositoriesWithMetadata(50, 0) + if err != nil { + t.Fatalf("ListRepositoriesWithMetadata: %v", err) + } + if len(repos) != 2 { + t.Fatalf("repo count: got %d, want 2", len(repos)) + } + + // Ordered by name ASC. + if repos[0].Name != "alpha/app" { + t.Fatalf("first repo name: got %q, want %q", repos[0].Name, "alpha/app") + } + if repos[0].TagCount != 2 { + t.Fatalf("alpha/app tag count: got %d, want 2", repos[0].TagCount) + } + if repos[0].ManifestCount != 1 { + t.Fatalf("alpha/app manifest count: got %d, want 1", repos[0].ManifestCount) + } + if repos[0].TotalSize != 1024 { + t.Fatalf("alpha/app total size: got %d, want 1024", repos[0].TotalSize) + } + if repos[0].CreatedAt == "" { + t.Fatal("alpha/app created_at: expected non-empty") + } + + if repos[1].Name != "bravo/lib" { + t.Fatalf("second repo name: got %q, want %q", repos[1].Name, "bravo/lib") + } + if repos[1].TagCount != 1 { + t.Fatalf("bravo/lib tag count: got %d, want 1", repos[1].TagCount) + } +} + +func TestListRepositoriesWithMetadataEmpty(t *testing.T) { + d := migratedTestDB(t) + + repos, err := d.ListRepositoriesWithMetadata(50, 0) + if err != nil { + t.Fatalf("ListRepositoriesWithMetadata: %v", err) + } + if repos != nil { + t.Fatalf("expected nil repos, got %v", repos) + } +} + +func TestListRepositoriesWithMetadataPagination(t *testing.T) { + d := migratedTestDB(t) + + seedAdminRepo(t, d, "alpha/app", []string{"latest"}) + seedAdminRepo(t, d, "bravo/lib", []string{"latest"}) + seedAdminRepo(t, d, "charlie/svc", []string{"latest"}) + + repos, err := d.ListRepositoriesWithMetadata(2, 0) + if err != nil { + t.Fatalf("ListRepositoriesWithMetadata page 1: %v", err) + } + if len(repos) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(repos)) + } + if repos[0].Name != "alpha/app" { + t.Fatalf("page 1 first: got %q, want %q", repos[0].Name, "alpha/app") + } + + repos, err = d.ListRepositoriesWithMetadata(2, 2) + if err != nil { + t.Fatalf("ListRepositoriesWithMetadata page 2: %v", err) + } + if len(repos) != 1 { + t.Fatalf("page 2 count: got %d, want 1", len(repos)) + } + if repos[0].Name != "charlie/svc" { + t.Fatalf("page 2 first: got %q, want %q", repos[0].Name, "charlie/svc") + } +} + +func TestGetRepositoryDetail(t *testing.T) { + d := migratedTestDB(t) + + seedAdminRepo(t, d, "myorg/myapp", []string{"latest", "v1.0"}) + + detail, err := d.GetRepositoryDetail("myorg/myapp") + if err != nil { + t.Fatalf("GetRepositoryDetail: %v", err) + } + if detail.Name != "myorg/myapp" { + t.Fatalf("name: got %q, want %q", detail.Name, "myorg/myapp") + } + if detail.CreatedAt == "" { + t.Fatal("created_at: expected non-empty") + } + if len(detail.Tags) != 2 { + t.Fatalf("tag count: got %d, want 2", len(detail.Tags)) + } + // Tags ordered by name ASC. + if detail.Tags[0].Name != "latest" { + t.Fatalf("first tag: got %q, want %q", detail.Tags[0].Name, "latest") + } + if detail.Tags[0].Digest == "" { + t.Fatal("first tag digest: expected non-empty") + } + if detail.Tags[1].Name != "v1.0" { + t.Fatalf("second tag: got %q, want %q", detail.Tags[1].Name, "v1.0") + } + + if len(detail.Manifests) != 1 { + t.Fatalf("manifest count: got %d, want 1", len(detail.Manifests)) + } + if detail.Manifests[0].Size != 1024 { + t.Fatalf("manifest size: got %d, want 1024", detail.Manifests[0].Size) + } + if detail.TotalSize != 1024 { + t.Fatalf("total size: got %d, want 1024", detail.TotalSize) + } +} + +func TestGetRepositoryDetailNotFound(t *testing.T) { + d := migratedTestDB(t) + + _, err := d.GetRepositoryDetail("nonexistent/repo") + if !errors.Is(err, ErrRepoNotFound) { + t.Fatalf("expected ErrRepoNotFound, got %v", err) + } +} + +func TestGetRepositoryDetailEmptyRepo(t *testing.T) { + d := migratedTestDB(t) + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('empty/repo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + detail, err := d.GetRepositoryDetail("empty/repo") + if err != nil { + t.Fatalf("GetRepositoryDetail: %v", err) + } + if len(detail.Tags) != 0 { + t.Fatalf("tags: got %d, want 0", len(detail.Tags)) + } + if len(detail.Manifests) != 0 { + t.Fatalf("manifests: got %d, want 0", len(detail.Manifests)) + } + if detail.TotalSize != 0 { + t.Fatalf("total size: got %d, want 0", detail.TotalSize) + } +} + +func TestDeleteRepository(t *testing.T) { + d := migratedTestDB(t) + + seedAdminRepo(t, d, "myorg/myapp", []string{"latest"}) + + if err := d.DeleteRepository("myorg/myapp"); err != nil { + t.Fatalf("DeleteRepository: %v", err) + } + + // Verify it's gone. + _, err := d.GetRepositoryDetail("myorg/myapp") + if !errors.Is(err, ErrRepoNotFound) { + t.Fatalf("expected ErrRepoNotFound after delete, got %v", err) + } + + // Verify cascade: manifests and tags should be gone. + var manifestCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM manifests`).Scan(&manifestCount); err != nil { + t.Fatalf("count manifests: %v", err) + } + if manifestCount != 0 { + t.Fatalf("manifests after delete: got %d, want 0", manifestCount) + } + + var tagCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM tags`).Scan(&tagCount); err != nil { + t.Fatalf("count tags: %v", err) + } + if tagCount != 0 { + t.Fatalf("tags after delete: got %d, want 0", tagCount) + } +} + +func TestDeleteRepositoryNotFound(t *testing.T) { + d := migratedTestDB(t) + + err := d.DeleteRepository("nonexistent/repo") + if !errors.Is(err, ErrRepoNotFound) { + t.Fatalf("expected ErrRepoNotFound, got %v", err) + } +} + +func TestCreatePolicyRule(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 50, + Description: "allow CI push", + Effect: "allow", + Roles: []string{"ci"}, + Actions: []string{"registry:push", "registry:pull"}, + Repositories: []string{"production/*"}, + Enabled: true, + CreatedBy: "admin-uuid", + } + + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + if id == 0 { + t.Fatal("expected non-zero ID") + } + + got, err := d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if got.Priority != 50 { + t.Fatalf("priority: got %d, want 50", got.Priority) + } + if got.Description != "allow CI push" { + t.Fatalf("description: got %q, want %q", got.Description, "allow CI push") + } + if got.Effect != "allow" { + t.Fatalf("effect: got %q, want %q", got.Effect, "allow") + } + if len(got.Roles) != 1 || got.Roles[0] != "ci" { + t.Fatalf("roles: got %v, want [ci]", got.Roles) + } + if len(got.Actions) != 2 { + t.Fatalf("actions: got %d, want 2", len(got.Actions)) + } + if len(got.Repositories) != 1 || got.Repositories[0] != "production/*" { + t.Fatalf("repositories: got %v, want [production/*]", got.Repositories) + } + if !got.Enabled { + t.Fatal("enabled: got false, want true") + } + if got.CreatedBy != "admin-uuid" { + t.Fatalf("created_by: got %q, want %q", got.CreatedBy, "admin-uuid") + } + if got.CreatedAt == "" { + t.Fatal("created_at: expected non-empty") + } + if got.UpdatedAt == "" { + t.Fatal("updated_at: expected non-empty") + } +} + +func TestCreatePolicyRuleDisabled(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 10, + Description: "disabled rule", + Effect: "deny", + Actions: []string{"registry:delete"}, + Enabled: false, + } + + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + + got, err := d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if got.Enabled { + t.Fatal("enabled: got true, want false") + } +} + +func TestGetPolicyRuleNotFound(t *testing.T) { + d := migratedTestDB(t) + + _, err := d.GetPolicyRule(9999) + if !errors.Is(err, ErrPolicyRuleNotFound) { + t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err) + } +} + +func TestListPolicyRules(t *testing.T) { + d := migratedTestDB(t) + + // Insert rules with different priorities (out of order). + rule1 := PolicyRuleRow{ + Priority: 50, + Description: "rule A", + Effect: "allow", + Actions: []string{"registry:pull"}, + Enabled: true, + } + rule2 := PolicyRuleRow{ + Priority: 10, + Description: "rule B", + Effect: "deny", + Actions: []string{"registry:delete"}, + Enabled: true, + } + rule3 := PolicyRuleRow{ + Priority: 30, + Description: "rule C", + Effect: "allow", + Actions: []string{"registry:push"}, + Enabled: false, + } + + if _, err := d.CreatePolicyRule(rule1); err != nil { + t.Fatalf("CreatePolicyRule 1: %v", err) + } + if _, err := d.CreatePolicyRule(rule2); err != nil { + t.Fatalf("CreatePolicyRule 2: %v", err) + } + if _, err := d.CreatePolicyRule(rule3); err != nil { + t.Fatalf("CreatePolicyRule 3: %v", err) + } + + rules, err := d.ListPolicyRules(50, 0) + if err != nil { + t.Fatalf("ListPolicyRules: %v", err) + } + if len(rules) != 3 { + t.Fatalf("rule count: got %d, want 3", len(rules)) + } + + // Should be ordered by priority ASC: 10, 30, 50. + if rules[0].Priority != 10 { + t.Fatalf("first rule priority: got %d, want 10", rules[0].Priority) + } + if rules[0].Description != "rule B" { + t.Fatalf("first rule description: got %q, want %q", rules[0].Description, "rule B") + } + if rules[1].Priority != 30 { + t.Fatalf("second rule priority: got %d, want 30", rules[1].Priority) + } + if rules[2].Priority != 50 { + t.Fatalf("third rule priority: got %d, want 50", rules[2].Priority) + } + + // Verify enabled flags. + if !rules[0].Enabled { + t.Fatal("rule B enabled: got false, want true") + } + if rules[1].Enabled { + t.Fatal("rule C enabled: got true, want false") + } +} + +func TestListPolicyRulesEmpty(t *testing.T) { + d := migratedTestDB(t) + + rules, err := d.ListPolicyRules(50, 0) + if err != nil { + t.Fatalf("ListPolicyRules: %v", err) + } + if rules != nil { + t.Fatalf("expected nil rules, got %v", rules) + } +} + +func TestUpdatePolicyRule(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 50, + Description: "original", + Effect: "allow", + Actions: []string{"registry:pull"}, + Enabled: true, + } + + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + + // Update priority and description. + updates := PolicyRuleRow{ + Priority: 25, + Description: "updated", + } + if err := d.UpdatePolicyRule(id, updates); err != nil { + t.Fatalf("UpdatePolicyRule: %v", err) + } + + got, err := d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if got.Priority != 25 { + t.Fatalf("priority: got %d, want 25", got.Priority) + } + if got.Description != "updated" { + t.Fatalf("description: got %q, want %q", got.Description, "updated") + } + // Effect should be unchanged. + if got.Effect != "allow" { + t.Fatalf("effect: got %q, want %q (unchanged)", got.Effect, "allow") + } + // Actions should be unchanged. + if len(got.Actions) != 1 || got.Actions[0] != "registry:pull" { + t.Fatalf("actions: got %v, want [registry:pull] (unchanged)", got.Actions) + } +} + +func TestUpdatePolicyRuleBody(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 50, + Description: "test", + Effect: "allow", + Actions: []string{"registry:pull"}, + Enabled: true, + } + + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + + // Update rule body fields. + updates := PolicyRuleRow{ + Effect: "deny", + Actions: []string{"registry:push", "registry:delete"}, + Roles: []string{"ci"}, + } + if err := d.UpdatePolicyRule(id, updates); err != nil { + t.Fatalf("UpdatePolicyRule: %v", err) + } + + got, err := d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if got.Effect != "deny" { + t.Fatalf("effect: got %q, want %q", got.Effect, "deny") + } + if len(got.Actions) != 2 { + t.Fatalf("actions: got %d, want 2", len(got.Actions)) + } + if len(got.Roles) != 1 || got.Roles[0] != "ci" { + t.Fatalf("roles: got %v, want [ci]", got.Roles) + } +} + +func TestUpdatePolicyRuleNotFound(t *testing.T) { + d := migratedTestDB(t) + + err := d.UpdatePolicyRule(9999, PolicyRuleRow{Description: "nope"}) + if !errors.Is(err, ErrPolicyRuleNotFound) { + t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err) + } +} + +func TestSetPolicyRuleEnabled(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 50, + Description: "test", + Effect: "allow", + Actions: []string{"registry:pull"}, + Enabled: true, + } + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + + // Disable the rule. + if err := d.SetPolicyRuleEnabled(id, false); err != nil { + t.Fatalf("SetPolicyRuleEnabled(false): %v", err) + } + + got, err := d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if got.Enabled { + t.Fatal("enabled: got true, want false") + } + + // Re-enable. + if err := d.SetPolicyRuleEnabled(id, true); err != nil { + t.Fatalf("SetPolicyRuleEnabled(true): %v", err) + } + + got, err = d.GetPolicyRule(id) + if err != nil { + t.Fatalf("GetPolicyRule: %v", err) + } + if !got.Enabled { + t.Fatal("enabled: got false, want true") + } +} + +func TestSetPolicyRuleEnabledNotFound(t *testing.T) { + d := migratedTestDB(t) + + err := d.SetPolicyRuleEnabled(9999, true) + if !errors.Is(err, ErrPolicyRuleNotFound) { + t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err) + } +} + +func TestDeletePolicyRule(t *testing.T) { + d := migratedTestDB(t) + + rule := PolicyRuleRow{ + Priority: 50, + Description: "to delete", + Effect: "allow", + Actions: []string{"registry:pull"}, + Enabled: true, + } + + id, err := d.CreatePolicyRule(rule) + if err != nil { + t.Fatalf("CreatePolicyRule: %v", err) + } + + if err := d.DeletePolicyRule(id); err != nil { + t.Fatalf("DeletePolicyRule: %v", err) + } + + // Verify it's gone. + _, err = d.GetPolicyRule(id) + if !errors.Is(err, ErrPolicyRuleNotFound) { + t.Fatalf("expected ErrPolicyRuleNotFound after delete, got %v", err) + } +} + +func TestDeletePolicyRuleNotFound(t *testing.T) { + d := migratedTestDB(t) + + err := d.DeletePolicyRule(9999) + if !errors.Is(err, ErrPolicyRuleNotFound) { + t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err) + } +} diff --git a/internal/db/errors.go b/internal/db/errors.go new file mode 100644 index 0000000..b2cf84b --- /dev/null +++ b/internal/db/errors.go @@ -0,0 +1,9 @@ +package db + +import "errors" + +var ( + ErrRepoNotFound = errors.New("db: repository not found") + ErrManifestNotFound = errors.New("db: manifest not found") + ErrBlobNotFound = errors.New("db: blob not found") +) diff --git a/internal/db/push.go b/internal/db/push.go new file mode 100644 index 0000000..7129c48 --- /dev/null +++ b/internal/db/push.go @@ -0,0 +1,154 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" +) + +// GetOrCreateRepository returns the repository ID for the given name, +// creating it if it does not exist (implicit creation on first push). +func (d *DB) GetOrCreateRepository(name string) (int64, error) { + var id int64 + err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&id) + if err == nil { + return id, nil + } + if !errors.Is(err, sql.ErrNoRows) { + return 0, fmt.Errorf("db: get repository: %w", err) + } + + result, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name) + if err != nil { + return 0, fmt.Errorf("db: create repository: %w", err) + } + id, err = result.LastInsertId() + if err != nil { + return 0, fmt.Errorf("db: repository last insert id: %w", err) + } + return id, nil +} + +// BlobExists checks whether a blob with the given digest exists in the blobs table. +func (d *DB) BlobExists(digest string) (bool, error) { + var count int + err := d.QueryRow(`SELECT COUNT(*) FROM blobs WHERE digest = ?`, digest).Scan(&count) + if err != nil { + return false, fmt.Errorf("db: blob exists: %w", err) + } + return count > 0, nil +} + +// InsertBlob inserts a blob row if it does not already exist. +// Returns without error if the blob already exists (content-addressed dedup). +func (d *DB) InsertBlob(digest string, size int64) error { + _, err := d.Exec( + `INSERT OR IGNORE INTO blobs (digest, size) VALUES (?, ?)`, + digest, size, + ) + if err != nil { + return fmt.Errorf("db: insert blob: %w", err) + } + return nil +} + +// PushManifestParams holds the parameters for a manifest push operation. +type PushManifestParams struct { + RepoName string + Digest string + MediaType string + Content []byte + Size int64 + Tag string // empty if push-by-digest + BlobDigests []string // referenced blob digests +} + +// PushManifest executes the full manifest push in a single transaction per +// ARCHITECTURE.md §5. It creates the repository if needed, inserts/updates +// the manifest, populates manifest_blobs, and updates the tag if provided. +func (d *DB) PushManifest(p PushManifestParams) error { + tx, err := d.Begin() + if err != nil { + return fmt.Errorf("db: begin push manifest: %w", err) + } + + // Step a: create repository if not exists. + var repoID int64 + err = tx.QueryRow(`SELECT id FROM repositories WHERE name = ?`, p.RepoName).Scan(&repoID) + if errors.Is(err, sql.ErrNoRows) { + result, insertErr := tx.Exec(`INSERT INTO repositories (name) VALUES (?)`, p.RepoName) + if insertErr != nil { + _ = tx.Rollback() + return fmt.Errorf("db: create repository: %w", insertErr) + } + repoID, err = result.LastInsertId() + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: repository last insert id: %w", err) + } + } else if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: get repository: %w", err) + } + + // Step b: insert or update manifest. + // Use INSERT OR REPLACE on the UNIQUE(repository_id, digest) constraint. + result, err := tx.Exec( + `INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(repository_id, digest) DO UPDATE SET + media_type = excluded.media_type, + content = excluded.content, + size = excluded.size`, + repoID, p.Digest, p.MediaType, p.Content, p.Size, + ) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: insert manifest: %w", err) + } + manifestID, err := result.LastInsertId() + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: manifest last insert id: %w", err) + } + + // Step c: populate manifest_blobs join table. + // Delete existing entries for this manifest first (in case of re-push). + _, err = tx.Exec(`DELETE FROM manifest_blobs WHERE manifest_id = ?`, manifestID) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: clear manifest_blobs: %w", err) + } + for _, blobDigest := range p.BlobDigests { + _, err = tx.Exec( + `INSERT INTO manifest_blobs (manifest_id, blob_id) + SELECT ?, id FROM blobs WHERE digest = ?`, + manifestID, blobDigest, + ) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: insert manifest_blob: %w", err) + } + } + + // Step d: if reference is a tag, insert or update tag row. + if p.Tag != "" { + _, err = tx.Exec( + `INSERT INTO tags (repository_id, name, manifest_id) + VALUES (?, ?, ?) + ON CONFLICT(repository_id, name) DO UPDATE SET + manifest_id = excluded.manifest_id, + updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')`, + repoID, p.Tag, manifestID, + ) + if err != nil { + _ = tx.Rollback() + return fmt.Errorf("db: upsert tag: %w", err) + } + } + + if err := tx.Commit(); err != nil { + return fmt.Errorf("db: commit push manifest: %w", err) + } + return nil +} diff --git a/internal/db/push_test.go b/internal/db/push_test.go new file mode 100644 index 0000000..d06086b --- /dev/null +++ b/internal/db/push_test.go @@ -0,0 +1,278 @@ +package db + +import "testing" + +func TestGetOrCreateRepositoryNew(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + id, err := d.GetOrCreateRepository("newrepo") + if err != nil { + t.Fatalf("GetOrCreateRepository: %v", err) + } + if id <= 0 { + t.Fatalf("id: got %d, want > 0", id) + } + + // Second call should return the same ID. + id2, err := d.GetOrCreateRepository("newrepo") + if err != nil { + t.Fatalf("GetOrCreateRepository (second): %v", err) + } + if id2 != id { + t.Fatalf("id mismatch: got %d, want %d", id2, id) + } +} + +func TestGetOrCreateRepositoryExisting(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('existing')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + id, err := d.GetOrCreateRepository("existing") + if err != nil { + t.Fatalf("GetOrCreateRepository: %v", err) + } + if id <= 0 { + t.Fatalf("id: got %d, want > 0", id) + } +} + +func TestBlobExists(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.Exec(`INSERT INTO blobs (digest, size) VALUES ('sha256:aaa', 100)`) + if err != nil { + t.Fatalf("insert blob: %v", err) + } + + exists, err := d.BlobExists("sha256:aaa") + if err != nil { + t.Fatalf("BlobExists: %v", err) + } + if !exists { + t.Fatal("expected blob to exist") + } + + exists, err = d.BlobExists("sha256:nonexistent") + if err != nil { + t.Fatalf("BlobExists (nonexistent): %v", err) + } + if exists { + t.Fatal("expected blob to not exist") + } +} + +func TestInsertBlob(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + if err := d.InsertBlob("sha256:bbb", 200); err != nil { + t.Fatalf("InsertBlob: %v", err) + } + + exists, err := d.BlobExists("sha256:bbb") + if err != nil { + t.Fatalf("BlobExists: %v", err) + } + if !exists { + t.Fatal("expected blob to exist after insert") + } + + // Insert again — should be a no-op (INSERT OR IGNORE). + if err := d.InsertBlob("sha256:bbb", 200); err != nil { + t.Fatalf("InsertBlob (dup): %v", err) + } +} + +func TestPushManifestByTag(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Insert blobs first. + if err := d.InsertBlob("sha256:config111", 50); err != nil { + t.Fatalf("insert config blob: %v", err) + } + if err := d.InsertBlob("sha256:layer111", 1000); err != nil { + t.Fatalf("insert layer blob: %v", err) + } + + content := []byte(`{"schemaVersion":2}`) + params := PushManifestParams{ + RepoName: "myrepo", + Digest: "sha256:manifest111", + MediaType: "application/vnd.oci.image.manifest.v1+json", + Content: content, + Size: int64(len(content)), + Tag: "latest", + BlobDigests: []string{"sha256:config111", "sha256:layer111"}, + } + + if err := d.PushManifest(params); err != nil { + t.Fatalf("PushManifest: %v", err) + } + + // Verify repository was created. + repoID, err := d.GetRepositoryByName("myrepo") + if err != nil { + t.Fatalf("GetRepositoryByName: %v", err) + } + if repoID <= 0 { + t.Fatalf("repo id: got %d, want > 0", repoID) + } + + // Verify manifest exists. + m, err := d.GetManifestByDigest(repoID, "sha256:manifest111") + if err != nil { + t.Fatalf("GetManifestByDigest: %v", err) + } + if m.MediaType != "application/vnd.oci.image.manifest.v1+json" { + t.Fatalf("media type: got %q", m.MediaType) + } + if m.Size != int64(len(content)) { + t.Fatalf("size: got %d, want %d", m.Size, len(content)) + } + + // Verify tag points to manifest. + m2, err := d.GetManifestByTag(repoID, "latest") + if err != nil { + t.Fatalf("GetManifestByTag: %v", err) + } + if m2.Digest != "sha256:manifest111" { + t.Fatalf("tag digest: got %q", m2.Digest) + } + + // Verify manifest_blobs join table. + var mbCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM manifest_blobs WHERE manifest_id = ?`, m.ID).Scan(&mbCount); err != nil { + t.Fatalf("count manifest_blobs: %v", err) + } + if mbCount != 2 { + t.Fatalf("manifest_blobs count: got %d, want 2", mbCount) + } +} + +func TestPushManifestByDigest(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + content := []byte(`{"schemaVersion":2}`) + params := PushManifestParams{ + RepoName: "myrepo", + Digest: "sha256:manifest222", + MediaType: "application/vnd.oci.image.manifest.v1+json", + Content: content, + Size: int64(len(content)), + Tag: "", // push by digest — no tag + } + + if err := d.PushManifest(params); err != nil { + t.Fatalf("PushManifest: %v", err) + } + + // Verify no tag was created. + var tagCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM tags`).Scan(&tagCount); err != nil { + t.Fatalf("count tags: %v", err) + } + if tagCount != 0 { + t.Fatalf("tag count: got %d, want 0", tagCount) + } +} + +func TestPushManifestTagMove(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Push first manifest with tag "latest". + content1 := []byte(`{"schemaVersion":2,"v":"1"}`) + if err := d.PushManifest(PushManifestParams{ + RepoName: "myrepo", + Digest: "sha256:first", + MediaType: "application/vnd.oci.image.manifest.v1+json", + Content: content1, + Size: int64(len(content1)), + Tag: "latest", + }); err != nil { + t.Fatalf("PushManifest (first): %v", err) + } + + // Push second manifest with same tag "latest" — should atomically move tag. + content2 := []byte(`{"schemaVersion":2,"v":"2"}`) + if err := d.PushManifest(PushManifestParams{ + RepoName: "myrepo", + Digest: "sha256:second", + MediaType: "application/vnd.oci.image.manifest.v1+json", + Content: content2, + Size: int64(len(content2)), + Tag: "latest", + }); err != nil { + t.Fatalf("PushManifest (second): %v", err) + } + + repoID, err := d.GetRepositoryByName("myrepo") + if err != nil { + t.Fatalf("GetRepositoryByName: %v", err) + } + + m, err := d.GetManifestByTag(repoID, "latest") + if err != nil { + t.Fatalf("GetManifestByTag: %v", err) + } + if m.Digest != "sha256:second" { + t.Fatalf("tag should point to second manifest, got %q", m.Digest) + } +} + +func TestPushManifestIdempotent(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + content := []byte(`{"schemaVersion":2}`) + params := PushManifestParams{ + RepoName: "myrepo", + Digest: "sha256:manifest333", + MediaType: "application/vnd.oci.image.manifest.v1+json", + Content: content, + Size: int64(len(content)), + Tag: "latest", + } + + // Push twice — should not fail. + if err := d.PushManifest(params); err != nil { + t.Fatalf("PushManifest (first): %v", err) + } + if err := d.PushManifest(params); err != nil { + t.Fatalf("PushManifest (second): %v", err) + } + + // Verify only one manifest exists. + var mCount int + if err := d.QueryRow(`SELECT COUNT(*) FROM manifests`).Scan(&mCount); err != nil { + t.Fatalf("count manifests: %v", err) + } + if mCount != 1 { + t.Fatalf("manifest count: got %d, want 1", mCount) + } +} diff --git a/internal/db/repository.go b/internal/db/repository.go new file mode 100644 index 0000000..b78bae3 --- /dev/null +++ b/internal/db/repository.go @@ -0,0 +1,154 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" +) + +// ManifestRow represents a manifest as stored in the database. +type ManifestRow struct { + ID int64 + RepositoryID int64 + Digest string + MediaType string + Content []byte + Size int64 +} + +// GetRepositoryByName returns the repository ID for the given name. +func (d *DB) GetRepositoryByName(name string) (int64, error) { + var id int64 + err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&id) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return 0, ErrRepoNotFound + } + return 0, fmt.Errorf("db: get repository by name: %w", err) + } + return id, nil +} + +// GetManifestByTag returns the manifest associated with the given tag in a repository. +func (d *DB) GetManifestByTag(repoID int64, tag string) (*ManifestRow, error) { + var m ManifestRow + err := d.QueryRow( + `SELECT m.id, m.repository_id, m.digest, m.media_type, m.content, m.size + FROM manifests m + JOIN tags t ON t.manifest_id = m.id + WHERE t.repository_id = ? AND t.name = ?`, + repoID, tag, + ).Scan(&m.ID, &m.RepositoryID, &m.Digest, &m.MediaType, &m.Content, &m.Size) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrManifestNotFound + } + return nil, fmt.Errorf("db: get manifest by tag: %w", err) + } + return &m, nil +} + +// GetManifestByDigest returns the manifest with the given digest in a repository. +func (d *DB) GetManifestByDigest(repoID int64, digest string) (*ManifestRow, error) { + var m ManifestRow + err := d.QueryRow( + `SELECT id, repository_id, digest, media_type, content, size + FROM manifests + WHERE repository_id = ? AND digest = ?`, + repoID, digest, + ).Scan(&m.ID, &m.RepositoryID, &m.Digest, &m.MediaType, &m.Content, &m.Size) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrManifestNotFound + } + return nil, fmt.Errorf("db: get manifest by digest: %w", err) + } + return &m, nil +} + +// BlobExistsInRepo checks whether a blob with the given digest exists and is +// referenced by at least one manifest in the given repository. +func (d *DB) BlobExistsInRepo(repoID int64, digest string) (bool, error) { + var count int + err := d.QueryRow( + `SELECT COUNT(*) FROM blobs b + JOIN manifest_blobs mb ON mb.blob_id = b.id + JOIN manifests m ON m.id = mb.manifest_id + WHERE m.repository_id = ? AND b.digest = ?`, + repoID, digest, + ).Scan(&count) + if err != nil { + return false, fmt.Errorf("db: blob exists in repo: %w", err) + } + return count > 0, nil +} + +// ListTags returns tag names for a repository, ordered alphabetically. +// Pagination is cursor-based: after is the last tag name from the previous page, +// limit is the maximum number of tags to return. +func (d *DB) ListTags(repoID int64, after string, limit int) ([]string, error) { + var query string + var args []any + + if after != "" { + query = `SELECT name FROM tags WHERE repository_id = ? AND name > ? ORDER BY name ASC LIMIT ?` + args = []any{repoID, after, limit} + } else { + query = `SELECT name FROM tags WHERE repository_id = ? ORDER BY name ASC LIMIT ?` + args = []any{repoID, limit} + } + + rows, err := d.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("db: list tags: %w", err) + } + defer func() { _ = rows.Close() }() + + var tags []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("db: scan tag: %w", err) + } + tags = append(tags, name) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate tags: %w", err) + } + return tags, nil +} + +// ListRepositoryNames returns repository names ordered alphabetically. +// Pagination is cursor-based: after is the last repo name from the previous page, +// limit is the maximum number of names to return. +func (d *DB) ListRepositoryNames(after string, limit int) ([]string, error) { + var query string + var args []any + + if after != "" { + query = `SELECT name FROM repositories WHERE name > ? ORDER BY name ASC LIMIT ?` + args = []any{after, limit} + } else { + query = `SELECT name FROM repositories ORDER BY name ASC LIMIT ?` + args = []any{limit} + } + + rows, err := d.Query(query, args...) + if err != nil { + return nil, fmt.Errorf("db: list repository names: %w", err) + } + defer func() { _ = rows.Close() }() + + var names []string + for rows.Next() { + var name string + if err := rows.Scan(&name); err != nil { + return nil, fmt.Errorf("db: scan repository name: %w", err) + } + names = append(names, name) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("db: iterate repository names: %w", err) + } + return names, nil +} diff --git a/internal/db/repository_test.go b/internal/db/repository_test.go new file mode 100644 index 0000000..8fab3d7 --- /dev/null +++ b/internal/db/repository_test.go @@ -0,0 +1,429 @@ +package db + +import ( + "errors" + "testing" +) + +// seedTestRepo inserts a repository, manifest, tag, blob, and manifest_blob +// link for use in repository query tests. It returns the repository ID. +func seedTestRepo(t *testing.T, d *DB) int64 { + t.Helper() + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('myorg/myapp')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + var repoID int64 + if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'myorg/myapp'`).Scan(&repoID); err != nil { + t.Fatalf("select repo id: %v", err) + } + + _, err = d.Exec( + `INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (?, 'sha256:aaaa', 'application/vnd.oci.image.manifest.v1+json', '{"layers":[]}', 15)`, + repoID, + ) + if err != nil { + t.Fatalf("insert manifest: %v", err) + } + + var manifestID int64 + if err := d.QueryRow(`SELECT id FROM manifests WHERE digest = 'sha256:aaaa'`).Scan(&manifestID); err != nil { + t.Fatalf("select manifest id: %v", err) + } + + _, err = d.Exec( + `INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, 'latest', ?)`, + repoID, manifestID, + ) + if err != nil { + t.Fatalf("insert tag: %v", err) + } + + _, err = d.Exec(`INSERT INTO blobs (digest, size) VALUES ('sha256:bbbb', 2048)`) + if err != nil { + t.Fatalf("insert blob: %v", err) + } + + var blobID int64 + if err := d.QueryRow(`SELECT id FROM blobs WHERE digest = 'sha256:bbbb'`).Scan(&blobID); err != nil { + t.Fatalf("select blob id: %v", err) + } + + _, err = d.Exec( + `INSERT INTO manifest_blobs (manifest_id, blob_id) VALUES (?, ?)`, + manifestID, blobID, + ) + if err != nil { + t.Fatalf("insert manifest_blob: %v", err) + } + + return repoID +} + +func TestGetRepositoryByName_Found(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + seedTestRepo(t, d) + + id, err := d.GetRepositoryByName("myorg/myapp") + if err != nil { + t.Fatalf("GetRepositoryByName: %v", err) + } + if id == 0 { + t.Fatal("expected non-zero repository ID") + } +} + +func TestGetRepositoryByName_NotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.GetRepositoryByName("nonexistent") + if !errors.Is(err, ErrRepoNotFound) { + t.Fatalf("expected ErrRepoNotFound, got %v", err) + } +} + +func TestGetManifestByTag_Found(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + m, err := d.GetManifestByTag(repoID, "latest") + if err != nil { + t.Fatalf("GetManifestByTag: %v", err) + } + if m.Digest != "sha256:aaaa" { + t.Fatalf("digest: got %q, want %q", m.Digest, "sha256:aaaa") + } + if m.MediaType != "application/vnd.oci.image.manifest.v1+json" { + t.Fatalf("media type: got %q, want OCI manifest", m.MediaType) + } + if m.Size != 15 { + t.Fatalf("size: got %d, want 15", m.Size) + } + if string(m.Content) != `{"layers":[]}` { + t.Fatalf("content: got %q, want {\"layers\":[]}", string(m.Content)) + } +} + +func TestGetManifestByTag_NotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + _, err := d.GetManifestByTag(repoID, "v0.0.0-nonexistent") + if !errors.Is(err, ErrManifestNotFound) { + t.Fatalf("expected ErrManifestNotFound, got %v", err) + } +} + +func TestGetManifestByDigest_Found(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + m, err := d.GetManifestByDigest(repoID, "sha256:aaaa") + if err != nil { + t.Fatalf("GetManifestByDigest: %v", err) + } + if m.Digest != "sha256:aaaa" { + t.Fatalf("digest: got %q, want %q", m.Digest, "sha256:aaaa") + } + if m.RepositoryID != repoID { + t.Fatalf("repository_id: got %d, want %d", m.RepositoryID, repoID) + } +} + +func TestGetManifestByDigest_NotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + _, err := d.GetManifestByDigest(repoID, "sha256:nonexistent") + if !errors.Is(err, ErrManifestNotFound) { + t.Fatalf("expected ErrManifestNotFound, got %v", err) + } +} + +func TestBlobExistsInRepo_Exists(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + exists, err := d.BlobExistsInRepo(repoID, "sha256:bbbb") + if err != nil { + t.Fatalf("BlobExistsInRepo: %v", err) + } + if !exists { + t.Fatal("expected blob to exist in repo") + } +} + +func TestBlobExistsInRepo_NotInThisRepo(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + seedTestRepo(t, d) // creates blob sha256:bbbb in myorg/myapp + + // Create a second repo with no manifests linking the blob. + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('other/repo')`) + if err != nil { + t.Fatalf("insert other repo: %v", err) + } + var otherRepoID int64 + if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'other/repo'`).Scan(&otherRepoID); err != nil { + t.Fatalf("select other repo id: %v", err) + } + + exists, err := d.BlobExistsInRepo(otherRepoID, "sha256:bbbb") + if err != nil { + t.Fatalf("BlobExistsInRepo: %v", err) + } + if exists { + t.Fatal("expected blob to NOT exist in other repo") + } +} + +func TestBlobExistsInRepo_BlobDoesNotExist(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + exists, err := d.BlobExistsInRepo(repoID, "sha256:nonexistent") + if err != nil { + t.Fatalf("BlobExistsInRepo: %v", err) + } + if exists { + t.Fatal("expected blob to not exist") + } +} + +func TestListTags_WithTags(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + // Add more tags pointing to the same manifest. + var manifestID int64 + if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil { + t.Fatalf("select manifest id: %v", err) + } + + for _, tag := range []string{"v1.0", "v2.0", "beta"} { + _, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`, + repoID, tag, manifestID) + if err != nil { + t.Fatalf("insert tag %q: %v", tag, err) + } + } + + tags, err := d.ListTags(repoID, "", 100) + if err != nil { + t.Fatalf("ListTags: %v", err) + } + + // Expect alphabetical: beta, latest, v1.0, v2.0 + want := []string{"beta", "latest", "v1.0", "v2.0"} + if len(tags) != len(want) { + t.Fatalf("tags count: got %d, want %d", len(tags), len(want)) + } + for i, tag := range tags { + if tag != want[i] { + t.Fatalf("tags[%d]: got %q, want %q", i, tag, want[i]) + } + } +} + +func TestListTags_Pagination(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + repoID := seedTestRepo(t, d) + + var manifestID int64 + if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil { + t.Fatalf("select manifest id: %v", err) + } + + for _, tag := range []string{"v1.0", "v2.0", "beta"} { + _, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`, + repoID, tag, manifestID) + if err != nil { + t.Fatalf("insert tag %q: %v", tag, err) + } + } + + // First page: 2 tags starting from beginning. + tags, err := d.ListTags(repoID, "", 2) + if err != nil { + t.Fatalf("ListTags page 1: %v", err) + } + if len(tags) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(tags)) + } + if tags[0] != "beta" || tags[1] != "latest" { + t.Fatalf("page 1: got %v, want [beta, latest]", tags) + } + + // Second page: after "latest". + tags, err = d.ListTags(repoID, "latest", 2) + if err != nil { + t.Fatalf("ListTags page 2: %v", err) + } + if len(tags) != 2 { + t.Fatalf("page 2 count: got %d, want 2", len(tags)) + } + if tags[0] != "v1.0" || tags[1] != "v2.0" { + t.Fatalf("page 2: got %v, want [v1.0, v2.0]", tags) + } + + // Third page: after "v2.0" — no more tags. + tags, err = d.ListTags(repoID, "v2.0", 2) + if err != nil { + t.Fatalf("ListTags page 3: %v", err) + } + if len(tags) != 0 { + t.Fatalf("page 3 count: got %d, want 0", len(tags)) + } +} + +func TestListTags_Empty(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Create a repo with no tags. + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('empty/repo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + var repoID int64 + if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'empty/repo'`).Scan(&repoID); err != nil { + t.Fatalf("select repo id: %v", err) + } + + tags, err := d.ListTags(repoID, "", 100) + if err != nil { + t.Fatalf("ListTags: %v", err) + } + if tags != nil { + t.Fatalf("expected nil tags, got %v", tags) + } +} + +func TestListRepositoryNames_WithRepos(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + for _, name := range []string{"charlie/app", "alpha/lib", "bravo/svc"} { + _, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name) + if err != nil { + t.Fatalf("insert repo %q: %v", name, err) + } + } + + names, err := d.ListRepositoryNames("", 100) + if err != nil { + t.Fatalf("ListRepositoryNames: %v", err) + } + + want := []string{"alpha/lib", "bravo/svc", "charlie/app"} + if len(names) != len(want) { + t.Fatalf("names count: got %d, want %d", len(names), len(want)) + } + for i, n := range names { + if n != want[i] { + t.Fatalf("names[%d]: got %q, want %q", i, n, want[i]) + } + } +} + +func TestListRepositoryNames_Pagination(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + for _, name := range []string{"charlie/app", "alpha/lib", "bravo/svc"} { + _, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name) + if err != nil { + t.Fatalf("insert repo %q: %v", name, err) + } + } + + // First page: 2. + names, err := d.ListRepositoryNames("", 2) + if err != nil { + t.Fatalf("ListRepositoryNames page 1: %v", err) + } + if len(names) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(names)) + } + if names[0] != "alpha/lib" || names[1] != "bravo/svc" { + t.Fatalf("page 1: got %v", names) + } + + // Second page: after "bravo/svc". + names, err = d.ListRepositoryNames("bravo/svc", 2) + if err != nil { + t.Fatalf("ListRepositoryNames page 2: %v", err) + } + if len(names) != 1 { + t.Fatalf("page 2 count: got %d, want 1", len(names)) + } + if names[0] != "charlie/app" { + t.Fatalf("page 2: got %v", names) + } +} + +func TestListRepositoryNames_Empty(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + names, err := d.ListRepositoryNames("", 100) + if err != nil { + t.Fatalf("ListRepositoryNames: %v", err) + } + if names != nil { + t.Fatalf("expected nil names, got %v", names) + } +} diff --git a/internal/db/upload.go b/internal/db/upload.go new file mode 100644 index 0000000..b64425a --- /dev/null +++ b/internal/db/upload.go @@ -0,0 +1,80 @@ +package db + +import ( + "database/sql" + "errors" + "fmt" +) + +// ErrUploadNotFound indicates the requested upload UUID does not exist. +var ErrUploadNotFound = errors.New("db: upload not found") + +// UploadRow represents a row in the uploads table. +type UploadRow struct { + ID int64 + UUID string + RepositoryID int64 + ByteOffset int64 +} + +// CreateUpload inserts a new upload row and returns its ID. +func (d *DB) CreateUpload(uuid string, repoID int64) error { + _, err := d.Exec( + `INSERT INTO uploads (uuid, repository_id) VALUES (?, ?)`, + uuid, repoID, + ) + if err != nil { + return fmt.Errorf("db: create upload: %w", err) + } + return nil +} + +// GetUpload returns the upload with the given UUID. +func (d *DB) GetUpload(uuid string) (*UploadRow, error) { + var u UploadRow + err := d.QueryRow( + `SELECT id, uuid, repository_id, byte_offset FROM uploads WHERE uuid = ?`, uuid, + ).Scan(&u.ID, &u.UUID, &u.RepositoryID, &u.ByteOffset) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, ErrUploadNotFound + } + return nil, fmt.Errorf("db: get upload: %w", err) + } + return &u, nil +} + +// UpdateUploadOffset sets the byte_offset for an upload. +func (d *DB) UpdateUploadOffset(uuid string, offset int64) error { + result, err := d.Exec( + `UPDATE uploads SET byte_offset = ? WHERE uuid = ?`, + offset, uuid, + ) + if err != nil { + return fmt.Errorf("db: update upload offset: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: update upload offset rows affected: %w", err) + } + if n == 0 { + return ErrUploadNotFound + } + return nil +} + +// DeleteUpload removes the upload row with the given UUID. +func (d *DB) DeleteUpload(uuid string) error { + result, err := d.Exec(`DELETE FROM uploads WHERE uuid = ?`, uuid) + if err != nil { + return fmt.Errorf("db: delete upload: %w", err) + } + n, err := result.RowsAffected() + if err != nil { + return fmt.Errorf("db: delete upload rows affected: %w", err) + } + if n == 0 { + return ErrUploadNotFound + } + return nil +} diff --git a/internal/db/upload_test.go b/internal/db/upload_test.go new file mode 100644 index 0000000..0f6c824 --- /dev/null +++ b/internal/db/upload_test.go @@ -0,0 +1,124 @@ +package db + +import ( + "errors" + "testing" +) + +func TestCreateAndGetUpload(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + // Create a repository for the upload. + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + + if err := d.CreateUpload("test-uuid-1", 1); err != nil { + t.Fatalf("CreateUpload: %v", err) + } + + u, err := d.GetUpload("test-uuid-1") + if err != nil { + t.Fatalf("GetUpload: %v", err) + } + if u.UUID != "test-uuid-1" { + t.Fatalf("uuid: got %q, want %q", u.UUID, "test-uuid-1") + } + if u.RepositoryID != 1 { + t.Fatalf("repo id: got %d, want 1", u.RepositoryID) + } + if u.ByteOffset != 0 { + t.Fatalf("byte offset: got %d, want 0", u.ByteOffset) + } +} + +func TestGetUploadNotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.GetUpload("nonexistent") + if !errors.Is(err, ErrUploadNotFound) { + t.Fatalf("err: got %v, want ErrUploadNotFound", err) + } +} + +func TestUpdateUploadOffset(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + if err := d.CreateUpload("test-uuid-2", 1); err != nil { + t.Fatalf("CreateUpload: %v", err) + } + + if err := d.UpdateUploadOffset("test-uuid-2", 1024); err != nil { + t.Fatalf("UpdateUploadOffset: %v", err) + } + + u, err := d.GetUpload("test-uuid-2") + if err != nil { + t.Fatalf("GetUpload: %v", err) + } + if u.ByteOffset != 1024 { + t.Fatalf("byte offset: got %d, want 1024", u.ByteOffset) + } +} + +func TestUpdateUploadOffsetNotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + err := d.UpdateUploadOffset("nonexistent", 100) + if !errors.Is(err, ErrUploadNotFound) { + t.Fatalf("err: got %v, want ErrUploadNotFound", err) + } +} + +func TestDeleteUpload(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + _, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`) + if err != nil { + t.Fatalf("insert repo: %v", err) + } + if err := d.CreateUpload("test-uuid-3", 1); err != nil { + t.Fatalf("CreateUpload: %v", err) + } + + if err := d.DeleteUpload("test-uuid-3"); err != nil { + t.Fatalf("DeleteUpload: %v", err) + } + + _, err = d.GetUpload("test-uuid-3") + if !errors.Is(err, ErrUploadNotFound) { + t.Fatalf("after delete: got %v, want ErrUploadNotFound", err) + } +} + +func TestDeleteUploadNotFound(t *testing.T) { + d := openTestDB(t) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + + err := d.DeleteUpload("nonexistent") + if !errors.Is(err, ErrUploadNotFound) { + t.Fatalf("err: got %v, want ErrUploadNotFound", err) + } +} diff --git a/internal/oci/blob.go b/internal/oci/blob.go new file mode 100644 index 0000000..91da9e7 --- /dev/null +++ b/internal/oci/blob.go @@ -0,0 +1,98 @@ +package oci + +import ( + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +func (h *Handler) handleBlobGet(w http.ResponseWriter, r *http.Request, repo, digest string) { + if !h.checkPolicy(w, r, policy.ActionPull, repo) { + return + } + + repoID, err := h.db.GetRepositoryByName(repo) + if err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("repository %q not found", repo)) + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + exists, err := h.db.BlobExistsInRepo(repoID, digest) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if !exists { + writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("blob %q not found in repository", digest)) + return + } + + size, err := h.blobs.Stat(digest) + if err != nil { + writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, "blob not found in storage") + return + } + + rc, err := h.blobs.Open(digest) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + defer func() { _ = rc.Close() }() + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Docker-Content-Digest", digest) + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.WriteHeader(http.StatusOK) + _, _ = io.Copy(w, rc) +} + +func (h *Handler) handleBlobHead(w http.ResponseWriter, r *http.Request, repo, digest string) { + if !h.checkPolicy(w, r, policy.ActionPull, repo) { + return + } + + repoID, err := h.db.GetRepositoryByName(repo) + if err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("repository %q not found", repo)) + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + exists, err := h.db.BlobExistsInRepo(repoID, digest) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if !exists { + writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("blob %q not found in repository", digest)) + return + } + + size, err := h.blobs.Stat(digest) + if err != nil { + writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, "blob not found in storage") + return + } + + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Docker-Content-Digest", digest) + w.Header().Set("Content-Length", strconv.FormatInt(size, 10)) + w.WriteHeader(http.StatusOK) +} diff --git a/internal/oci/blob_test.go b/internal/oci/blob_test.go new file mode 100644 index 0000000..8d45896 --- /dev/null +++ b/internal/oci/blob_test.go @@ -0,0 +1,144 @@ +package oci + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestBlobGet(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addBlob(1, "sha256:layerdigest") + + blobs := newFakeBlobs() + blobs.data["sha256:layerdigest"] = []byte("layer-content-bytes") + + h := NewHandler(fdb, blobs, allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/blobs/sha256:layerdigest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/octet-stream" { + t.Fatalf("Content-Type: got %q", ct) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:layerdigest" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } + if cl := rr.Header().Get("Content-Length"); cl != "19" { + t.Fatalf("Content-Length: got %q, want %q", cl, "19") + } + if rr.Body.String() != "layer-content-bytes" { + t.Fatalf("body: got %q", rr.Body.String()) + } +} + +func TestBlobHead(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addBlob(1, "sha256:layerdigest") + + blobs := newFakeBlobs() + blobs.data["sha256:layerdigest"] = []byte("layer-content-bytes") + + h := NewHandler(fdb, blobs, allowAll(), nil) + router := testRouter(h) + + req := authedRequest("HEAD", "/v2/myrepo/blobs/sha256:layerdigest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/octet-stream" { + t.Fatalf("Content-Type: got %q", ct) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:layerdigest" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } + if cl := rr.Header().Get("Content-Length"); cl != "19" { + t.Fatalf("Content-Length: got %q, want %q", cl, "19") + } + if rr.Body.Len() != 0 { + t.Fatalf("HEAD body should be empty, got %d bytes", rr.Body.Len()) + } +} + +func TestBlobGetNotInRepo(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + // Blob NOT added to the repo. + + blobs := newFakeBlobs() + blobs.data["sha256:orphan"] = []byte("data") + + h := NewHandler(fdb, blobs, allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/blobs/sha256:orphan", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "BLOB_UNKNOWN" { + t.Fatalf("error code: got %+v, want BLOB_UNKNOWN", body.Errors) + } +} + +func TestBlobGetRepoNotFound(t *testing.T) { + fdb := newFakeDB() + // No repos. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/nosuchrepo/blobs/sha256:abc", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" { + t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors) + } +} + +func TestBlobGetMultiSegmentRepo(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("org/team/app", 1) + fdb.addBlob(1, "sha256:layerdigest") + + blobs := newFakeBlobs() + blobs.data["sha256:layerdigest"] = []byte("data") + + h := NewHandler(fdb, blobs, allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/org/team/app/blobs/sha256:layerdigest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } +} diff --git a/internal/oci/catalog.go b/internal/oci/catalog.go new file mode 100644 index 0000000..a44ed73 --- /dev/null +++ b/internal/oci/catalog.go @@ -0,0 +1,68 @@ +package oci + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +type catalogResponse struct { + Repositories []string `json:"repositories"` +} + +func (h *Handler) handleCatalog(w http.ResponseWriter, r *http.Request) { + if !h.checkPolicy(w, r, policy.ActionCatalog, "") { + return + } + + n := 0 + var err error + if nStr := r.URL.Query().Get("n"); nStr != "" { + n, err = strconv.Atoi(nStr) + if err != nil || n < 0 { + writeOCIError(w, "INVALID_PARAMETER", http.StatusBadRequest, "invalid 'n' parameter") + return + } + } + last := r.URL.Query().Get("last") + + if n == 0 { + repos, err := h.db.ListRepositoryNames(last, 10000) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if repos == nil { + repos = []string{} + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(catalogResponse{Repositories: repos}) + return + } + + repos, err := h.db.ListRepositoryNames(last, n+1) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if repos == nil { + repos = []string{} + } + + hasMore := len(repos) > n + if hasMore { + repos = repos[:n] + } + + if hasMore { + lastRepo := repos[len(repos)-1] + linkURL := fmt.Sprintf("/v2/_catalog?n=%d&last=%s", n, lastRepo) + w.Header().Set("Link", fmt.Sprintf(`<%s>; rel="next"`, linkURL)) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(catalogResponse{Repositories: repos}) +} diff --git a/internal/oci/catalog_test.go b/internal/oci/catalog_test.go new file mode 100644 index 0000000..1496664 --- /dev/null +++ b/internal/oci/catalog_test.go @@ -0,0 +1,124 @@ +package oci + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestCatalog(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("alpha/lib", 1) + fdb.addRepo("bravo/svc", 2) + fdb.addRepo("charlie/app", 3) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/_catalog", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body catalogResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + + want := []string{"alpha/lib", "bravo/svc", "charlie/app"} + if len(body.Repositories) != len(want) { + t.Fatalf("repos count: got %d, want %d", len(body.Repositories), len(want)) + } + for i, r := range body.Repositories { + if r != want[i] { + t.Fatalf("repos[%d]: got %q, want %q", i, r, want[i]) + } + } +} + +func TestCatalogPagination(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("alpha/lib", 1) + fdb.addRepo("bravo/svc", 2) + fdb.addRepo("charlie/app", 3) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + // First page: n=2. + req := authedRequest("GET", "/v2/_catalog?n=2", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body catalogResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Repositories) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(body.Repositories)) + } + if body.Repositories[0] != "alpha/lib" || body.Repositories[1] != "bravo/svc" { + t.Fatalf("page 1: got %v", body.Repositories) + } + + // Check Link header. + link := rr.Header().Get("Link") + if link == "" { + t.Fatal("expected Link header for pagination") + } + + // Second page: n=2, last=bravo/svc. + req = authedRequest("GET", "/v2/_catalog?n=2&last=bravo/svc", nil) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("page 2 status: got %d, want %d", rr.Code, http.StatusOK) + } + + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode page 2: %v", err) + } + if len(body.Repositories) != 1 { + t.Fatalf("page 2 count: got %d, want 1", len(body.Repositories)) + } + if body.Repositories[0] != "charlie/app" { + t.Fatalf("page 2: got %v", body.Repositories) + } + + // No Link header on last page. + if link := rr.Header().Get("Link"); link != "" { + t.Fatalf("expected no Link header on last page, got %q", link) + } +} + +func TestCatalogEmpty(t *testing.T) { + fdb := newFakeDB() + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/_catalog", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body catalogResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Repositories) != 0 { + t.Fatalf("expected empty repositories, got %v", body.Repositories) + } +} diff --git a/internal/oci/handler.go b/internal/oci/handler.go new file mode 100644 index 0000000..21609b1 --- /dev/null +++ b/internal/oci/handler.go @@ -0,0 +1,119 @@ +package oci + +import ( + "io" + "net/http" + "strings" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" + "git.wntrmute.dev/kyle/mcr/internal/storage" +) + +// DBQuerier provides the database operations needed by OCI handlers. +type DBQuerier interface { + GetRepositoryByName(name string) (int64, error) + GetManifestByTag(repoID int64, tag string) (*db.ManifestRow, error) + GetManifestByDigest(repoID int64, digest string) (*db.ManifestRow, error) + BlobExistsInRepo(repoID int64, digest string) (bool, error) + ListTags(repoID int64, after string, limit int) ([]string, error) + ListRepositoryNames(after string, limit int) ([]string, error) + + // Push operations + GetOrCreateRepository(name string) (int64, error) + BlobExists(digest string) (bool, error) + InsertBlob(digest string, size int64) error + PushManifest(p db.PushManifestParams) error + + // Upload operations + CreateUpload(uuid string, repoID int64) error + GetUpload(uuid string) (*db.UploadRow, error) + UpdateUploadOffset(uuid string, offset int64) error + DeleteUpload(uuid string) error +} + +// BlobStore provides read and write access to blob storage. +type BlobStore interface { + Open(digest string) (io.ReadCloser, error) + Stat(digest string) (int64, error) + StartUpload(uuid string) (*storage.BlobWriter, error) +} + +// PolicyEval evaluates access control policies. +type PolicyEval interface { + Evaluate(input policy.PolicyInput) (policy.Effect, *policy.Rule) +} + +// AuditFunc records audit events. Follows the same signature pattern as +// db.WriteAuditEvent but without an error return — audit failures should +// not block request processing. +type AuditFunc func(eventType, actorID, repository, digest, ip string, details map[string]string) + +// Handler serves OCI Distribution Spec endpoints. +type Handler struct { + db DBQuerier + blobs BlobStore + policy PolicyEval + auditFn AuditFunc + uploads *uploadManager +} + +// NewHandler creates a new OCI handler. +func NewHandler(querier DBQuerier, blobs BlobStore, pol PolicyEval, auditFn AuditFunc) *Handler { + return &Handler{ + db: querier, + blobs: blobs, + policy: pol, + auditFn: auditFn, + uploads: newUploadManager(), + } +} + +// isDigest returns true if the reference looks like a digest (sha256:...). +func isDigest(ref string) bool { + return strings.HasPrefix(ref, "sha256:") +} + +// checkPolicy evaluates the policy for the given action and repository. +// Returns true if access is allowed, false if denied (and writes the OCI error). +func (h *Handler) checkPolicy(w http.ResponseWriter, r *http.Request, action policy.Action, repo string) bool { + claims := auth.ClaimsFromContext(r.Context()) + if claims == nil { + writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required") + return false + } + + input := policy.PolicyInput{ + Subject: claims.Subject, + AccountType: claims.AccountType, + Roles: claims.Roles, + Action: action, + Repository: repo, + } + + effect, _ := h.policy.Evaluate(input) + if effect == policy.Deny { + if h.auditFn != nil { + h.auditFn("policy_deny", claims.Subject, repo, "", r.RemoteAddr, map[string]string{ + "action": string(action), + }) + } + writeOCIError(w, "DENIED", http.StatusForbidden, "access denied by policy") + return false + } + return true +} + +// audit records an audit event if an audit function is configured. +func (h *Handler) audit(r *http.Request, eventType, repo, digest string) { + if h.auditFn == nil { + return + } + claims := auth.ClaimsFromContext(r.Context()) + actorID := "" + if claims != nil { + actorID = claims.Subject + } + h.auditFn(eventType, actorID, repo, digest, r.RemoteAddr, nil) +} diff --git a/internal/oci/handler_test.go b/internal/oci/handler_test.go new file mode 100644 index 0000000..042f3ee --- /dev/null +++ b/internal/oci/handler_test.go @@ -0,0 +1,321 @@ +package oci + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "sort" + "sync" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" + "git.wntrmute.dev/kyle/mcr/internal/storage" +) + +// manifestKey uniquely identifies a manifest for test lookup. +type manifestKey struct { + repoID int64 + reference string // tag or digest +} + +// fakeDB implements DBQuerier for tests. +type fakeDB struct { + mu sync.Mutex + repos map[string]int64 // name -> id + manifests map[manifestKey]*db.ManifestRow // (repoID, ref) -> manifest + blobs map[int64]map[string]bool // repoID -> set of digests + allBlobs map[string]bool // global blob digests + tags map[int64][]string // repoID -> sorted tag names + repoNames []string // sorted repo names + uploads map[string]*db.UploadRow // uuid -> upload + nextID int64 // auto-increment counter + pushed []db.PushManifestParams // record of pushed manifests +} + +func newFakeDB() *fakeDB { + return &fakeDB{ + repos: make(map[string]int64), + manifests: make(map[manifestKey]*db.ManifestRow), + blobs: make(map[int64]map[string]bool), + allBlobs: make(map[string]bool), + tags: make(map[int64][]string), + uploads: make(map[string]*db.UploadRow), + nextID: 1, + } +} + +func (f *fakeDB) GetRepositoryByName(name string) (int64, error) { + id, ok := f.repos[name] + if !ok { + return 0, db.ErrRepoNotFound + } + return id, nil +} + +func (f *fakeDB) GetManifestByTag(repoID int64, tag string) (*db.ManifestRow, error) { + m, ok := f.manifests[manifestKey{repoID, tag}] + if !ok { + return nil, db.ErrManifestNotFound + } + return m, nil +} + +func (f *fakeDB) GetManifestByDigest(repoID int64, digest string) (*db.ManifestRow, error) { + m, ok := f.manifests[manifestKey{repoID, digest}] + if !ok { + return nil, db.ErrManifestNotFound + } + return m, nil +} + +func (f *fakeDB) BlobExistsInRepo(repoID int64, digest string) (bool, error) { + digests, ok := f.blobs[repoID] + if !ok { + return false, nil + } + return digests[digest], nil +} + +func (f *fakeDB) ListTags(repoID int64, after string, limit int) ([]string, error) { + allTags := f.tags[repoID] + var result []string + for _, t := range allTags { + if after != "" && t <= after { + continue + } + result = append(result, t) + if len(result) >= limit { + break + } + } + return result, nil +} + +func (f *fakeDB) ListRepositoryNames(after string, limit int) ([]string, error) { + var result []string + for _, n := range f.repoNames { + if after != "" && n <= after { + continue + } + result = append(result, n) + if len(result) >= limit { + break + } + } + return result, nil +} + +func (f *fakeDB) GetOrCreateRepository(name string) (int64, error) { + f.mu.Lock() + defer f.mu.Unlock() + id, ok := f.repos[name] + if ok { + return id, nil + } + id = f.nextID + f.nextID++ + f.repos[name] = id + f.repoNames = append(f.repoNames, name) + sort.Strings(f.repoNames) + return id, nil +} + +func (f *fakeDB) BlobExists(digest string) (bool, error) { + return f.allBlobs[digest], nil +} + +func (f *fakeDB) InsertBlob(digest string, size int64) error { + f.allBlobs[digest] = true + _ = size + return nil +} + +func (f *fakeDB) PushManifest(p db.PushManifestParams) error { + f.mu.Lock() + defer f.mu.Unlock() + f.pushed = append(f.pushed, p) + // Simulate creating the manifest in our fake data. + repoID, ok := f.repos[p.RepoName] + if !ok { + repoID = f.nextID + f.nextID++ + f.repos[p.RepoName] = repoID + f.repoNames = append(f.repoNames, p.RepoName) + sort.Strings(f.repoNames) + } + m := &db.ManifestRow{ + ID: f.nextID, + RepositoryID: repoID, + Digest: p.Digest, + MediaType: p.MediaType, + Content: p.Content, + Size: p.Size, + } + f.nextID++ + f.manifests[manifestKey{repoID, p.Digest}] = m + if p.Tag != "" { + f.manifests[manifestKey{repoID, p.Tag}] = m + } + return nil +} + +func (f *fakeDB) CreateUpload(uuid string, repoID int64) error { + f.mu.Lock() + defer f.mu.Unlock() + f.uploads[uuid] = &db.UploadRow{ + ID: f.nextID, + UUID: uuid, + RepositoryID: repoID, + ByteOffset: 0, + } + f.nextID++ + return nil +} + +func (f *fakeDB) GetUpload(uuid string) (*db.UploadRow, error) { + f.mu.Lock() + defer f.mu.Unlock() + u, ok := f.uploads[uuid] + if !ok { + return nil, db.ErrUploadNotFound + } + return u, nil +} + +func (f *fakeDB) UpdateUploadOffset(uuid string, offset int64) error { + f.mu.Lock() + defer f.mu.Unlock() + u, ok := f.uploads[uuid] + if !ok { + return db.ErrUploadNotFound + } + u.ByteOffset = offset + return nil +} + +func (f *fakeDB) DeleteUpload(uuid string) error { + f.mu.Lock() + defer f.mu.Unlock() + if _, ok := f.uploads[uuid]; !ok { + return db.ErrUploadNotFound + } + delete(f.uploads, uuid) + return nil +} + +// addRepo adds a repo to the fakeDB and returns its ID. +func (f *fakeDB) addRepo(name string, id int64) { + f.repos[name] = id + f.repoNames = append(f.repoNames, name) + sort.Strings(f.repoNames) + if id >= f.nextID { + f.nextID = id + 1 + } +} + +// addManifest adds a manifest accessible by both tag and digest. +func (f *fakeDB) addManifest(repoID int64, tag, digest, mediaType string, content []byte) { + m := &db.ManifestRow{ + ID: f.nextID, + RepositoryID: repoID, + Digest: digest, + MediaType: mediaType, + Content: content, + Size: int64(len(content)), + } + f.nextID++ + if tag != "" { + f.manifests[manifestKey{repoID, tag}] = m + } + f.manifests[manifestKey{repoID, digest}] = m +} + +// addBlob registers a blob digest in a repository. +func (f *fakeDB) addBlob(repoID int64, digest string) { + if f.blobs[repoID] == nil { + f.blobs[repoID] = make(map[string]bool) + } + f.blobs[repoID][digest] = true +} + +// addGlobalBlob registers a blob in the global blob table. +func (f *fakeDB) addGlobalBlob(digest string) { + f.allBlobs[digest] = true +} + +// addTag adds a tag to a repository's tag list. +func (f *fakeDB) addTag(repoID int64, tag string) { + f.tags[repoID] = append(f.tags[repoID], tag) + sort.Strings(f.tags[repoID]) +} + +// fakeBlobs implements BlobStore for tests. +type fakeBlobs struct { + data map[string][]byte // digest -> content + uploads map[string]*bytes.Buffer +} + +func newFakeBlobs() *fakeBlobs { + return &fakeBlobs{ + data: make(map[string][]byte), + uploads: make(map[string]*bytes.Buffer), + } +} + +func (f *fakeBlobs) Open(digest string) (io.ReadCloser, error) { + data, ok := f.data[digest] + if !ok { + return nil, io.ErrUnexpectedEOF + } + return io.NopCloser(bytes.NewReader(data)), nil +} + +func (f *fakeBlobs) Stat(digest string) (int64, error) { + data, ok := f.data[digest] + if !ok { + return 0, io.ErrUnexpectedEOF + } + return int64(len(data)), nil +} + +func (f *fakeBlobs) StartUpload(uuid string) (*storage.BlobWriter, error) { + // For tests that need real storage, use a real Store in t.TempDir(). + // This fake panics to catch unintended usage. + panic("fakeBlobs.StartUpload should not be called; use a real storage.Store for upload tests") +} + +// fakePolicy implements PolicyEval, always returning Allow. +type fakePolicy struct { + effect policy.Effect +} + +func (f *fakePolicy) Evaluate(_ policy.PolicyInput) (policy.Effect, *policy.Rule) { + return f.effect, nil +} + +// allowAll returns a fakePolicy that allows all requests. +func allowAll() *fakePolicy { + return &fakePolicy{effect: policy.Allow} +} + +// testRouter creates a chi.Mux with the OCI handler mounted at /v2. +func testRouter(h *Handler) *chi.Mux { + parent := chi.NewRouter() + parent.Mount("/v2", h.Router()) + return parent +} + +// authedRequest creates an HTTP request with authenticated claims in the context. +func authedRequest(method, path string, body io.Reader) *http.Request { + req := httptest.NewRequest(method, path, body) + claims := &auth.Claims{ + Subject: "test-user", + AccountType: "human", + Roles: []string{"user"}, + } + return req.WithContext(auth.ContextWithClaims(req.Context(), claims)) +} diff --git a/internal/oci/manifest.go b/internal/oci/manifest.go new file mode 100644 index 0000000..9ce62de --- /dev/null +++ b/internal/oci/manifest.go @@ -0,0 +1,222 @@ +package oci + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +func (h *Handler) handleManifestGet(w http.ResponseWriter, r *http.Request, repo, reference string) { + if !h.checkPolicy(w, r, policy.ActionPull, repo) { + return + } + + m, ok := h.resolveManifest(w, repo, reference) + if !ok { + return + } + + h.audit(r, "manifest_pulled", repo, m.Digest) + + w.Header().Set("Content-Type", m.MediaType) + w.Header().Set("Docker-Content-Digest", m.Digest) + w.Header().Set("Content-Length", strconv.FormatInt(m.Size, 10)) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(m.Content) +} + +func (h *Handler) handleManifestHead(w http.ResponseWriter, r *http.Request, repo, reference string) { + if !h.checkPolicy(w, r, policy.ActionPull, repo) { + return + } + + m, ok := h.resolveManifest(w, repo, reference) + if !ok { + return + } + + w.Header().Set("Content-Type", m.MediaType) + w.Header().Set("Docker-Content-Digest", m.Digest) + w.Header().Set("Content-Length", strconv.FormatInt(m.Size, 10)) + w.WriteHeader(http.StatusOK) +} + +// resolveManifest looks up a manifest by tag or digest, writing OCI errors on failure. +func (h *Handler) resolveManifest(w http.ResponseWriter, repo, reference string) (*db.ManifestRow, bool) { + repoID, err := h.db.GetRepositoryByName(repo) + if err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("repository %q not found", repo)) + return nil, false + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return nil, false + } + + var m *db.ManifestRow + if isDigest(reference) { + m, err = h.db.GetManifestByDigest(repoID, reference) + } else { + m, err = h.db.GetManifestByTag(repoID, reference) + } + if err != nil { + if errors.Is(err, db.ErrManifestNotFound) { + writeOCIError(w, "MANIFEST_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("manifest %q not found", reference)) + return nil, false + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return nil, false + } + + return m, true +} + +// ociManifest is the minimal structure for parsing an OCI image manifest. +type ociManifest struct { + SchemaVersion int `json:"schemaVersion"` + MediaType string `json:"mediaType,omitempty"` + Config ociDescriptor `json:"config"` + Layers []ociDescriptor `json:"layers"` +} + +// ociDescriptor is a content-addressed reference within a manifest. +type ociDescriptor struct { + MediaType string `json:"mediaType"` + Digest string `json:"digest"` + Size int64 `json:"size"` +} + +// supportedManifestMediaTypes lists the manifest media types MCR accepts. +var supportedManifestMediaTypes = map[string]bool{ + "application/vnd.oci.image.manifest.v1+json": true, + "application/vnd.docker.distribution.manifest.v2+json": true, +} + +// handleManifestPut handles PUT /v2//manifests/ +func (h *Handler) handleManifestPut(w http.ResponseWriter, r *http.Request, repo, reference string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + // Step 1: Read and parse manifest JSON. + body, err := io.ReadAll(r.Body) + if err != nil { + writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "failed to read request body") + return + } + if len(body) == 0 { + writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "empty manifest") + return + } + + var manifest ociManifest + if err := json.Unmarshal(body, &manifest); err != nil { + writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "malformed manifest JSON") + return + } + if manifest.SchemaVersion != 2 { + writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "unsupported schema version") + return + } + + // Determine media type from Content-Type header, falling back to manifest body. + mediaType := r.Header.Get("Content-Type") + if mediaType == "" { + mediaType = manifest.MediaType + } + if mediaType == "" { + mediaType = "application/vnd.oci.image.manifest.v1+json" + } + if !supportedManifestMediaTypes[mediaType] { + writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, + fmt.Sprintf("unsupported media type: %s", mediaType)) + return + } + + // Step 2: Compute SHA-256 digest. + sum := sha256.Sum256(body) + computedDigest := "sha256:" + hex.EncodeToString(sum[:]) + + // Step 3: If reference is a digest, verify it matches. + tag := "" + if isDigest(reference) { + if reference != computedDigest { + writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, + fmt.Sprintf("digest mismatch: computed %s, got %s", computedDigest, reference)) + return + } + } else { + tag = reference + } + + // Step 4: Collect all referenced blob digests. + var blobDigests []string + if manifest.Config.Digest != "" { + blobDigests = append(blobDigests, manifest.Config.Digest) + } + for _, layer := range manifest.Layers { + if layer.Digest != "" { + blobDigests = append(blobDigests, layer.Digest) + } + } + + // Step 5: Verify all referenced blobs exist. + for _, bd := range blobDigests { + exists, err := h.db.BlobExists(bd) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if !exists { + writeOCIError(w, "MANIFEST_BLOB_UNKNOWN", http.StatusBadRequest, + fmt.Sprintf("blob %s not found", bd)) + return + } + } + + // Step 6: Single transaction — create repo, insert manifest, populate + // manifest_blobs, upsert tag. + params := db.PushManifestParams{ + RepoName: repo, + Digest: computedDigest, + MediaType: mediaType, + Content: body, + Size: int64(len(body)), + Tag: tag, + BlobDigests: blobDigests, + } + if err := h.db.PushManifest(params); err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + // Step 7: Audit and respond. + details := map[string]string{} + if tag != "" { + details["tag"] = tag + } + if h.auditFn != nil { + claims := auth.ClaimsFromContext(r.Context()) + actorID := "" + if claims != nil { + actorID = claims.Subject + } + h.auditFn("manifest_pushed", actorID, repo, computedDigest, r.RemoteAddr, details) + } + + w.Header().Set("Location", fmt.Sprintf("/v2/%s/manifests/%s", repo, computedDigest)) + w.Header().Set("Docker-Content-Digest", computedDigest) + w.Header().Set("Content-Type", mediaType) + w.WriteHeader(http.StatusCreated) +} diff --git a/internal/oci/manifest_test.go b/internal/oci/manifest_test.go new file mode 100644 index 0000000..0295d74 --- /dev/null +++ b/internal/oci/manifest_test.go @@ -0,0 +1,187 @@ +package oci + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestManifestGetByTag(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + content := []byte(`{"schemaVersion":2}`) + fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/manifests/latest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/vnd.oci.image.manifest.v1+json" { + t.Fatalf("Content-Type: got %q", ct) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } + if cl := rr.Header().Get("Content-Length"); cl != "19" { + t.Fatalf("Content-Length: got %q, want %q", cl, "19") + } + if rr.Body.String() != `{"schemaVersion":2}` { + t.Fatalf("body: got %q", rr.Body.String()) + } +} + +func TestManifestGetByDigest(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + content := []byte(`{"schemaVersion":2}`) + fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/manifests/sha256:aaaa", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } + if rr.Body.String() != `{"schemaVersion":2}` { + t.Fatalf("body: got %q", rr.Body.String()) + } +} + +func TestManifestHead(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + content := []byte(`{"schemaVersion":2}`) + fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("HEAD", "/v2/myrepo/manifests/latest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if ct := rr.Header().Get("Content-Type"); ct != "application/vnd.oci.image.manifest.v1+json" { + t.Fatalf("Content-Type: got %q", ct) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } + if cl := rr.Header().Get("Content-Length"); cl != "19" { + t.Fatalf("Content-Length: got %q, want %q", cl, "19") + } + if rr.Body.Len() != 0 { + t.Fatalf("HEAD body should be empty, got %d bytes", rr.Body.Len()) + } +} + +func TestManifestGetRepoNotFound(t *testing.T) { + fdb := newFakeDB() + // No repos added. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/nosuchrepo/manifests/latest", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" { + t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors) + } +} + +func TestManifestGetManifestNotFoundByTag(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + // No manifests added. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/manifests/nonexistent", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_UNKNOWN" { + t.Fatalf("error code: got %+v, want MANIFEST_UNKNOWN", body.Errors) + } +} + +func TestManifestGetManifestNotFoundByDigest(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + // No manifests added. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/manifests/sha256:nonexistent", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_UNKNOWN" { + t.Fatalf("error code: got %+v, want MANIFEST_UNKNOWN", body.Errors) + } +} + +func TestManifestGetMultiSegmentRepo(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("org/team/app", 1) + content := []byte(`{"layers":[]}`) + fdb.addManifest(1, "v1.0", "sha256:cccc", "application/vnd.oci.image.manifest.v1+json", content) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/org/team/app/manifests/v1.0", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:cccc" { + t.Fatalf("Docker-Content-Digest: got %q", dcd) + } +} diff --git a/internal/oci/ocierror.go b/internal/oci/ocierror.go new file mode 100644 index 0000000..5351902 --- /dev/null +++ b/internal/oci/ocierror.go @@ -0,0 +1,23 @@ +package oci + +import ( + "encoding/json" + "net/http" +) + +type ociErrorEntry struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type ociErrorResponse struct { + Errors []ociErrorEntry `json:"errors"` +} + +func writeOCIError(w http.ResponseWriter, code string, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(ociErrorResponse{ + Errors: []ociErrorEntry{{Code: code, Message: message}}, + }) +} diff --git a/internal/oci/push_test.go b/internal/oci/push_test.go new file mode 100644 index 0000000..8a00515 --- /dev/null +++ b/internal/oci/push_test.go @@ -0,0 +1,266 @@ +package oci + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "net/http/httptest" + "testing" +) + +func makeManifest(configDigest string, layerDigests []string) []byte { + layers := "" + for i, d := range layerDigests { + if i > 0 { + layers += "," + } + layers += fmt.Sprintf(`{"mediaType":"application/vnd.oci.image.layer.v1.tar+gzip","digest":%q,"size":1000}`, d) + } + return []byte(fmt.Sprintf(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","config":{"mediaType":"application/vnd.oci.image.config.v1+json","digest":%q,"size":100},"layers":[%s]}`, configDigest, layers)) +} + +func manifestDigest(content []byte) string { + sum := sha256.Sum256(content) + return "sha256:" + hex.EncodeToString(sum[:]) +} + +func TestManifestPushByTag(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addGlobalBlob("sha256:config1") + fdb.addGlobalBlob("sha256:layer1") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + content := makeManifest("sha256:config1", []string{"sha256:layer1"}) + digest := manifestDigest(content) + + req := authedRequest("PUT", "/v2/myrepo/manifests/latest", nil) + req.Body = http.NoBody + // Re-create with proper body. + req = authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("status: got %d, want %d, body: %s", rr.Code, http.StatusCreated, rr.Body.String()) + } + + dcd := rr.Header().Get("Docker-Content-Digest") + if dcd != digest { + t.Fatalf("Docker-Content-Digest: got %q, want %q", dcd, digest) + } + + loc := rr.Header().Get("Location") + if loc == "" { + t.Fatal("Location header missing") + } + + ct := rr.Header().Get("Content-Type") + if ct != "application/vnd.oci.image.manifest.v1+json" { + t.Fatalf("Content-Type: got %q", ct) + } + + // Verify manifest was stored. + if len(fdb.pushed) != 1 { + t.Fatalf("pushed count: got %d, want 1", len(fdb.pushed)) + } + if fdb.pushed[0].Tag != "latest" { + t.Fatalf("pushed tag: got %q, want %q", fdb.pushed[0].Tag, "latest") + } +} + +func TestManifestPushByDigest(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addGlobalBlob("sha256:config1") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + content := makeManifest("sha256:config1", nil) + digest := manifestDigest(content) + + req := authedPushRequest("PUT", "/v2/myrepo/manifests/"+digest, content) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("status: got %d, body: %s", rr.Code, rr.Body.String()) + } + + // Verify no tag was set. + if len(fdb.pushed) != 1 { + t.Fatalf("pushed count: got %d", len(fdb.pushed)) + } + if fdb.pushed[0].Tag != "" { + t.Fatalf("pushed tag: got %q, want empty", fdb.pushed[0].Tag) + } +} + +func TestManifestPushDigestMismatch(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + content := []byte(`{"schemaVersion":2}`) + wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000" + + req := authedPushRequest("PUT", "/v2/myrepo/manifests/"+wrongDigest, content) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "DIGEST_INVALID" { + t.Fatalf("error code: got %+v, want DIGEST_INVALID", body.Errors) + } +} + +func TestManifestPushMissingBlob(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + // Config blob exists but layer blob does not. + fdb.addGlobalBlob("sha256:config1") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + content := makeManifest("sha256:config1", []string{"sha256:missing_layer"}) + + req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d, body: %s", rr.Code, http.StatusBadRequest, rr.Body.String()) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_BLOB_UNKNOWN" { + t.Fatalf("error code: got %+v, want MANIFEST_BLOB_UNKNOWN", body.Errors) + } +} + +func TestManifestPushMalformed(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", []byte("not valid json")) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_INVALID" { + t.Fatalf("error code: got %+v, want MANIFEST_INVALID", body.Errors) + } +} + +func TestManifestPushEmpty(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", nil) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest) + } +} + +func TestManifestPushUpdatesTag(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addGlobalBlob("sha256:config1") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + // First push with tag "latest". + content1 := makeManifest("sha256:config1", nil) + req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content1) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("first push status: got %d", rr.Code) + } + firstDigest := rr.Header().Get("Docker-Content-Digest") + + // Second push with same tag — different content. + fdb.addGlobalBlob("sha256:config2") + content2 := makeManifest("sha256:config2", nil) + req = authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content2) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("second push status: got %d", rr.Code) + } + secondDigest := rr.Header().Get("Docker-Content-Digest") + + if firstDigest == secondDigest { + t.Fatal("two pushes should produce different digests") + } + + // Verify that the tag was atomically moved. + if len(fdb.pushed) != 2 { + t.Fatalf("pushed count: got %d, want 2", len(fdb.pushed)) + } +} + +func TestManifestPushRepushIdempotent(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addGlobalBlob("sha256:config1") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + content := makeManifest("sha256:config1", nil) + + // Push the same manifest twice. + for i := range 2 { + req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content) + req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json") + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + if rr.Code != http.StatusCreated { + t.Fatalf("push %d status: got %d", i+1, rr.Code) + } + } +} diff --git a/internal/oci/routes.go b/internal/oci/routes.go new file mode 100644 index 0000000..db11060 --- /dev/null +++ b/internal/oci/routes.go @@ -0,0 +1,171 @@ +package oci + +import ( + "net/http" + "strings" + + "github.com/go-chi/chi/v5" +) + +// ociPathInfo holds the parsed components of an OCI API path. +type ociPathInfo struct { + name string // repository name (may contain slashes) + kind string // "manifests", "blobs", or "tags" + reference string // tag, digest, or "list" +} + +// parseOCIPath extracts the repository name and operation from an OCI path. +// The path should NOT include the /v2/ prefix. +// Examples: +// +// "myrepo/manifests/latest" -> {name:"myrepo", kind:"manifests", reference:"latest"} +// "org/team/app/blobs/sha256:abc" -> {name:"org/team/app", kind:"blobs", reference:"sha256:abc"} +// "myrepo/tags/list" -> {name:"myrepo", kind:"tags", reference:"list"} +// "myrepo/blobs/uploads/" -> {name:"myrepo", kind:"uploads", reference:""} +// "myrepo/blobs/uploads/uuid-here" -> {name:"myrepo", kind:"uploads", reference:"uuid-here"} +func parseOCIPath(path string) (ociPathInfo, bool) { + // Check for /tags/list suffix. + if strings.HasSuffix(path, "/tags/list") { + name := path[:len(path)-len("/tags/list")] + if name == "" { + return ociPathInfo{}, false + } + return ociPathInfo{name: name, kind: "tags", reference: "list"}, true + } + + // Check for /blobs/uploads/ (must come before /blobs/). + if idx := strings.LastIndex(path, "/blobs/uploads/"); idx >= 0 { + name := path[:idx] + uuid := path[idx+len("/blobs/uploads/"):] + if name == "" { + return ociPathInfo{}, false + } + return ociPathInfo{name: name, kind: "uploads", reference: uuid}, true + } + + // Check for /blobs/uploads (trailing-slash-trimmed form for POST initiate). + if strings.HasSuffix(path, "/blobs/uploads") { + name := path[:len(path)-len("/blobs/uploads")] + if name == "" { + return ociPathInfo{}, false + } + return ociPathInfo{name: name, kind: "uploads", reference: ""}, true + } + + // Check for /manifests/. + if idx := strings.LastIndex(path, "/manifests/"); idx >= 0 { + name := path[:idx] + ref := path[idx+len("/manifests/"):] + if name == "" || ref == "" { + return ociPathInfo{}, false + } + return ociPathInfo{name: name, kind: "manifests", reference: ref}, true + } + + // Check for /blobs/. + if idx := strings.LastIndex(path, "/blobs/"); idx >= 0 { + name := path[:idx] + ref := path[idx+len("/blobs/"):] + if name == "" || ref == "" { + return ociPathInfo{}, false + } + return ociPathInfo{name: name, kind: "blobs", reference: ref}, true + } + + return ociPathInfo{}, false +} + +// Router returns a chi router for OCI Distribution Spec endpoints. +// It should be mounted at /v2 on the parent router. +func (h *Handler) Router() chi.Router { + r := chi.NewRouter() + + // Catalog endpoint: GET /v2/_catalog + r.Get("/_catalog", h.handleCatalog) + + // All other OCI endpoints use a catch-all to support multi-segment repo names. + r.HandleFunc("/*", h.dispatch) + + return r +} + +// dispatch routes requests to the appropriate handler based on the parsed path. +func (h *Handler) dispatch(w http.ResponseWriter, r *http.Request) { + // Get the path after /v2/ + path := chi.URLParam(r, "*") + if path == "" { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "repository name required") + return + } + + info, ok := parseOCIPath(path) + if !ok { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "invalid OCI path") + return + } + + switch info.kind { + case "manifests": + switch r.Method { + case http.MethodGet: + h.handleManifestGet(w, r, info.name, info.reference) + case http.MethodHead: + h.handleManifestHead(w, r, info.name, info.reference) + case http.MethodPut: + h.handleManifestPut(w, r, info.name, info.reference) + default: + w.Header().Set("Allow", "GET, HEAD, PUT") + writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed") + } + case "blobs": + switch r.Method { + case http.MethodGet: + h.handleBlobGet(w, r, info.name, info.reference) + case http.MethodHead: + h.handleBlobHead(w, r, info.name, info.reference) + default: + w.Header().Set("Allow", "GET, HEAD") + writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed") + } + case "uploads": + h.dispatchUpload(w, r, info.name, info.reference) + case "tags": + if r.Method != http.MethodGet { + w.Header().Set("Allow", "GET") + writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed") + return + } + h.handleTagsList(w, r, info.name) + default: + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "unknown operation") + } +} + +// dispatchUpload routes upload requests to the appropriate handler. +func (h *Handler) dispatchUpload(w http.ResponseWriter, r *http.Request, repo, uuid string) { + if uuid == "" { + // POST /v2//blobs/uploads/ — initiate + if r.Method != http.MethodPost { + w.Header().Set("Allow", "POST") + writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed") + return + } + h.handleUploadInitiate(w, r, repo) + return + } + + // Operations on existing upload UUID. + switch r.Method { + case http.MethodPatch: + h.handleUploadChunk(w, r, repo, uuid) + case http.MethodPut: + h.handleUploadComplete(w, r, repo, uuid) + case http.MethodGet: + h.handleUploadStatus(w, r, repo, uuid) + case http.MethodDelete: + h.handleUploadCancel(w, r, repo, uuid) + default: + w.Header().Set("Allow", "PATCH, PUT, GET, DELETE") + writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed") + } +} diff --git a/internal/oci/routes_test.go b/internal/oci/routes_test.go new file mode 100644 index 0000000..3de9760 --- /dev/null +++ b/internal/oci/routes_test.go @@ -0,0 +1,141 @@ +package oci + +import "testing" + +func TestParseOCIPath(t *testing.T) { + tests := []struct { + name string + path string + want ociPathInfo + wantOK bool + }{ + { + name: "simple repo manifest by tag", + path: "myrepo/manifests/latest", + want: ociPathInfo{name: "myrepo", kind: "manifests", reference: "latest"}, + wantOK: true, + }, + { + name: "multi-segment repo manifest by tag", + path: "org/team/app/manifests/v1.0", + want: ociPathInfo{name: "org/team/app", kind: "manifests", reference: "v1.0"}, + wantOK: true, + }, + { + name: "manifest by digest", + path: "myrepo/manifests/sha256:abc123def456", + want: ociPathInfo{name: "myrepo", kind: "manifests", reference: "sha256:abc123def456"}, + wantOK: true, + }, + { + name: "simple repo blob", + path: "myrepo/blobs/sha256:abc123", + want: ociPathInfo{name: "myrepo", kind: "blobs", reference: "sha256:abc123"}, + wantOK: true, + }, + { + name: "multi-segment repo blob", + path: "org/team/app/blobs/sha256:abc123", + want: ociPathInfo{name: "org/team/app", kind: "blobs", reference: "sha256:abc123"}, + wantOK: true, + }, + { + name: "simple repo tags list", + path: "myrepo/tags/list", + want: ociPathInfo{name: "myrepo", kind: "tags", reference: "list"}, + wantOK: true, + }, + { + name: "multi-segment repo tags list", + path: "org/app/tags/list", + want: ociPathInfo{name: "org/app", kind: "tags", reference: "list"}, + wantOK: true, + }, + { + name: "empty path", + path: "", + wantOK: false, + }, + { + name: "just repo name", + path: "myrepo", + wantOK: false, + }, + { + name: "unknown operation", + path: "myrepo/unknown/ref", + wantOK: false, + }, + { + name: "manifests with no ref", + path: "myrepo/manifests/", + wantOK: false, + }, + { + name: "blobs with no digest", + path: "myrepo/blobs/", + wantOK: false, + }, + { + name: "tags without list suffix", + path: "myrepo/tags/something", + wantOK: false, + }, + { + name: "no repo name before manifests", + path: "/manifests/latest", + wantOK: false, + }, + { + name: "upload initiate (trailing slash)", + path: "myrepo/blobs/uploads/", + want: ociPathInfo{name: "myrepo", kind: "uploads", reference: ""}, + wantOK: true, + }, + { + name: "upload initiate (no trailing slash)", + path: "myrepo/blobs/uploads", + want: ociPathInfo{name: "myrepo", kind: "uploads", reference: ""}, + wantOK: true, + }, + { + name: "upload with uuid", + path: "myrepo/blobs/uploads/abc-123-def", + want: ociPathInfo{name: "myrepo", kind: "uploads", reference: "abc-123-def"}, + wantOK: true, + }, + { + name: "multi-segment repo upload", + path: "org/team/app/blobs/uploads/uuid-456", + want: ociPathInfo{name: "org/team/app", kind: "uploads", reference: "uuid-456"}, + wantOK: true, + }, + { + name: "multi-segment repo upload initiate", + path: "org/team/app/blobs/uploads/", + want: ociPathInfo{name: "org/team/app", kind: "uploads", reference: ""}, + wantOK: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := parseOCIPath(tt.path) + if ok != tt.wantOK { + t.Fatalf("parseOCIPath(%q) ok = %v, want %v", tt.path, ok, tt.wantOK) + } + if !ok { + return + } + if got.name != tt.want.name { + t.Errorf("name: got %q, want %q", got.name, tt.want.name) + } + if got.kind != tt.want.kind { + t.Errorf("kind: got %q, want %q", got.kind, tt.want.kind) + } + if got.reference != tt.want.reference { + t.Errorf("reference: got %q, want %q", got.reference, tt.want.reference) + } + }) + } +} diff --git a/internal/oci/tags.go b/internal/oci/tags.go new file mode 100644 index 0000000..5df524f --- /dev/null +++ b/internal/oci/tags.go @@ -0,0 +1,84 @@ +package oci + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +type tagListResponse struct { + Name string `json:"name"` + Tags []string `json:"tags"` +} + +func (h *Handler) handleTagsList(w http.ResponseWriter, r *http.Request, repo string) { + if !h.checkPolicy(w, r, policy.ActionPull, repo) { + return + } + + repoID, err := h.db.GetRepositoryByName(repo) + if err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, + fmt.Sprintf("repository %q not found", repo)) + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + // Parse pagination params. + n := 0 + if nStr := r.URL.Query().Get("n"); nStr != "" { + n, err = strconv.Atoi(nStr) + if err != nil || n < 0 { + writeOCIError(w, "INVALID_PARAMETER", http.StatusBadRequest, "invalid 'n' parameter") + return + } + } + last := r.URL.Query().Get("last") + + // Default: no limit (return all). + if n == 0 { + tags, err := h.db.ListTags(repoID, last, 10000) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if tags == nil { + tags = []string{} + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tagListResponse{Name: repo, Tags: tags}) + return + } + + // Request n+1 to detect if there are more results. + tags, err := h.db.ListTags(repoID, last, n+1) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + if tags == nil { + tags = []string{} + } + + hasMore := len(tags) > n + if hasMore { + tags = tags[:n] + } + + if hasMore { + lastTag := tags[len(tags)-1] + linkURL := fmt.Sprintf("/v2/%s/tags/list?n=%d&last=%s", repo, n, lastTag) + w.Header().Set("Link", fmt.Sprintf(`<%s>; rel="next"`, linkURL)) + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tagListResponse{Name: repo, Tags: tags}) +} diff --git a/internal/oci/tags_test.go b/internal/oci/tags_test.go new file mode 100644 index 0000000..6ac8418 --- /dev/null +++ b/internal/oci/tags_test.go @@ -0,0 +1,154 @@ +package oci + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +func TestTagsList(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addTag(1, "latest") + fdb.addTag(1, "v1.0") + fdb.addTag(1, "v2.0") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/tags/list", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body tagListResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if body.Name != "myrepo" { + t.Fatalf("name: got %q, want %q", body.Name, "myrepo") + } + want := []string{"latest", "v1.0", "v2.0"} + if len(body.Tags) != len(want) { + t.Fatalf("tags count: got %d, want %d", len(body.Tags), len(want)) + } + for i, tag := range body.Tags { + if tag != want[i] { + t.Fatalf("tags[%d]: got %q, want %q", i, tag, want[i]) + } + } +} + +func TestTagsListPagination(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + fdb.addTag(1, "alpha") + fdb.addTag(1, "beta") + fdb.addTag(1, "gamma") + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + // First page: n=2. + req := authedRequest("GET", "/v2/myrepo/tags/list?n=2", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body tagListResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Tags) != 2 { + t.Fatalf("page 1 tags count: got %d, want 2", len(body.Tags)) + } + if body.Tags[0] != "alpha" || body.Tags[1] != "beta" { + t.Fatalf("page 1 tags: got %v", body.Tags) + } + + // Check Link header. + link := rr.Header().Get("Link") + if link == "" { + t.Fatal("expected Link header for pagination") + } + + // Second page: n=2, last=beta. + req = authedRequest("GET", "/v2/myrepo/tags/list?n=2&last=beta", nil) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("page 2 status: got %d, want %d", rr.Code, http.StatusOK) + } + + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode page 2: %v", err) + } + if len(body.Tags) != 1 { + t.Fatalf("page 2 tags count: got %d, want 1", len(body.Tags)) + } + if body.Tags[0] != "gamma" { + t.Fatalf("page 2 tags: got %v", body.Tags) + } + + // No Link header on last page. + if link := rr.Header().Get("Link"); link != "" { + t.Fatalf("expected no Link header on last page, got %q", link) + } +} + +func TestTagsListEmpty(t *testing.T) { + fdb := newFakeDB() + fdb.addRepo("myrepo", 1) + // No tags. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/myrepo/tags/list", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK) + } + + var body tagListResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode body: %v", err) + } + if len(body.Tags) != 0 { + t.Fatalf("expected empty tags, got %v", body.Tags) + } +} + +func TestTagsListRepoNotFound(t *testing.T) { + fdb := newFakeDB() + // No repos. + + h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil) + router := testRouter(h) + + req := authedRequest("GET", "/v2/nosuchrepo/tags/list", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error body: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" { + t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors) + } +} diff --git a/internal/oci/upload.go b/internal/oci/upload.go new file mode 100644 index 0000000..50a8ce1 --- /dev/null +++ b/internal/oci/upload.go @@ -0,0 +1,241 @@ +package oci + +import ( + "crypto/rand" + "errors" + "fmt" + "io" + "net/http" + "sync" + + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" + "git.wntrmute.dev/kyle/mcr/internal/storage" +) + +// uploadManager tracks in-progress blob writers by UUID. +type uploadManager struct { + mu sync.Mutex + writers map[string]*storage.BlobWriter +} + +func newUploadManager() *uploadManager { + return &uploadManager{writers: make(map[string]*storage.BlobWriter)} +} + +func (m *uploadManager) set(uuid string, bw *storage.BlobWriter) { + m.mu.Lock() + m.writers[uuid] = bw + m.mu.Unlock() +} + +func (m *uploadManager) get(uuid string) (*storage.BlobWriter, bool) { + m.mu.Lock() + bw, ok := m.writers[uuid] + m.mu.Unlock() + return bw, ok +} + +func (m *uploadManager) remove(uuid string) { + m.mu.Lock() + delete(m.writers, uuid) + m.mu.Unlock() +} + +// generateUUID creates a random UUID (v4) string. +func generateUUID() (string, error) { + var buf [16]byte + if _, err := rand.Read(buf[:]); err != nil { + return "", fmt.Errorf("oci: generate uuid: %w", err) + } + // Set version 4 and variant bits. + buf[6] = (buf[6] & 0x0f) | 0x40 + buf[8] = (buf[8] & 0x3f) | 0x80 + return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x", + buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]), nil +} + +// handleUploadInitiate handles POST /v2//blobs/uploads/ +func (h *Handler) handleUploadInitiate(w http.ResponseWriter, r *http.Request, repo string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + // Create repository if it doesn't exist (implicit creation). + repoID, err := h.db.GetOrCreateRepository(repo) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + uuid, err := generateUUID() + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + // Insert upload row in DB. + if err := h.db.CreateUpload(uuid, repoID); err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + // Create temp file via storage. + bw, err := h.blobs.StartUpload(uuid) + if err != nil { + // Clean up DB row on storage failure. + _ = h.db.DeleteUpload(uuid) + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + h.uploads.set(uuid, bw) + + w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid)) + w.Header().Set("Docker-Upload-UUID", uuid) + w.Header().Set("Range", "0-0") + w.WriteHeader(http.StatusAccepted) +} + +// handleUploadChunk handles PATCH /v2//blobs/uploads/ +func (h *Handler) handleUploadChunk(w http.ResponseWriter, r *http.Request, repo, uuid string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + bw, ok := h.uploads.get(uuid) + if !ok { + writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found") + return + } + + // Append request body to upload file. + n, err := io.Copy(bw, r.Body) + if err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "write failed") + return + } + + // Update offset in DB. + newOffset := bw.BytesWritten() + if err := h.db.UpdateUploadOffset(uuid, newOffset); err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + _ = n // bytes written this chunk + + w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid)) + w.Header().Set("Docker-Upload-UUID", uuid) + w.Header().Set("Range", fmt.Sprintf("0-%d", newOffset)) + w.WriteHeader(http.StatusAccepted) +} + +// handleUploadComplete handles PUT /v2//blobs/uploads/?digest= +func (h *Handler) handleUploadComplete(w http.ResponseWriter, r *http.Request, repo, uuid string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + digest := r.URL.Query().Get("digest") + if digest == "" { + writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "digest parameter required") + return + } + + bw, ok := h.uploads.get(uuid) + if !ok { + writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found") + return + } + + // If request body is non-empty, append it first (monolithic upload). + if r.ContentLength != 0 { + if _, err := io.Copy(bw, r.Body); err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "write failed") + return + } + } + + // Commit the blob: verify digest, move to final location. + size := bw.BytesWritten() + _, err := bw.Commit(digest) + if err != nil { + h.uploads.remove(uuid) + if errors.Is(err, storage.ErrDigestMismatch) { + _ = h.db.DeleteUpload(uuid) + writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "digest mismatch") + return + } + if errors.Is(err, storage.ErrInvalidDigest) { + _ = h.db.DeleteUpload(uuid) + writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "invalid digest format") + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "commit failed") + return + } + + h.uploads.remove(uuid) + + // Insert blob row (no-op if already exists — content-addressed dedup). + if err := h.db.InsertBlob(digest, size); err != nil { + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + // Delete upload row. + _ = h.db.DeleteUpload(uuid) + + h.audit(r, "blob_uploaded", repo, digest) + + w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/%s", repo, digest)) + w.Header().Set("Docker-Content-Digest", digest) + w.WriteHeader(http.StatusCreated) +} + +// handleUploadStatus handles GET /v2//blobs/uploads/ +func (h *Handler) handleUploadStatus(w http.ResponseWriter, r *http.Request, repo, uuid string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + upload, err := h.db.GetUpload(uuid) + if err != nil { + if errors.Is(err, db.ErrUploadNotFound) { + writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found") + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid)) + w.Header().Set("Docker-Upload-UUID", uuid) + w.Header().Set("Range", fmt.Sprintf("0-%d", upload.ByteOffset)) + w.WriteHeader(http.StatusNoContent) +} + +// handleUploadCancel handles DELETE /v2//blobs/uploads/ +func (h *Handler) handleUploadCancel(w http.ResponseWriter, r *http.Request, repo, uuid string) { + if !h.checkPolicy(w, r, policy.ActionPush, repo) { + return + } + + bw, ok := h.uploads.get(uuid) + if ok { + _ = bw.Cancel() + h.uploads.remove(uuid) + } + + if err := h.db.DeleteUpload(uuid); err != nil { + if errors.Is(err, db.ErrUploadNotFound) { + writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found") + return + } + writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error") + return + } + + w.WriteHeader(http.StatusNoContent) +} diff --git a/internal/oci/upload_test.go b/internal/oci/upload_test.go new file mode 100644 index 0000000..4d524df --- /dev/null +++ b/internal/oci/upload_test.go @@ -0,0 +1,291 @@ +package oci + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/storage" +) + +// testHandlerWithStorage creates a handler with real storage in t.TempDir(). +func testHandlerWithStorage(t *testing.T, fdb *fakeDB) (*Handler, *chi.Mux) { + t.Helper() + dir := t.TempDir() + store := storage.New(dir+"/layers", dir+"/uploads") + h := NewHandler(fdb, store, allowAll(), nil) + router := chi.NewRouter() + router.Mount("/v2", h.Router()) + return h, router +} + +func authedPushRequest(method, path string, body []byte) *http.Request { + var reader *bytes.Reader + if body != nil { + reader = bytes.NewReader(body) + } else { + reader = bytes.NewReader(nil) + } + req := httptest.NewRequest(method, path, reader) + claims := &auth.Claims{ + Subject: "pusher", + AccountType: "human", + Roles: []string{"user"}, + } + return req.WithContext(auth.ContextWithClaims(req.Context(), claims)) +} + +func TestUploadInitiate(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusAccepted) + } + + loc := rr.Header().Get("Location") + if !strings.HasPrefix(loc, "/v2/myrepo/blobs/uploads/") { + t.Fatalf("Location: got %q", loc) + } + + uuid := rr.Header().Get("Docker-Upload-UUID") + if uuid == "" { + t.Fatal("Docker-Upload-UUID header missing") + } + + rng := rr.Header().Get("Range") + if rng != "0-0" { + t.Fatalf("Range: got %q, want %q", rng, "0-0") + } + + // Verify repo was implicitly created. + if _, ok := fdb.repos["myrepo"]; !ok { + t.Fatal("repository should have been implicitly created") + } +} + +func TestUploadInitiateUniqueUUIDs(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + uuids := make(map[string]bool) + for range 5 { + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("status: got %d", rr.Code) + } + + uuid := rr.Header().Get("Docker-Upload-UUID") + if uuids[uuid] { + t.Fatalf("duplicate UUID: %s", uuid) + } + uuids[uuid] = true + } +} + +func TestMonolithicUpload(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + // Step 1: Initiate upload. + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("initiate status: got %d", rr.Code) + } + uuid := rr.Header().Get("Docker-Upload-UUID") + + // Step 2: Complete upload with body and digest in a single PUT. + blobData := []byte("hello world blob data") + sum := sha256.Sum256(blobData) + digest := "sha256:" + hex.EncodeToString(sum[:]) + + putURL := "/v2/myrepo/blobs/uploads/" + uuid + "?digest=" + digest + req = authedPushRequest("PUT", putURL, blobData) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("complete status: got %d, body: %s", rr.Code, rr.Body.String()) + } + + loc := rr.Header().Get("Location") + if !strings.Contains(loc, digest) { + t.Fatalf("Location should contain digest: got %q", loc) + } + + dcd := rr.Header().Get("Docker-Content-Digest") + if dcd != digest { + t.Fatalf("Docker-Content-Digest: got %q, want %q", dcd, digest) + } + + // Verify blob was inserted in fake DB. + if !fdb.allBlobs[digest] { + t.Fatal("blob should exist in DB after upload") + } +} + +func TestChunkedUpload(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + // Step 1: Initiate. + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + uuid := rr.Header().Get("Docker-Upload-UUID") + + // Step 2: PATCH chunk 1. + chunk1 := []byte("chunk-one-data-") + req = authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/"+uuid, chunk1) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("patch 1 status: got %d", rr.Code) + } + + // Step 3: PATCH chunk 2. + chunk2 := []byte("chunk-two-data") + req = authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/"+uuid, chunk2) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusAccepted { + t.Fatalf("patch 2 status: got %d", rr.Code) + } + + // Step 4: Complete with PUT. + allData := append(chunk1, chunk2...) + sum := sha256.Sum256(allData) + digest := "sha256:" + hex.EncodeToString(sum[:]) + + req = authedPushRequest("PUT", "/v2/myrepo/blobs/uploads/"+uuid+"?digest="+digest, nil) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusCreated { + t.Fatalf("complete status: got %d, body: %s", rr.Code, rr.Body.String()) + } + if rr.Header().Get("Docker-Content-Digest") != digest { + t.Fatalf("Docker-Content-Digest: got %q, want %q", rr.Header().Get("Docker-Content-Digest"), digest) + } +} + +func TestUploadDigestMismatch(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + // Initiate. + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + uuid := rr.Header().Get("Docker-Upload-UUID") + + // Complete with wrong digest. + blobData := []byte("some data") + wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000" + + req = authedPushRequest("PUT", "/v2/myrepo/blobs/uploads/"+uuid+"?digest="+wrongDigest, blobData) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusBadRequest { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "DIGEST_INVALID" { + t.Fatalf("error code: got %+v, want DIGEST_INVALID", body.Errors) + } +} + +func TestUploadStatus(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + // Initiate. + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + uuid := rr.Header().Get("Docker-Upload-UUID") + + // Check status. + req = authedPushRequest("GET", "/v2/myrepo/blobs/uploads/"+uuid, nil) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNoContent) + } + + if rr.Header().Get("Docker-Upload-UUID") != uuid { + t.Fatalf("Docker-Upload-UUID: got %q", rr.Header().Get("Docker-Upload-UUID")) + } +} + +func TestUploadCancel(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + // Initiate. + req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + uuid := rr.Header().Get("Docker-Upload-UUID") + + // Cancel. + req = authedPushRequest("DELETE", "/v2/myrepo/blobs/uploads/"+uuid, nil) + rr = httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNoContent { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNoContent) + } + + // Verify upload was removed from DB. + if _, ok := fdb.uploads[uuid]; ok { + t.Fatal("upload should have been deleted from DB") + } +} + +func TestUploadNonexistentUUID(t *testing.T) { + fdb := newFakeDB() + _, router := testHandlerWithStorage(t, fdb) + + req := authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/nonexistent-uuid", []byte("data")) + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + + if rr.Code != http.StatusNotFound { + t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound) + } + + var body ociErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&body); err != nil { + t.Fatalf("decode error: %v", err) + } + if len(body.Errors) != 1 || body.Errors[0].Code != "BLOB_UPLOAD_UNKNOWN" { + t.Fatalf("error code: got %+v, want BLOB_UPLOAD_UNKNOWN", body.Errors) + } +} diff --git a/internal/server/admin.go b/internal/server/admin.go new file mode 100644 index 0000000..4146712 --- /dev/null +++ b/internal/server/admin.go @@ -0,0 +1,52 @@ +package server + +import ( + "encoding/json" + "net/http" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +type adminErrorResponse struct { + Error string `json:"error"` +} + +func writeAdminError(w http.ResponseWriter, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(adminErrorResponse{Error: message}) +} + +func writeJSON(w http.ResponseWriter, status int, v any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(v) +} + +// RequireAdmin returns middleware that checks for the admin role. +// Returns 403 with an admin error format if the caller is not an admin. +func RequireAdmin() func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims := auth.ClaimsFromContext(r.Context()) + if claims == nil { + writeAdminError(w, http.StatusUnauthorized, "authentication required") + return + } + if !hasRole(claims.Roles, "admin") { + writeAdminError(w, http.StatusForbidden, "admin role required") + return + } + next.ServeHTTP(w, r) + }) + } +} + +func hasRole(roles []string, target string) bool { + for _, r := range roles { + if r == target { + return true + } + } + return false +} diff --git a/internal/server/admin_audit.go b/internal/server/admin_audit.go new file mode 100644 index 0000000..b6ba860 --- /dev/null +++ b/internal/server/admin_audit.go @@ -0,0 +1,48 @@ +package server + +import ( + "net/http" + "strconv" + + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +// AdminListAuditHandler handles GET /v1/audit. +func AdminListAuditHandler(database *db.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + q := r.URL.Query() + + limit := 50 + if n := q.Get("n"); n != "" { + if v, err := strconv.Atoi(n); err == nil && v > 0 { + limit = v + } + } + offset := 0 + if o := q.Get("offset"); o != "" { + if v, err := strconv.Atoi(o); err == nil && v >= 0 { + offset = v + } + } + + filter := db.AuditFilter{ + EventType: q.Get("event_type"), + ActorID: q.Get("actor_id"), + Repository: q.Get("repository"), + Since: q.Get("since"), + Until: q.Get("until"), + Limit: limit, + Offset: offset, + } + + events, err := database.ListAuditEvents(filter) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + if events == nil { + events = []db.AuditEvent{} + } + writeJSON(w, http.StatusOK, events) + } +} diff --git a/internal/server/admin_audit_test.go b/internal/server/admin_audit_test.go new file mode 100644 index 0000000..ebc15f5 --- /dev/null +++ b/internal/server/admin_audit_test.go @@ -0,0 +1,152 @@ +package server + +import ( + "encoding/json" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +func TestAdminListAuditEvents(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // Seed some audit events. + if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "10.0.0.1", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "sha256:abc", "10.0.0.1", + map[string]string{"tag": "latest"}); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "10.0.0.2", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + rr := adminReq(t, router, "GET", "/v1/audit", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var events []db.AuditEvent + if err := json.NewDecoder(rr.Body).Decode(&events); err != nil { + t.Fatalf("decode: %v", err) + } + if len(events) != 3 { + t.Fatalf("event count: got %d, want 3", len(events)) + } +} + +func TestAdminListAuditEventsWithFilter(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + // Filter by event_type. + rr := adminReq(t, router, "GET", "/v1/audit?event_type=login_ok", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var events []db.AuditEvent + if err := json.NewDecoder(rr.Body).Decode(&events); err != nil { + t.Fatalf("decode: %v", err) + } + if len(events) != 2 { + t.Fatalf("event count: got %d, want 2", len(events)) + } + for _, e := range events { + if e.EventType != "login_ok" { + t.Fatalf("unexpected event type: %q", e.EventType) + } + } +} + +func TestAdminListAuditEventsFilterByActor(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + if err := database.WriteAuditEvent("login_ok", "actor-a", "", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + if err := database.WriteAuditEvent("login_ok", "actor-b", "", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + + rr := adminReq(t, router, "GET", "/v1/audit?actor_id=actor-a", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var events []db.AuditEvent + if err := json.NewDecoder(rr.Body).Decode(&events); err != nil { + t.Fatalf("decode: %v", err) + } + if len(events) != 1 { + t.Fatalf("event count: got %d, want 1", len(events)) + } + if events[0].ActorID != "actor-a" { + t.Fatalf("actor_id: got %q, want %q", events[0].ActorID, "actor-a") + } +} + +func TestAdminListAuditEventsPagination(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + for range 5 { + if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil { + t.Fatalf("WriteAuditEvent: %v", err) + } + } + + rr := adminReq(t, router, "GET", "/v1/audit?n=2&offset=0", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var events []db.AuditEvent + if err := json.NewDecoder(rr.Body).Decode(&events); err != nil { + t.Fatalf("decode: %v", err) + } + if len(events) != 2 { + t.Fatalf("event count: got %d, want 2", len(events)) + } +} + +func TestAdminListAuditEventsEmpty(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/audit", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var events []db.AuditEvent + if err := json.NewDecoder(rr.Body).Decode(&events); err != nil { + t.Fatalf("decode: %v", err) + } + if len(events) != 0 { + t.Fatalf("event count: got %d, want 0", len(events)) + } +} + +func TestAdminListAuditEventsNonAdmin(t *testing.T) { + database := openAdminTestDB(t) + router := buildNonAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/audit", "") + if rr.Code != 403 { + t.Fatalf("status: got %d, want 403", rr.Code) + } +} diff --git a/internal/server/admin_auth.go b/internal/server/admin_auth.go new file mode 100644 index 0000000..c4d1203 --- /dev/null +++ b/internal/server/admin_auth.go @@ -0,0 +1,60 @@ +package server + +import ( + "encoding/json" + "net/http" + "time" +) + +type adminLoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type adminLoginResponse struct { + Token string `json:"token"` + ExpiresAt string `json:"expires_at"` +} + +// AdminLoginHandler handles POST /v1/auth/login. +func AdminLoginHandler(loginClient LoginClient) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req adminLoginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid request body") + return + } + if req.Username == "" || req.Password == "" { + writeAdminError(w, http.StatusBadRequest, "username and password required") + return + } + + token, expiresIn, err := loginClient.Login(req.Username, req.Password) + if err != nil { + writeAdminError(w, http.StatusUnauthorized, "authentication failed") + return + } + + expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339) + writeJSON(w, http.StatusOK, adminLoginResponse{ + Token: token, + ExpiresAt: expiresAt, + }) + } +} + +// AdminLogoutHandler handles POST /v1/auth/logout. +func AdminLogoutHandler() http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + // MCIAS token revocation is not currently supported. + // The client should discard the token. + w.WriteHeader(http.StatusNoContent) + } +} + +// AdminHealthHandler handles GET /v1/health. +func AdminHealthHandler() http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + writeJSON(w, http.StatusOK, map[string]string{"status": "ok"}) + } +} diff --git a/internal/server/admin_auth_test.go b/internal/server/admin_auth_test.go new file mode 100644 index 0000000..6b0f083 --- /dev/null +++ b/internal/server/admin_auth_test.go @@ -0,0 +1,116 @@ +package server + +import ( + "encoding/json" + "testing" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +func TestAdminHealthHandler(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // Health endpoint does not require auth. + rr := adminReq(t, router, "GET", "/v1/health", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var resp map[string]string + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp["status"] != "ok" { + t.Fatalf("status field: got %q, want %q", resp["status"], "ok") + } +} + +func TestAdminLoginSuccess(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + body := `{"username":"admin","password":"secret"}` + rr := adminReq(t, router, "POST", "/v1/auth/login", body) + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + + var resp adminLoginResponse + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.Token != "test-token" { + t.Fatalf("token: got %q, want %q", resp.Token, "test-token") + } + if resp.ExpiresAt == "" { + t.Fatal("expires_at: expected non-empty") + } +} + +func TestAdminLoginInvalidCreds(t *testing.T) { + database := openAdminTestDB(t) + + validator := &fakeValidator{ + claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}}, + } + login := &fakeLoginClient{err: auth.ErrUnauthorized} + reloader := &fakePolicyReloader{} + gcState := &GCState{} + + r := chi.NewRouter() + MountAdminRoutes(r, validator, "mcr-test", AdminDeps{ + DB: database, + Login: login, + Engine: reloader, + AuditFn: nil, + GCState: gcState, + }) + + body := `{"username":"admin","password":"wrong"}` + rr := adminReq(t, r, "POST", "/v1/auth/login", body) + if rr.Code != 401 { + t.Fatalf("status: got %d, want 401", rr.Code) + } + + var errResp adminErrorResponse + if err := json.NewDecoder(rr.Body).Decode(&errResp); err != nil { + t.Fatalf("decode: %v", err) + } + if errResp.Error != "authentication failed" { + t.Fatalf("error: got %q, want %q", errResp.Error, "authentication failed") + } +} + +func TestAdminLoginMissingFields(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + body := `{"username":"admin"}` + rr := adminReq(t, router, "POST", "/v1/auth/login", body) + if rr.Code != 400 { + t.Fatalf("status: got %d, want 400", rr.Code) + } +} + +func TestAdminLoginBadJSON(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "POST", "/v1/auth/login", "not json") + if rr.Code != 400 { + t.Fatalf("status: got %d, want 400", rr.Code) + } +} + +func TestAdminLogout(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "POST", "/v1/auth/logout", "") + if rr.Code != 204 { + t.Fatalf("status: got %d, want 204", rr.Code) + } +} diff --git a/internal/server/admin_gc.go b/internal/server/admin_gc.go new file mode 100644 index 0000000..dd0e40f --- /dev/null +++ b/internal/server/admin_gc.go @@ -0,0 +1,66 @@ +package server + +import ( + "net/http" + "sync" + + "github.com/google/uuid" +) + +// GCLastRun records the result of the last GC run. +type GCLastRun struct { + StartedAt string `json:"started_at"` + CompletedAt string `json:"completed_at,omitempty"` + BlobsRemoved int `json:"blobs_removed"` + BytesFreed int64 `json:"bytes_freed"` +} + +// GCState tracks the current state of garbage collection. +type GCState struct { + mu sync.Mutex + Running bool `json:"running"` + LastRun *GCLastRun `json:"last_run,omitempty"` +} + +type gcStatusResponse struct { + Running bool `json:"running"` + LastRun *GCLastRun `json:"last_run,omitempty"` +} + +type gcTriggerResponse struct { + ID string `json:"id"` +} + +// AdminTriggerGCHandler handles POST /v1/gc. +func AdminTriggerGCHandler(state *GCState) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + state.mu.Lock() + if state.Running { + state.mu.Unlock() + writeAdminError(w, http.StatusConflict, "garbage collection already running") + return + } + state.Running = true + state.mu.Unlock() + + // GC engine is Phase 9 -- for now, just mark as running and return. + // The actual GC goroutine will be wired up in Phase 9. + gcID := uuid.New().String() + + writeJSON(w, http.StatusAccepted, gcTriggerResponse{ID: gcID}) + } +} + +// AdminGCStatusHandler handles GET /v1/gc/status. +func AdminGCStatusHandler(state *GCState) http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + state.mu.Lock() + resp := gcStatusResponse{ + Running: state.Running, + LastRun: state.LastRun, + } + state.mu.Unlock() + + writeJSON(w, http.StatusOK, resp) + } +} diff --git a/internal/server/admin_gc_test.go b/internal/server/admin_gc_test.go new file mode 100644 index 0000000..bd0ab2e --- /dev/null +++ b/internal/server/admin_gc_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "encoding/json" + "testing" +) + +func TestAdminTriggerGC(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "POST", "/v1/gc", "") + if rr.Code != 202 { + t.Fatalf("status: got %d, want 202; body: %s", rr.Code, rr.Body.String()) + } + + var resp gcTriggerResponse + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.ID == "" { + t.Fatal("expected non-empty GC ID") + } +} + +func TestAdminTriggerGCConflict(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // First trigger should succeed. + rr := adminReq(t, router, "POST", "/v1/gc", "") + if rr.Code != 202 { + t.Fatalf("first trigger status: got %d, want 202", rr.Code) + } + + // Second trigger should conflict because GC is still "running". + rr = adminReq(t, router, "POST", "/v1/gc", "") + if rr.Code != 409 { + t.Fatalf("second trigger status: got %d, want 409; body: %s", rr.Code, rr.Body.String()) + } +} + +func TestAdminGCStatus(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // Before triggering GC. + rr := adminReq(t, router, "GET", "/v1/gc/status", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var resp gcStatusResponse + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if resp.Running { + t.Fatal("expected running=false before trigger") + } + + // Trigger GC. + rr = adminReq(t, router, "POST", "/v1/gc", "") + if rr.Code != 202 { + t.Fatalf("trigger status: got %d, want 202", rr.Code) + } + + // After triggering GC. + rr = adminReq(t, router, "GET", "/v1/gc/status", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil { + t.Fatalf("decode: %v", err) + } + if !resp.Running { + t.Fatal("expected running=true after trigger") + } +} + +func TestAdminGCNonAdmin(t *testing.T) { + database := openAdminTestDB(t) + router := buildNonAdminRouter(t, database) + + rr := adminReq(t, router, "POST", "/v1/gc", "") + if rr.Code != 403 { + t.Fatalf("trigger status: got %d, want 403", rr.Code) + } + + rr = adminReq(t, router, "GET", "/v1/gc/status", "") + if rr.Code != 403 { + t.Fatalf("status status: got %d, want 403", rr.Code) + } +} diff --git a/internal/server/admin_policy.go b/internal/server/admin_policy.go new file mode 100644 index 0000000..2fae888 --- /dev/null +++ b/internal/server/admin_policy.go @@ -0,0 +1,343 @@ +package server + +import ( + "encoding/json" + "errors" + "fmt" + "net/http" + "strconv" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +// PolicyReloader can reload policy rules from a store. +type PolicyReloader interface { + Reload(store policy.RuleStore) error +} + +// policyCreateRequest is the JSON body for creating a policy rule. +type policyCreateRequest struct { + Priority int `json:"priority"` + Description string `json:"description"` + Effect string `json:"effect"` + Roles []string `json:"roles,omitempty"` + AccountTypes []string `json:"account_types,omitempty"` + SubjectUUID string `json:"subject_uuid,omitempty"` + Actions []string `json:"actions"` + Repositories []string `json:"repositories,omitempty"` + Enabled *bool `json:"enabled,omitempty"` // pointer to distinguish unset from false +} + +// policyUpdateRequest is the JSON body for updating a policy rule. +type policyUpdateRequest struct { + Priority *int `json:"priority,omitempty"` + Description *string `json:"description,omitempty"` + Effect *string `json:"effect,omitempty"` + Roles []string `json:"roles,omitempty"` + AccountTypes []string `json:"account_types,omitempty"` + SubjectUUID *string `json:"subject_uuid,omitempty"` + Actions []string `json:"actions,omitempty"` + Repositories []string `json:"repositories,omitempty"` + Enabled *bool `json:"enabled,omitempty"` +} + +var validActions = map[string]bool{ + string(policy.ActionVersionCheck): true, + string(policy.ActionPull): true, + string(policy.ActionPush): true, + string(policy.ActionDelete): true, + string(policy.ActionCatalog): true, + string(policy.ActionPolicyManage): true, +} + +func validateActions(actions []string) error { + for _, a := range actions { + if !validActions[a] { + return fmt.Errorf("invalid action: %q", a) + } + } + return nil +} + +func validateEffect(effect string) error { + if effect != "allow" && effect != "deny" { + return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect) + } + return nil +} + +// AdminListPolicyRulesHandler handles GET /v1/policy/rules. +func AdminListPolicyRulesHandler(database *db.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + limit := 50 + offset := 0 + if n := r.URL.Query().Get("n"); n != "" { + if v, err := strconv.Atoi(n); err == nil && v > 0 { + limit = v + } + } + if o := r.URL.Query().Get("offset"); o != "" { + if v, err := strconv.Atoi(o); err == nil && v >= 0 { + offset = v + } + } + + rules, err := database.ListPolicyRules(limit, offset) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + if rules == nil { + rules = []db.PolicyRuleRow{} + } + writeJSON(w, http.StatusOK, rules) + } +} + +// AdminCreatePolicyRuleHandler handles POST /v1/policy/rules. +func AdminCreatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var req policyCreateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Priority < 1 { + writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)") + return + } + if req.Description == "" { + writeAdminError(w, http.StatusBadRequest, "description is required") + return + } + if err := validateEffect(req.Effect); err != nil { + writeAdminError(w, http.StatusBadRequest, err.Error()) + return + } + if len(req.Actions) == 0 { + writeAdminError(w, http.StatusBadRequest, "at least one action is required") + return + } + if err := validateActions(req.Actions); err != nil { + writeAdminError(w, http.StatusBadRequest, err.Error()) + return + } + + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + + claims := auth.ClaimsFromContext(r.Context()) + createdBy := "" + if claims != nil { + createdBy = claims.Subject + } + + row := db.PolicyRuleRow{ + Priority: req.Priority, + Description: req.Description, + Effect: req.Effect, + Roles: req.Roles, + AccountTypes: req.AccountTypes, + SubjectUUID: req.SubjectUUID, + Actions: req.Actions, + Repositories: req.Repositories, + Enabled: enabled, + CreatedBy: createdBy, + } + + id, err := database.CreatePolicyRule(row) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + + // Reload policy engine. + if engine != nil { + _ = engine.Reload(database) + } + + if auditFn != nil { + auditFn("policy_rule_created", createdBy, "", "", r.RemoteAddr, map[string]string{ + "rule_id": strconv.FormatInt(id, 10), + }) + } + + created, err := database.GetPolicyRule(id) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + writeJSON(w, http.StatusCreated, created) + } +} + +// AdminGetPolicyRuleHandler handles GET /v1/policy/rules/{id}. +func AdminGetPolicyRuleHandler(database *db.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid rule ID") + return + } + + rule, err := database.GetPolicyRule(id) + if err != nil { + if errors.Is(err, db.ErrPolicyRuleNotFound) { + writeAdminError(w, http.StatusNotFound, "policy rule not found") + return + } + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + writeJSON(w, http.StatusOK, rule) + } +} + +// AdminUpdatePolicyRuleHandler handles PATCH /v1/policy/rules/{id}. +func AdminUpdatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid rule ID") + return + } + + var req policyUpdateRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid request body") + return + } + + if req.Priority != nil && *req.Priority < 1 { + writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)") + return + } + if req.Effect != nil { + if err := validateEffect(*req.Effect); err != nil { + writeAdminError(w, http.StatusBadRequest, err.Error()) + return + } + } + if req.Actions != nil { + if len(req.Actions) == 0 { + writeAdminError(w, http.StatusBadRequest, "at least one action is required") + return + } + if err := validateActions(req.Actions); err != nil { + writeAdminError(w, http.StatusBadRequest, err.Error()) + return + } + } + + updates := db.PolicyRuleRow{} + if req.Priority != nil { + updates.Priority = *req.Priority + } + if req.Description != nil { + updates.Description = *req.Description + } + if req.Effect != nil { + updates.Effect = *req.Effect + } + if req.Roles != nil { + updates.Roles = req.Roles + } + if req.AccountTypes != nil { + updates.AccountTypes = req.AccountTypes + } + if req.SubjectUUID != nil { + updates.SubjectUUID = *req.SubjectUUID + } + if req.Actions != nil { + updates.Actions = req.Actions + } + if req.Repositories != nil { + updates.Repositories = req.Repositories + } + + if err := database.UpdatePolicyRule(id, updates); err != nil { + if errors.Is(err, db.ErrPolicyRuleNotFound) { + writeAdminError(w, http.StatusNotFound, "policy rule not found") + return + } + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + + // Handle enabled separately since it's a bool. + if req.Enabled != nil { + if err := database.SetPolicyRuleEnabled(id, *req.Enabled); err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + } + + // Reload policy engine. + if engine != nil { + _ = engine.Reload(database) + } + + if auditFn != nil { + claims := auth.ClaimsFromContext(r.Context()) + actorID := "" + if claims != nil { + actorID = claims.Subject + } + auditFn("policy_rule_updated", actorID, "", "", r.RemoteAddr, map[string]string{ + "rule_id": strconv.FormatInt(id, 10), + }) + } + + updated, err := database.GetPolicyRule(id) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + writeJSON(w, http.StatusOK, updated) + } +} + +// AdminDeletePolicyRuleHandler handles DELETE /v1/policy/rules/{id}. +func AdminDeletePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64) + if err != nil { + writeAdminError(w, http.StatusBadRequest, "invalid rule ID") + return + } + + if err := database.DeletePolicyRule(id); err != nil { + if errors.Is(err, db.ErrPolicyRuleNotFound) { + writeAdminError(w, http.StatusNotFound, "policy rule not found") + return + } + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + + // Reload policy engine. + if engine != nil { + _ = engine.Reload(database) + } + + if auditFn != nil { + claims := auth.ClaimsFromContext(r.Context()) + actorID := "" + if claims != nil { + actorID = claims.Subject + } + auditFn("policy_rule_deleted", actorID, "", "", r.RemoteAddr, map[string]string{ + "rule_id": strconv.FormatInt(id, 10), + }) + } + + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/internal/server/admin_policy_test.go b/internal/server/admin_policy_test.go new file mode 100644 index 0000000..95d9c25 --- /dev/null +++ b/internal/server/admin_policy_test.go @@ -0,0 +1,337 @@ +package server + +import ( + "encoding/json" + "fmt" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +func TestAdminPolicyCRUDCycle(t *testing.T) { + database := openAdminTestDB(t) + router, reloader := buildAdminRouter(t, database) + + // Create a rule. + createBody := `{ + "priority": 50, + "description": "allow CI push", + "effect": "allow", + "roles": ["ci"], + "actions": ["registry:push", "registry:pull"], + "repositories": ["production/*"], + "enabled": true + }` + rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody) + if rr.Code != 201 { + t.Fatalf("create status: got %d, want 201; body: %s", rr.Code, rr.Body.String()) + } + if reloader.reloadCount != 1 { + t.Fatalf("reload count after create: got %d, want 1", reloader.reloadCount) + } + + var created db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&created); err != nil { + t.Fatalf("decode created: %v", err) + } + if created.ID == 0 { + t.Fatal("expected non-zero ID") + } + if created.Effect != "allow" { + t.Fatalf("effect: got %q, want %q", created.Effect, "allow") + } + if len(created.Roles) != 1 || created.Roles[0] != "ci" { + t.Fatalf("roles: got %v, want [ci]", created.Roles) + } + if created.CreatedBy != "admin-uuid" { + t.Fatalf("created_by: got %q, want %q", created.CreatedBy, "admin-uuid") + } + + // Get the rule. + rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "") + if rr.Code != 200 { + t.Fatalf("get status: got %d, want 200", rr.Code) + } + + var got db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&got); err != nil { + t.Fatalf("decode got: %v", err) + } + if got.ID != created.ID { + t.Fatalf("id: got %d, want %d", got.ID, created.ID) + } + + // List rules. + rr = adminReq(t, router, "GET", "/v1/policy/rules/", "") + if rr.Code != 200 { + t.Fatalf("list status: got %d, want 200", rr.Code) + } + + var rules []db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil { + t.Fatalf("decode list: %v", err) + } + if len(rules) != 1 { + t.Fatalf("rule count: got %d, want 1", len(rules)) + } + + // Update the rule. + updateBody := `{"priority": 25, "description": "updated CI push"}` + rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), updateBody) + if rr.Code != 200 { + t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + if reloader.reloadCount != 2 { + t.Fatalf("reload count after update: got %d, want 2", reloader.reloadCount) + } + + var updated db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil { + t.Fatalf("decode updated: %v", err) + } + if updated.Priority != 25 { + t.Fatalf("updated priority: got %d, want 25", updated.Priority) + } + if updated.Description != "updated CI push" { + t.Fatalf("updated description: got %q, want %q", updated.Description, "updated CI push") + } + // Effect should be unchanged. + if updated.Effect != "allow" { + t.Fatalf("updated effect: got %q, want %q (unchanged)", updated.Effect, "allow") + } + + // Delete the rule. + rr = adminReq(t, router, "DELETE", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "") + if rr.Code != 204 { + t.Fatalf("delete status: got %d, want 204", rr.Code) + } + if reloader.reloadCount != 3 { + t.Fatalf("reload count after delete: got %d, want 3", reloader.reloadCount) + } + + // Verify it's gone. + rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "") + if rr.Code != 404 { + t.Fatalf("after delete status: got %d, want 404", rr.Code) + } +} + +func TestAdminCreatePolicyRuleValidation(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + tests := []struct { + name string + body string + want int + }{ + { + name: "priority too low", + body: `{"priority":0,"description":"test","effect":"allow","actions":["registry:pull"]}`, + want: 400, + }, + { + name: "missing description", + body: `{"priority":1,"effect":"allow","actions":["registry:pull"]}`, + want: 400, + }, + { + name: "invalid effect", + body: `{"priority":1,"description":"test","effect":"maybe","actions":["registry:pull"]}`, + want: 400, + }, + { + name: "no actions", + body: `{"priority":1,"description":"test","effect":"allow","actions":[]}`, + want: 400, + }, + { + name: "invalid action", + body: `{"priority":1,"description":"test","effect":"allow","actions":["bogus:action"]}`, + want: 400, + }, + { + name: "bad JSON", + body: `not json`, + want: 400, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rr := adminReq(t, router, "POST", "/v1/policy/rules/", tc.body) + if rr.Code != tc.want { + t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String()) + } + }) + } +} + +func TestAdminUpdatePolicyRuleValidation(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // Create a rule to update. + createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"]}` + rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody) + if rr.Code != 201 { + t.Fatalf("create status: got %d, want 201", rr.Code) + } + var created db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&created); err != nil { + t.Fatalf("decode: %v", err) + } + + tests := []struct { + name string + body string + want int + }{ + { + name: "priority too low", + body: `{"priority":0}`, + want: 400, + }, + { + name: "invalid effect", + body: `{"effect":"maybe"}`, + want: 400, + }, + { + name: "empty actions", + body: `{"actions":[]}`, + want: 400, + }, + { + name: "invalid action", + body: `{"actions":["bogus:action"]}`, + want: 400, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + rr := adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), tc.body) + if rr.Code != tc.want { + t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String()) + } + }) + } +} + +func TestAdminUpdatePolicyRuleNotFound(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "PATCH", "/v1/policy/rules/9999", `{"description":"nope"}`) + if rr.Code != 404 { + t.Fatalf("status: got %d, want 404", rr.Code) + } +} + +func TestAdminDeletePolicyRuleNotFound(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "DELETE", "/v1/policy/rules/9999", "") + if rr.Code != 404 { + t.Fatalf("status: got %d, want 404", rr.Code) + } +} + +func TestAdminGetPolicyRuleNotFound(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/policy/rules/9999", "") + if rr.Code != 404 { + t.Fatalf("status: got %d, want 404", rr.Code) + } +} + +func TestAdminGetPolicyRuleInvalidID(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/policy/rules/not-a-number", "") + if rr.Code != 400 { + t.Fatalf("status: got %d, want 400", rr.Code) + } +} + +func TestAdminPolicyRulesNonAdmin(t *testing.T) { + database := openAdminTestDB(t) + router := buildNonAdminRouter(t, database) + + // All policy rule endpoints require admin. + endpoints := []struct { + method string + path string + body string + }{ + {"GET", "/v1/policy/rules/", ""}, + {"POST", "/v1/policy/rules/", `{"priority":1,"description":"test","effect":"allow","actions":["registry:pull"]}`}, + {"GET", "/v1/policy/rules/1", ""}, + {"PATCH", "/v1/policy/rules/1", `{"description":"updated"}`}, + {"DELETE", "/v1/policy/rules/1", ""}, + } + + for _, ep := range endpoints { + t.Run(ep.method+" "+ep.path, func(t *testing.T) { + rr := adminReq(t, router, ep.method, ep.path, ep.body) + if rr.Code != 403 { + t.Fatalf("status: got %d, want 403; body: %s", rr.Code, rr.Body.String()) + } + }) + } +} + +func TestAdminUpdatePolicyRuleEnabled(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + // Create an enabled rule. + createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"],"enabled":true}` + rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody) + if rr.Code != 201 { + t.Fatalf("create status: got %d, want 201", rr.Code) + } + var created db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&created); err != nil { + t.Fatalf("decode: %v", err) + } + if !created.Enabled { + t.Fatal("expected created rule to be enabled") + } + + // Disable it. + rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), `{"enabled":false}`) + if rr.Code != 200 { + t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + var updated db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil { + t.Fatalf("decode: %v", err) + } + if updated.Enabled { + t.Fatal("expected updated rule to be disabled") + } +} + +func TestAdminListPolicyRulesEmpty(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/policy/rules/", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var rules []db.PolicyRuleRow + if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil { + t.Fatalf("decode: %v", err) + } + if len(rules) != 0 { + t.Fatalf("rule count: got %d, want 0", len(rules)) + } +} diff --git a/internal/server/admin_repo.go b/internal/server/admin_repo.go new file mode 100644 index 0000000..b61f090 --- /dev/null +++ b/internal/server/admin_repo.go @@ -0,0 +1,94 @@ +package server + +import ( + "errors" + "net/http" + "strconv" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +// AdminListReposHandler handles GET /v1/repositories. +func AdminListReposHandler(database *db.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + limit := 50 + offset := 0 + if n := r.URL.Query().Get("n"); n != "" { + if v, err := strconv.Atoi(n); err == nil && v > 0 { + limit = v + } + } + if o := r.URL.Query().Get("offset"); o != "" { + if v, err := strconv.Atoi(o); err == nil && v >= 0 { + offset = v + } + } + + repos, err := database.ListRepositoriesWithMetadata(limit, offset) + if err != nil { + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + if repos == nil { + repos = []db.RepoMetadata{} + } + writeJSON(w, http.StatusOK, repos) + } +} + +// AdminGetRepoHandler handles GET /v1/repositories/*. +func AdminGetRepoHandler(database *db.DB) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "*") + if name == "" { + writeAdminError(w, http.StatusBadRequest, "repository name required") + return + } + + detail, err := database.GetRepositoryDetail(name) + if err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeAdminError(w, http.StatusNotFound, "repository not found") + return + } + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + writeJSON(w, http.StatusOK, detail) + } +} + +// AdminDeleteRepoHandler handles DELETE /v1/repositories/*. +// Requires admin role (enforced by RequireAdmin middleware). +func AdminDeleteRepoHandler(database *db.DB, auditFn AuditFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + name := chi.URLParam(r, "*") + if name == "" { + writeAdminError(w, http.StatusBadRequest, "repository name required") + return + } + + if err := database.DeleteRepository(name); err != nil { + if errors.Is(err, db.ErrRepoNotFound) { + writeAdminError(w, http.StatusNotFound, "repository not found") + return + } + writeAdminError(w, http.StatusInternalServerError, "internal error") + return + } + + if auditFn != nil { + claims := auth.ClaimsFromContext(r.Context()) + actorID := "" + if claims != nil { + actorID = claims.Subject + } + auditFn("repo_deleted", actorID, name, "", r.RemoteAddr, nil) + } + + w.WriteHeader(http.StatusNoContent) + } +} diff --git a/internal/server/admin_repo_test.go b/internal/server/admin_repo_test.go new file mode 100644 index 0000000..9581b41 --- /dev/null +++ b/internal/server/admin_repo_test.go @@ -0,0 +1,186 @@ +package server + +import ( + "encoding/json" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +// seedRepoForAdmin inserts a repository with a manifest and tags into the test DB. +func seedRepoForAdmin(t *testing.T, database *db.DB, name string, tags []string) { + t.Helper() + _, err := database.Exec(`INSERT INTO repositories (name) VALUES (?)`, name) + if err != nil { + t.Fatalf("insert repo %q: %v", name, err) + } + var repoID int64 + if err := database.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&repoID); err != nil { + t.Fatalf("select repo id: %v", err) + } + _, err = database.Exec( + `INSERT INTO manifests (repository_id, digest, media_type, content, size) + VALUES (?, ?, 'application/vnd.oci.image.manifest.v1+json', '{}', 512)`, + repoID, "sha256:manifest-"+name, + ) + if err != nil { + t.Fatalf("insert manifest: %v", err) + } + var manifestID int64 + if err := database.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil { + t.Fatalf("select manifest id: %v", err) + } + for _, tag := range tags { + _, err := database.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`, + repoID, tag, manifestID) + if err != nil { + t.Fatalf("insert tag %q: %v", tag, err) + } + } +} + +func TestAdminListRepos(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + seedRepoForAdmin(t, database, "alpha/app", []string{"latest", "v1.0"}) + seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"}) + + rr := adminReq(t, router, "GET", "/v1/repositories", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var repos []db.RepoMetadata + if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil { + t.Fatalf("decode: %v", err) + } + if len(repos) != 2 { + t.Fatalf("repo count: got %d, want 2", len(repos)) + } + if repos[0].Name != "alpha/app" { + t.Fatalf("first repo: got %q, want %q", repos[0].Name, "alpha/app") + } + if repos[0].TagCount != 2 { + t.Fatalf("tag count: got %d, want 2", repos[0].TagCount) + } + if repos[0].ManifestCount != 1 { + t.Fatalf("manifest count: got %d, want 1", repos[0].ManifestCount) + } +} + +func TestAdminListReposEmpty(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/repositories", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var repos []db.RepoMetadata + if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil { + t.Fatalf("decode: %v", err) + } + if len(repos) != 0 { + t.Fatalf("repo count: got %d, want 0", len(repos)) + } +} + +func TestAdminListReposPagination(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + seedRepoForAdmin(t, database, "alpha/app", []string{"latest"}) + seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"}) + seedRepoForAdmin(t, database, "charlie/svc", []string{"latest"}) + + rr := adminReq(t, router, "GET", "/v1/repositories?n=2&offset=0", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200", rr.Code) + } + + var repos []db.RepoMetadata + if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil { + t.Fatalf("decode: %v", err) + } + if len(repos) != 2 { + t.Fatalf("page 1 count: got %d, want 2", len(repos)) + } +} + +func TestAdminGetRepoDetail(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest", "v1.0"}) + + rr := adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "") + if rr.Code != 200 { + t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String()) + } + + var detail db.RepoDetail + if err := json.NewDecoder(rr.Body).Decode(&detail); err != nil { + t.Fatalf("decode: %v", err) + } + if detail.Name != "myorg/myapp" { + t.Fatalf("name: got %q, want %q", detail.Name, "myorg/myapp") + } + if len(detail.Tags) != 2 { + t.Fatalf("tag count: got %d, want 2", len(detail.Tags)) + } + if len(detail.Manifests) != 1 { + t.Fatalf("manifest count: got %d, want 1", len(detail.Manifests)) + } +} + +func TestAdminGetRepoDetailNotFound(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "GET", "/v1/repositories/nonexistent/repo", "") + if rr.Code != 404 { + t.Fatalf("status: got %d, want 404", rr.Code) + } +} + +func TestAdminDeleteRepo(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"}) + + rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "") + if rr.Code != 204 { + t.Fatalf("status: got %d, want 204; body: %s", rr.Code, rr.Body.String()) + } + + // Verify it's gone. + rr = adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "") + if rr.Code != 404 { + t.Fatalf("after delete status: got %d, want 404", rr.Code) + } +} + +func TestAdminDeleteRepoNotFound(t *testing.T) { + database := openAdminTestDB(t) + router, _ := buildAdminRouter(t, database) + + rr := adminReq(t, router, "DELETE", "/v1/repositories/nonexistent/repo", "") + if rr.Code != 404 { + t.Fatalf("status: got %d, want 404", rr.Code) + } +} + +func TestAdminDeleteRepoNonAdmin(t *testing.T) { + database := openAdminTestDB(t) + router := buildNonAdminRouter(t, database) + + seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"}) + + rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "") + if rr.Code != 403 { + t.Fatalf("status: got %d, want 403", rr.Code) + } +} diff --git a/internal/server/admin_routes.go b/internal/server/admin_routes.go new file mode 100644 index 0000000..c691f3e --- /dev/null +++ b/internal/server/admin_routes.go @@ -0,0 +1,56 @@ +package server + +import ( + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/db" +) + +// AdminDeps holds the dependencies needed by admin routes. +type AdminDeps struct { + DB *db.DB + Login LoginClient + Engine PolicyReloader + AuditFn AuditFunc + GCState *GCState +} + +// MountAdminRoutes adds admin REST endpoints to the router. +// Auth middleware is applied at the route group level. +func MountAdminRoutes(r chi.Router, validator TokenValidator, serviceName string, deps AdminDeps) { + // Health endpoint - no auth required. + r.Get("/v1/health", AdminHealthHandler()) + + // Auth endpoints - no bearer auth required (login uses credentials). + r.Post("/v1/auth/login", AdminLoginHandler(deps.Login)) + + // Authenticated endpoints. + r.Route("/v1", func(v1 chi.Router) { + v1.Use(RequireAuth(validator, serviceName)) + + // Logout. + v1.Post("/auth/logout", AdminLogoutHandler()) + + // Repositories - list and detail require auth, delete requires admin. + v1.Get("/repositories", AdminListReposHandler(deps.DB)) + v1.Get("/repositories/*", AdminGetRepoHandler(deps.DB)) + v1.With(RequireAdmin()).Delete("/repositories/*", AdminDeleteRepoHandler(deps.DB, deps.AuditFn)) + + // Policy - all require admin. + v1.Route("/policy/rules", func(pr chi.Router) { + pr.Use(RequireAdmin()) + pr.Get("/", AdminListPolicyRulesHandler(deps.DB)) + pr.Post("/", AdminCreatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn)) + pr.Get("/{id}", AdminGetPolicyRuleHandler(deps.DB)) + pr.Patch("/{id}", AdminUpdatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn)) + pr.Delete("/{id}", AdminDeletePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn)) + }) + + // Audit - requires admin. + v1.With(RequireAdmin()).Get("/audit", AdminListAuditHandler(deps.DB)) + + // GC - requires admin. + v1.With(RequireAdmin()).Post("/gc", AdminTriggerGCHandler(deps.GCState)) + v1.With(RequireAdmin()).Get("/gc/status", AdminGCStatusHandler(deps.GCState)) + }) +} diff --git a/internal/server/admin_test.go b/internal/server/admin_test.go new file mode 100644 index 0000000..79d2528 --- /dev/null +++ b/internal/server/admin_test.go @@ -0,0 +1,148 @@ +package server + +import ( + "io" + "net/http" + "net/http/httptest" + "path/filepath" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + + "git.wntrmute.dev/kyle/mcr/internal/auth" + "git.wntrmute.dev/kyle/mcr/internal/db" + "git.wntrmute.dev/kyle/mcr/internal/policy" +) + +func openAdminTestDB(t *testing.T) *db.DB { + t.Helper() + path := filepath.Join(t.TempDir(), "test.db") + d, err := db.Open(path) + if err != nil { + t.Fatalf("Open: %v", err) + } + t.Cleanup(func() { _ = d.Close() }) + if err := d.Migrate(); err != nil { + t.Fatalf("Migrate: %v", err) + } + return d +} + +type fakePolicyReloader struct { + reloadCount int +} + +func (f *fakePolicyReloader) Reload(_ policy.RuleStore) error { + f.reloadCount++ + return nil +} + +// buildAdminRouter creates a chi router with admin routes wired up, +// using a fakeValidator that returns admin claims for any bearer token. +func buildAdminRouter(t *testing.T, database *db.DB) (chi.Router, *fakePolicyReloader) { + t.Helper() + + validator := &fakeValidator{ + claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}}, + } + login := &fakeLoginClient{token: "test-token", expiresIn: 3600} + reloader := &fakePolicyReloader{} + gcState := &GCState{} + + r := chi.NewRouter() + MountAdminRoutes(r, validator, "mcr-test", AdminDeps{ + DB: database, + Login: login, + Engine: reloader, + AuditFn: nil, + GCState: gcState, + }) + return r, reloader +} + +// buildNonAdminRouter creates a chi router that returns non-admin claims. +func buildNonAdminRouter(t *testing.T, database *db.DB) chi.Router { + t.Helper() + + validator := &fakeValidator{ + claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}, + } + login := &fakeLoginClient{token: "test-token", expiresIn: 3600} + reloader := &fakePolicyReloader{} + gcState := &GCState{} + + r := chi.NewRouter() + MountAdminRoutes(r, validator, "mcr-test", AdminDeps{ + DB: database, + Login: login, + Engine: reloader, + AuditFn: nil, + GCState: gcState, + }) + return r +} + +// adminReq is a convenience helper for making HTTP requests against the admin +// router, automatically including the Authorization header. +func adminReq(t *testing.T, router http.Handler, method, path string, body string) *httptest.ResponseRecorder { + t.Helper() + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, bodyReader) + req.Header.Set("Authorization", "Bearer valid-token") + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + rr := httptest.NewRecorder() + router.ServeHTTP(rr, req) + return rr +} + +func TestRequireAdminAllowed(t *testing.T) { + claims := &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}} + handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + + req := httptest.NewRequest("GET", "/test", nil) + req = req.WithContext(auth.ContextWithClaims(req.Context(), claims)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusOK { + t.Fatalf("admin allowed: got %d, want 200", rr.Code) + } +} + +func TestRequireAdminDenied(t *testing.T) { + claims := &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}} + handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("inner handler should not be called") + })) + + req := httptest.NewRequest("GET", "/test", nil) + req = req.WithContext(auth.ContextWithClaims(req.Context(), claims)) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusForbidden { + t.Fatalf("non-admin denied: got %d, want 403", rr.Code) + } +} + +func TestRequireAdminNoClaims(t *testing.T) { + handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Fatal("inner handler should not be called") + })) + + req := httptest.NewRequest("GET", "/test", nil) + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + if rr.Code != http.StatusUnauthorized { + t.Fatalf("no claims: got %d, want 401", rr.Code) + } +}