Phases 5, 6, 8: OCI pull/push paths and admin REST API
Phase 5 (OCI pull): internal/oci/ package with manifest GET/HEAD by tag/digest, blob GET/HEAD with repo membership check, tag listing with OCI pagination, catalog listing. Multi-segment repo names via parseOCIPath() right-split routing. DB query layer in internal/db/repository.go. Phase 6 (OCI push): blob uploads (monolithic and chunked) with uploadManager tracking in-progress BlobWriters, manifest push implementing full ARCHITECTURE.md §5 flow in a single SQLite transaction (create repo, upsert manifest, populate manifest_blobs, atomic tag move). Digest verification on both blob commit and manifest push-by-digest. Phase 8 (admin REST): /v1 endpoints for auth (login/logout/health), repository management (list/detail/delete), policy CRUD with engine reload, audit log listing with filters, GC trigger/status stubs. RequireAdmin middleware, platform-standard error format. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
177
PROGRESS.md
177
PROGRESS.md
@@ -6,7 +6,7 @@ See `PROJECT_PLAN.md` for the implementation roadmap and
|
|||||||
|
|
||||||
## Current State
|
## Current State
|
||||||
|
|
||||||
**Phase:** 4 complete, ready for Batch B (Phase 5 + Phase 8)
|
**Phase:** 6 complete, ready for Phase 7
|
||||||
**Last updated:** 2026-03-19
|
**Last updated:** 2026-03-19
|
||||||
|
|
||||||
### Completed
|
### Completed
|
||||||
@@ -16,6 +16,9 @@ See `PROJECT_PLAN.md` for the implementation roadmap and
|
|||||||
- Phase 2: Blob storage layer (all 2 steps)
|
- Phase 2: Blob storage layer (all 2 steps)
|
||||||
- Phase 3: MCIAS authentication (all 4 steps)
|
- Phase 3: MCIAS authentication (all 4 steps)
|
||||||
- Phase 4: Policy engine (all 4 steps)
|
- Phase 4: Policy engine (all 4 steps)
|
||||||
|
- Phase 5: OCI pull path (all 5 steps)
|
||||||
|
- Phase 6: OCI push path (all 3 steps)
|
||||||
|
- Phase 8: Admin REST API (all 5 steps)
|
||||||
- `ARCHITECTURE.md` — Full design specification (18 sections)
|
- `ARCHITECTURE.md` — Full design specification (18 sections)
|
||||||
- `CLAUDE.md` — AI development guidance
|
- `CLAUDE.md` — AI development guidance
|
||||||
- `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps)
|
- `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps)
|
||||||
@@ -23,14 +26,180 @@ See `PROJECT_PLAN.md` for the implementation roadmap and
|
|||||||
|
|
||||||
### Next Steps
|
### Next Steps
|
||||||
|
|
||||||
1. Batch B: Phase 5 (OCI pull) and Phase 8 (admin REST) — independent,
|
1. Phase 7 (OCI delete)
|
||||||
can be done in parallel
|
2. After Phase 7, Phase 9 (garbage collection)
|
||||||
2. After Phase 5, Phase 6 (OCI push) then Phase 7 (OCI delete)
|
3. Phase 10 (gRPC admin API)
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Log
|
## Log
|
||||||
|
|
||||||
|
### 2026-03-19 — Phase 6: OCI push path
|
||||||
|
|
||||||
|
**Task:** Implement blob uploads (monolithic and chunked) and manifest
|
||||||
|
push per ARCHITECTURE.md §5 and OCI Distribution Spec.
|
||||||
|
|
||||||
|
**Changes:**
|
||||||
|
|
||||||
|
Step 6.1 — Blob upload initiation:
|
||||||
|
- `db/upload.go`: `UploadRow` type, `ErrUploadNotFound` sentinel;
|
||||||
|
`CreateUpload()`, `GetUpload()`, `UpdateUploadOffset()`, `DeleteUpload()`
|
||||||
|
- `db/push.go`: `GetOrCreateRepository()` (implicit repo creation on
|
||||||
|
first push), `BlobExists()`, `InsertBlob()` (INSERT OR IGNORE for
|
||||||
|
content-addressed dedup)
|
||||||
|
- `oci/upload.go`: `uploadManager` (sync.Mutex-protected map of in-progress
|
||||||
|
BlobWriters by UUID); `generateUUID()` via crypto/rand; `handleUploadInitiate()`
|
||||||
|
— policy check (registry:push), implicit repo creation, DB row + storage
|
||||||
|
temp file, returns 202 with Location/Docker-Upload-UUID/Range headers
|
||||||
|
- Extended `DBQuerier` interface with push/upload methods
|
||||||
|
- Changed `BlobOpener` to `BlobStore` adding `StartUpload(uuid)` method
|
||||||
|
|
||||||
|
Step 6.2 — Blob upload chunked and monolithic:
|
||||||
|
- `oci/upload.go`: `handleUploadChunk()` (PATCH — append body, update offset),
|
||||||
|
`handleUploadComplete()` (PUT — optional final body, commit with digest
|
||||||
|
verification, insert blob row, cleanup upload, audit event),
|
||||||
|
`handleUploadStatus()` (GET — 204 with Range header),
|
||||||
|
`handleUploadCancel()` (DELETE — cancel BlobWriter, remove upload row)
|
||||||
|
- `oci/routes.go`: `parseOCIPath()` extended to handle `/blobs/uploads/`
|
||||||
|
and `/blobs/uploads/<uuid>` patterns; `dispatchUpload()` routes by method
|
||||||
|
|
||||||
|
Step 6.3 — Manifest push:
|
||||||
|
- `db/push.go`: `PushManifestParams` struct, `PushManifest()` — single
|
||||||
|
SQLite transaction: create repo if not exists, upsert manifest (ON
|
||||||
|
CONFLICT DO UPDATE), clear and repopulate manifest_blobs join table,
|
||||||
|
upsert tag if provided (atomic tag move)
|
||||||
|
- `oci/manifest.go`: `handleManifestPut()` — full push flow per §5:
|
||||||
|
parse JSON, compute SHA-256, verify digest if push-by-digest, collect
|
||||||
|
config+layer descriptors, verify all referenced blobs exist (400
|
||||||
|
MANIFEST_BLOB_UNKNOWN if missing), call PushManifest(), audit event
|
||||||
|
with tag details, returns 201 with Docker-Content-Digest/Location/
|
||||||
|
Content-Type headers
|
||||||
|
|
||||||
|
**Verification:**
|
||||||
|
- `make all` passes: vet clean, lint 0 issues, 198 tests passing,
|
||||||
|
all 3 binaries built
|
||||||
|
- DB tests (15 new): upload CRUD (create/get/update/delete/not-found),
|
||||||
|
GetOrCreateRepository (new + existing), BlobExists (found/not-found),
|
||||||
|
InsertBlob (new + idempotent), PushManifest by tag (verify repo creation,
|
||||||
|
manifest, tag, manifest_blobs), by digest (no tag), tag move (atomic
|
||||||
|
update), idempotent re-push
|
||||||
|
- OCI upload tests (8 new): initiate (202, Location/UUID/Range headers,
|
||||||
|
implicit repo creation), unique UUIDs (5 initiates → 5 distinct UUIDs),
|
||||||
|
monolithic upload (POST → PUT with body → 201), chunked upload
|
||||||
|
(POST → PATCH → PATCH → PUT → 201), digest mismatch (400 DIGEST_INVALID),
|
||||||
|
upload status (GET → 204), cancel (DELETE → 204), nonexistent UUID
|
||||||
|
(PATCH → 404 BLOB_UPLOAD_UNKNOWN)
|
||||||
|
- OCI manifest push tests (7 new): push by tag (201, correct headers),
|
||||||
|
push by digest, digest mismatch (400), missing blob (400
|
||||||
|
MANIFEST_BLOB_UNKNOWN), malformed JSON (400 MANIFEST_INVALID), empty
|
||||||
|
manifest, tag update (atomic move), re-push idempotent
|
||||||
|
- Route parsing tests (5 new subtests): upload initiate (trailing slash),
|
||||||
|
upload initiate (no trailing slash), upload with UUID, multi-segment
|
||||||
|
repo upload, multi-segment repo upload initiate
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 2026-03-19 — Batch B: Phase 5 (OCI pull) + Phase 8 (admin REST)
|
||||||
|
|
||||||
|
**Task:** Implement the OCI Distribution Spec pull path and the admin
|
||||||
|
REST API management endpoints. Both phases depend on Phase 4 (policy)
|
||||||
|
but not on each other — implemented in parallel.
|
||||||
|
|
||||||
|
**Changes:**
|
||||||
|
|
||||||
|
Phase 5 — `internal/db/` + `internal/oci/` (Steps 5.1–5.5):
|
||||||
|
|
||||||
|
Step 5.1 — OCI handler scaffolding:
|
||||||
|
- `db/errors.go`: `ErrRepoNotFound`, `ErrManifestNotFound`, `ErrBlobNotFound`
|
||||||
|
sentinel errors
|
||||||
|
- `oci/handler.go`: `Handler` struct with `DBQuerier`, `BlobOpener`,
|
||||||
|
`PolicyEval`, `AuditFunc` interfaces; `NewHandler()` constructor;
|
||||||
|
`checkPolicy()` inline policy check; `audit()` helper
|
||||||
|
- `oci/ocierror.go`: `writeOCIError()` duplicated from server package
|
||||||
|
(15 lines, avoids coupling)
|
||||||
|
- `oci/routes.go`: `parseOCIPath()` splits from the right to handle
|
||||||
|
multi-segment repo names (e.g., `org/team/app/manifests/latest`);
|
||||||
|
`Router()` returns chi router with `/_catalog` and `/*` catch-all;
|
||||||
|
`dispatch()` routes to manifest/blob/tag handlers
|
||||||
|
|
||||||
|
Step 5.2 — Manifest pull:
|
||||||
|
- `db/repository.go`: `ManifestRow` type, `GetRepositoryByName()`,
|
||||||
|
`GetManifestByTag()`, `GetManifestByDigest()`
|
||||||
|
- `oci/manifest.go`: GET/HEAD `/v2/{name}/manifests/{reference}`;
|
||||||
|
resolves by tag or digest; sets Content-Type, Docker-Content-Digest,
|
||||||
|
Content-Length headers
|
||||||
|
|
||||||
|
Step 5.3 — Blob download:
|
||||||
|
- `db/repository.go`: `BlobExistsInRepo()` — joins blobs+manifest_blobs+
|
||||||
|
manifests to verify blob belongs to repo
|
||||||
|
- `oci/blob.go`: GET/HEAD `/v2/{name}/blobs/{digest}`; validates blob
|
||||||
|
exists in repo before streaming from storage
|
||||||
|
|
||||||
|
Step 5.4 — Tag listing:
|
||||||
|
- `db/repository.go`: `ListTags()` with cursor-based pagination
|
||||||
|
(after/limit)
|
||||||
|
- `oci/tags.go`: GET `/v2/{name}/tags/list` with OCI `?n=`/`?last=`
|
||||||
|
pagination and `Link` header for next page
|
||||||
|
|
||||||
|
Step 5.5 — Catalog listing:
|
||||||
|
- `db/repository.go`: `ListRepositoryNames()` with cursor-based pagination
|
||||||
|
- `oci/catalog.go`: GET `/v2/_catalog` with same pagination pattern
|
||||||
|
|
||||||
|
Phase 8 — `internal/db/` + `internal/server/` (Steps 8.1–8.5):
|
||||||
|
|
||||||
|
Step 8.1 — Auth endpoints:
|
||||||
|
- `server/admin_auth.go`: POST `/v1/auth/login` (JSON body → MCIAS),
|
||||||
|
POST `/v1/auth/logout` (204 stub), GET `/v1/health` (no auth)
|
||||||
|
|
||||||
|
Step 8.2 — Repository management:
|
||||||
|
- `db/admin.go`: `RepoMetadata`, `TagInfo`, `ManifestInfo`, `RepoDetail`
|
||||||
|
types; `ListRepositoriesWithMetadata()`, `GetRepositoryDetail()`,
|
||||||
|
`DeleteRepository()`
|
||||||
|
- `server/admin_repo.go`: GET/DELETE `/v1/repositories` and
|
||||||
|
`/v1/repositories/*` (wildcard for multi-segment names)
|
||||||
|
|
||||||
|
Step 8.3 — Policy CRUD:
|
||||||
|
- `db/admin.go`: `PolicyRuleRow` type, `ErrPolicyRuleNotFound`;
|
||||||
|
`CreatePolicyRule()`, `GetPolicyRule()`, `ListPolicyRules()`,
|
||||||
|
`UpdatePolicyRule()`, `SetPolicyRuleEnabled()`, `DeletePolicyRule()`
|
||||||
|
- `server/admin_policy.go`: full CRUD on `/v1/policy/rules` and
|
||||||
|
`/v1/policy/rules/{id}`; `PolicyReloader` interface; input validation
|
||||||
|
(priority >= 1, valid effect/actions); mutations trigger engine reload
|
||||||
|
|
||||||
|
Step 8.4 — Audit endpoint:
|
||||||
|
- `server/admin_audit.go`: GET `/v1/audit` with query parameter filters
|
||||||
|
(event_type, actor_id, repository, since, until, n, offset);
|
||||||
|
delegates to `db.ListAuditEvents`
|
||||||
|
|
||||||
|
Step 8.5 — GC endpoints:
|
||||||
|
- `server/admin_gc.go`: `GCState` struct with `sync.Mutex`; POST `/v1/gc`
|
||||||
|
returns 202 (stub for Phase 9); GET `/v1/gc/status`; concurrent
|
||||||
|
trigger returns 409
|
||||||
|
|
||||||
|
Shared admin infrastructure:
|
||||||
|
- `server/admin.go`: `writeAdminError()` (platform `{"error":"..."}`
|
||||||
|
format), `writeJSON()`, `RequireAdmin()` middleware, `hasRole()` helper
|
||||||
|
- `server/admin_routes.go`: `AdminDeps` struct, `MountAdminRoutes()`
|
||||||
|
mounts all `/v1/*` endpoints with proper auth/admin middleware layers
|
||||||
|
|
||||||
|
**Verification:**
|
||||||
|
- `make all` passes: vet clean, lint 0 issues, 168 tests passing,
|
||||||
|
all 3 binaries built
|
||||||
|
- Phase 5 (34 new tests): parseOCIPath (14 subtests covering simple/
|
||||||
|
multi-segment/edge cases), manifest GET by tag/digest + HEAD + not
|
||||||
|
found, blob GET/HEAD + not in repo + repo not found, tags list +
|
||||||
|
pagination + empty + repo not found, catalog list + pagination + empty,
|
||||||
|
DB repository methods (15 tests covering all 6 query methods)
|
||||||
|
- Phase 8 (51 new tests): DB admin methods (19 tests covering CRUD,
|
||||||
|
pagination, cascade, not-found), admin auth (login ok/fail, health,
|
||||||
|
logout), admin repos (list/detail/delete/non-admin 403), admin policy
|
||||||
|
(full CRUD cycle, validation errors, non-admin 403, engine reload),
|
||||||
|
admin audit (list with filters, pagination), admin GC (trigger 202,
|
||||||
|
status, concurrent 409), RequireAdmin middleware (allowed/denied/
|
||||||
|
no claims)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
### 2026-03-19 — Phase 4: Policy engine
|
### 2026-03-19 — Phase 4: Policy engine
|
||||||
|
|
||||||
**Task:** Implement the registry-specific authorization engine with
|
**Task:** Implement the registry-specific authorization engine with
|
||||||
|
|||||||
@@ -14,10 +14,10 @@ design specification.
|
|||||||
| 2 | Blob storage layer | **Complete** |
|
| 2 | Blob storage layer | **Complete** |
|
||||||
| 3 | MCIAS authentication | **Complete** |
|
| 3 | MCIAS authentication | **Complete** |
|
||||||
| 4 | Policy engine | **Complete** |
|
| 4 | Policy engine | **Complete** |
|
||||||
| 5 | OCI API — pull path | Not started |
|
| 5 | OCI API — pull path | **Complete** |
|
||||||
| 6 | OCI API — push path | Not started |
|
| 6 | OCI API — push path | **Complete** |
|
||||||
| 7 | OCI API — delete path | Not started |
|
| 7 | OCI API — delete path | Not started |
|
||||||
| 8 | Admin REST API | Not started |
|
| 8 | Admin REST API | **Complete** |
|
||||||
| 9 | Garbage collection | Not started |
|
| 9 | Garbage collection | Not started |
|
||||||
| 10 | gRPC admin API | Not started |
|
| 10 | gRPC admin API | Not started |
|
||||||
| 11 | CLI tool (mcrctl) | Not started |
|
| 11 | CLI tool (mcrctl) | Not started |
|
||||||
|
|||||||
429
internal/db/admin.go
Normal file
429
internal/db/admin.go
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrPolicyRuleNotFound is returned when a policy rule lookup finds no matching row.
|
||||||
|
var ErrPolicyRuleNotFound = errors.New("db: policy rule not found")
|
||||||
|
|
||||||
|
// RepoMetadata is a repository with aggregate counts for listing.
|
||||||
|
type RepoMetadata struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
TagCount int `json:"tag_count"`
|
||||||
|
ManifestCount int `json:"manifest_count"`
|
||||||
|
TotalSize int64 `json:"total_size"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// TagInfo is a tag with its manifest digest for repo detail.
|
||||||
|
type TagInfo struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ManifestInfo is a manifest summary for repo detail.
|
||||||
|
type ManifestInfo struct {
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
MediaType string `json:"media_type"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// RepoDetail contains detailed info about a single repository.
|
||||||
|
type RepoDetail struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Tags []TagInfo `json:"tags"`
|
||||||
|
Manifests []ManifestInfo `json:"manifests"`
|
||||||
|
TotalSize int64 `json:"total_size"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PolicyRuleRow represents a row from the policy_rules table with parsed JSON.
|
||||||
|
type PolicyRuleRow struct {
|
||||||
|
ID int64 `json:"id"`
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Effect string `json:"effect"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
AccountTypes []string `json:"account_types,omitempty"`
|
||||||
|
SubjectUUID string `json:"subject_uuid,omitempty"`
|
||||||
|
Actions []string `json:"actions"`
|
||||||
|
Repositories []string `json:"repositories,omitempty"`
|
||||||
|
Enabled bool `json:"enabled"`
|
||||||
|
CreatedBy string `json:"created_by,omitempty"`
|
||||||
|
CreatedAt string `json:"created_at"`
|
||||||
|
UpdatedAt string `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRepositoriesWithMetadata returns all repositories with tag count,
|
||||||
|
// manifest count, and total size.
|
||||||
|
func (d *DB) ListRepositoriesWithMetadata(limit, offset int) ([]RepoMetadata, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
rows, err := d.Query(
|
||||||
|
`SELECT r.name, r.created_at,
|
||||||
|
(SELECT COUNT(*) FROM tags t WHERE t.repository_id = r.id) AS tag_count,
|
||||||
|
(SELECT COUNT(*) FROM manifests m WHERE m.repository_id = r.id) AS manifest_count,
|
||||||
|
COALESCE((SELECT SUM(m.size) FROM manifests m WHERE m.repository_id = r.id), 0) AS total_size
|
||||||
|
FROM repositories r
|
||||||
|
ORDER BY r.name ASC
|
||||||
|
LIMIT ? OFFSET ?`,
|
||||||
|
limit, offset,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list repositories: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var repos []RepoMetadata
|
||||||
|
for rows.Next() {
|
||||||
|
var r RepoMetadata
|
||||||
|
if err := rows.Scan(&r.Name, &r.CreatedAt, &r.TagCount, &r.ManifestCount, &r.TotalSize); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan repository: %w", err)
|
||||||
|
}
|
||||||
|
repos = append(repos, r)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate repositories: %w", err)
|
||||||
|
}
|
||||||
|
return repos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepositoryDetail returns detailed information about a repository.
|
||||||
|
func (d *DB) GetRepositoryDetail(name string) (*RepoDetail, error) {
|
||||||
|
var repoID int64
|
||||||
|
var createdAt string
|
||||||
|
err := d.QueryRow(`SELECT id, created_at FROM repositories WHERE name = ?`, name).Scan(&repoID, &createdAt)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrRepoNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("db: get repository: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
detail := &RepoDetail{Name: name, CreatedAt: createdAt}
|
||||||
|
|
||||||
|
// Tags with manifest digests.
|
||||||
|
tagRows, err := d.Query(
|
||||||
|
`SELECT t.name, m.digest
|
||||||
|
FROM tags t JOIN manifests m ON m.id = t.manifest_id
|
||||||
|
WHERE t.repository_id = ?
|
||||||
|
ORDER BY t.name ASC`, repoID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list repo tags: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = tagRows.Close() }()
|
||||||
|
|
||||||
|
for tagRows.Next() {
|
||||||
|
var ti TagInfo
|
||||||
|
if err := tagRows.Scan(&ti.Name, &ti.Digest); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan tag: %w", err)
|
||||||
|
}
|
||||||
|
detail.Tags = append(detail.Tags, ti)
|
||||||
|
}
|
||||||
|
if err := tagRows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate tags: %w", err)
|
||||||
|
}
|
||||||
|
if detail.Tags == nil {
|
||||||
|
detail.Tags = []TagInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manifests.
|
||||||
|
mRows, err := d.Query(
|
||||||
|
`SELECT digest, media_type, size, created_at
|
||||||
|
FROM manifests WHERE repository_id = ?
|
||||||
|
ORDER BY created_at DESC`, repoID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list repo manifests: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = mRows.Close() }()
|
||||||
|
|
||||||
|
for mRows.Next() {
|
||||||
|
var mi ManifestInfo
|
||||||
|
if err := mRows.Scan(&mi.Digest, &mi.MediaType, &mi.Size, &mi.CreatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan manifest: %w", err)
|
||||||
|
}
|
||||||
|
detail.TotalSize += mi.Size
|
||||||
|
detail.Manifests = append(detail.Manifests, mi)
|
||||||
|
}
|
||||||
|
if err := mRows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate manifests: %w", err)
|
||||||
|
}
|
||||||
|
if detail.Manifests == nil {
|
||||||
|
detail.Manifests = []ManifestInfo{}
|
||||||
|
}
|
||||||
|
|
||||||
|
return detail, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRepository deletes a repository and all its manifests, tags, and
|
||||||
|
// manifest_blobs. CASCADE handles the dependent rows.
|
||||||
|
func (d *DB) DeleteRepository(name string) error {
|
||||||
|
result, err := d.Exec(`DELETE FROM repositories WHERE name = ?`, name)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete repository: %w", err)
|
||||||
|
}
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete repository rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return ErrRepoNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreatePolicyRule inserts a new policy rule and returns its ID.
|
||||||
|
func (d *DB) CreatePolicyRule(rule PolicyRuleRow) (int64, error) {
|
||||||
|
body := ruleBody{
|
||||||
|
Effect: rule.Effect,
|
||||||
|
Roles: rule.Roles,
|
||||||
|
AccountTypes: rule.AccountTypes,
|
||||||
|
SubjectUUID: rule.SubjectUUID,
|
||||||
|
Actions: rule.Actions,
|
||||||
|
Repositories: rule.Repositories,
|
||||||
|
}
|
||||||
|
ruleJSON, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: marshal rule body: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := 0
|
||||||
|
if rule.Enabled {
|
||||||
|
enabled = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := d.Exec(
|
||||||
|
`INSERT INTO policy_rules (priority, description, rule_json, enabled, created_by)
|
||||||
|
VALUES (?, ?, ?, ?, ?)`,
|
||||||
|
rule.Priority, rule.Description, string(ruleJSON), enabled, nullIfEmpty(rule.CreatedBy),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: create policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: policy rule last insert id: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPolicyRule returns a single policy rule by ID.
|
||||||
|
func (d *DB) GetPolicyRule(id int64) (*PolicyRuleRow, error) {
|
||||||
|
var row PolicyRuleRow
|
||||||
|
var ruleJSON string
|
||||||
|
var enabledInt int
|
||||||
|
var createdBy *string
|
||||||
|
|
||||||
|
err := d.QueryRow(
|
||||||
|
`SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
|
||||||
|
FROM policy_rules WHERE id = ?`, id,
|
||||||
|
).Scan(&row.ID, &row.Priority, &row.Description, &ruleJSON, &enabledInt, &createdBy, &row.CreatedAt, &row.UpdatedAt)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrPolicyRuleNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("db: get policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row.Enabled = enabledInt == 1
|
||||||
|
if createdBy != nil {
|
||||||
|
row.CreatedBy = *createdBy
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ruleBody
|
||||||
|
if err := json.Unmarshal([]byte(ruleJSON), &body); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: parse rule_json for rule %d: %w", id, err)
|
||||||
|
}
|
||||||
|
row.Effect = body.Effect
|
||||||
|
row.Roles = body.Roles
|
||||||
|
row.AccountTypes = body.AccountTypes
|
||||||
|
row.SubjectUUID = body.SubjectUUID
|
||||||
|
row.Actions = body.Actions
|
||||||
|
row.Repositories = body.Repositories
|
||||||
|
|
||||||
|
return &row, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListPolicyRules returns all policy rules ordered by priority ascending.
|
||||||
|
func (d *DB) ListPolicyRules(limit, offset int) ([]PolicyRuleRow, error) {
|
||||||
|
if limit <= 0 {
|
||||||
|
limit = 50
|
||||||
|
}
|
||||||
|
rows, err := d.Query(
|
||||||
|
`SELECT id, priority, description, rule_json, enabled, created_by, created_at, updated_at
|
||||||
|
FROM policy_rules
|
||||||
|
ORDER BY priority ASC
|
||||||
|
LIMIT ? OFFSET ?`,
|
||||||
|
limit, offset,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list policy rules: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var rules []PolicyRuleRow
|
||||||
|
for rows.Next() {
|
||||||
|
var row PolicyRuleRow
|
||||||
|
var ruleJSON string
|
||||||
|
var enabledInt int
|
||||||
|
var createdBy *string
|
||||||
|
|
||||||
|
if err := rows.Scan(&row.ID, &row.Priority, &row.Description, &ruleJSON, &enabledInt, &createdBy, &row.CreatedAt, &row.UpdatedAt); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan policy rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
row.Enabled = enabledInt == 1
|
||||||
|
if createdBy != nil {
|
||||||
|
row.CreatedBy = *createdBy
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ruleBody
|
||||||
|
if err := json.Unmarshal([]byte(ruleJSON), &body); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: parse rule_json: %w", err)
|
||||||
|
}
|
||||||
|
row.Effect = body.Effect
|
||||||
|
row.Roles = body.Roles
|
||||||
|
row.AccountTypes = body.AccountTypes
|
||||||
|
row.SubjectUUID = body.SubjectUUID
|
||||||
|
row.Actions = body.Actions
|
||||||
|
row.Repositories = body.Repositories
|
||||||
|
|
||||||
|
rules = append(rules, row)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate policy rules: %w", err)
|
||||||
|
}
|
||||||
|
return rules, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatePolicyRule performs a partial update of a policy rule.
|
||||||
|
// Only non-zero/non-empty fields in the input are updated.
|
||||||
|
// Always updates updated_at.
|
||||||
|
func (d *DB) UpdatePolicyRule(id int64, updates PolicyRuleRow) error {
|
||||||
|
// First check the rule exists.
|
||||||
|
var exists int
|
||||||
|
err := d.QueryRow(`SELECT COUNT(*) FROM policy_rules WHERE id = ?`, id).Scan(&exists)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: check policy rule: %w", err)
|
||||||
|
}
|
||||||
|
if exists == 0 {
|
||||||
|
return ErrPolicyRuleNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
var setClauses []string
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if updates.Priority != 0 {
|
||||||
|
setClauses = append(setClauses, "priority = ?")
|
||||||
|
args = append(args, updates.Priority)
|
||||||
|
}
|
||||||
|
if updates.Description != "" {
|
||||||
|
setClauses = append(setClauses, "description = ?")
|
||||||
|
args = append(args, updates.Description)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If any rule body fields are set, rebuild the full rule_json.
|
||||||
|
// Read the current value first, apply the update, then write back.
|
||||||
|
if updates.Effect != "" || updates.Actions != nil || updates.Roles != nil ||
|
||||||
|
updates.AccountTypes != nil || updates.Repositories != nil || updates.SubjectUUID != "" {
|
||||||
|
var currentJSON string
|
||||||
|
err := d.QueryRow(`SELECT rule_json FROM policy_rules WHERE id = ?`, id).Scan(¤tJSON)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: read current rule_json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ruleBody
|
||||||
|
if err := json.Unmarshal([]byte(currentJSON), &body); err != nil {
|
||||||
|
return fmt.Errorf("db: parse current rule_json: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if updates.Effect != "" {
|
||||||
|
body.Effect = updates.Effect
|
||||||
|
}
|
||||||
|
if updates.Actions != nil {
|
||||||
|
body.Actions = updates.Actions
|
||||||
|
}
|
||||||
|
if updates.Roles != nil {
|
||||||
|
body.Roles = updates.Roles
|
||||||
|
}
|
||||||
|
if updates.AccountTypes != nil {
|
||||||
|
body.AccountTypes = updates.AccountTypes
|
||||||
|
}
|
||||||
|
if updates.Repositories != nil {
|
||||||
|
body.Repositories = updates.Repositories
|
||||||
|
}
|
||||||
|
if updates.SubjectUUID != "" {
|
||||||
|
body.SubjectUUID = updates.SubjectUUID
|
||||||
|
}
|
||||||
|
|
||||||
|
newJSON, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: marshal updated rule_json: %w", err)
|
||||||
|
}
|
||||||
|
setClauses = append(setClauses, "rule_json = ?")
|
||||||
|
args = append(args, string(newJSON))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Always update updated_at.
|
||||||
|
setClauses = append(setClauses, "updated_at = ?")
|
||||||
|
args = append(args, time.Now().UTC().Format("2006-01-02T15:04:05Z"))
|
||||||
|
|
||||||
|
query := fmt.Sprintf("UPDATE policy_rules SET %s WHERE id = ?", strings.Join(setClauses, ", "))
|
||||||
|
args = append(args, id)
|
||||||
|
|
||||||
|
_, err = d.Exec(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: update policy rule: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPolicyRuleEnabled sets the enabled flag for a policy rule.
|
||||||
|
func (d *DB) SetPolicyRuleEnabled(id int64, enabled bool) error {
|
||||||
|
enabledInt := 0
|
||||||
|
if enabled {
|
||||||
|
enabledInt = 1
|
||||||
|
}
|
||||||
|
result, err := d.Exec(
|
||||||
|
`UPDATE policy_rules SET enabled = ?, updated_at = ? WHERE id = ?`,
|
||||||
|
enabledInt, time.Now().UTC().Format("2006-01-02T15:04:05Z"), id,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: set policy rule enabled: %w", err)
|
||||||
|
}
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: set policy rule enabled rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return ErrPolicyRuleNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeletePolicyRule deletes a policy rule by ID.
|
||||||
|
func (d *DB) DeletePolicyRule(id int64) error {
|
||||||
|
result, err := d.Exec(`DELETE FROM policy_rules WHERE id = ?`, id)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete policy rule: %w", err)
|
||||||
|
}
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete policy rule rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return ErrPolicyRuleNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
593
internal/db/admin_test.go
Normal file
593
internal/db/admin_test.go
Normal file
@@ -0,0 +1,593 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// seedAdminRepo inserts a repository with manifests, tags, and blobs for admin tests.
|
||||||
|
func seedAdminRepo(t *testing.T, d *DB, name string, tagNames []string) int64 {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo %q: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var repoID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&repoID); err != nil {
|
||||||
|
t.Fatalf("select repo id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.Exec(
|
||||||
|
`INSERT INTO manifests (repository_id, digest, media_type, content, size)
|
||||||
|
VALUES (?, ?, 'application/vnd.oci.image.manifest.v1+json', '{}', 1024)`,
|
||||||
|
repoID, "sha256:aaa-"+name,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert manifest for %q: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var manifestID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil {
|
||||||
|
t.Fatalf("select manifest id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tag := range tagNames {
|
||||||
|
_, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`,
|
||||||
|
repoID, tag, manifestID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert tag %q: %v", tag, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return repoID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoriesWithMetadata(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
seedAdminRepo(t, d, "alpha/app", []string{"latest", "v1.0"})
|
||||||
|
seedAdminRepo(t, d, "bravo/lib", []string{"latest"})
|
||||||
|
|
||||||
|
repos, err := d.ListRepositoriesWithMetadata(50, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoriesWithMetadata: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 2 {
|
||||||
|
t.Fatalf("repo count: got %d, want 2", len(repos))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ordered by name ASC.
|
||||||
|
if repos[0].Name != "alpha/app" {
|
||||||
|
t.Fatalf("first repo name: got %q, want %q", repos[0].Name, "alpha/app")
|
||||||
|
}
|
||||||
|
if repos[0].TagCount != 2 {
|
||||||
|
t.Fatalf("alpha/app tag count: got %d, want 2", repos[0].TagCount)
|
||||||
|
}
|
||||||
|
if repos[0].ManifestCount != 1 {
|
||||||
|
t.Fatalf("alpha/app manifest count: got %d, want 1", repos[0].ManifestCount)
|
||||||
|
}
|
||||||
|
if repos[0].TotalSize != 1024 {
|
||||||
|
t.Fatalf("alpha/app total size: got %d, want 1024", repos[0].TotalSize)
|
||||||
|
}
|
||||||
|
if repos[0].CreatedAt == "" {
|
||||||
|
t.Fatal("alpha/app created_at: expected non-empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if repos[1].Name != "bravo/lib" {
|
||||||
|
t.Fatalf("second repo name: got %q, want %q", repos[1].Name, "bravo/lib")
|
||||||
|
}
|
||||||
|
if repos[1].TagCount != 1 {
|
||||||
|
t.Fatalf("bravo/lib tag count: got %d, want 1", repos[1].TagCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoriesWithMetadataEmpty(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
repos, err := d.ListRepositoriesWithMetadata(50, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoriesWithMetadata: %v", err)
|
||||||
|
}
|
||||||
|
if repos != nil {
|
||||||
|
t.Fatalf("expected nil repos, got %v", repos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoriesWithMetadataPagination(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
seedAdminRepo(t, d, "alpha/app", []string{"latest"})
|
||||||
|
seedAdminRepo(t, d, "bravo/lib", []string{"latest"})
|
||||||
|
seedAdminRepo(t, d, "charlie/svc", []string{"latest"})
|
||||||
|
|
||||||
|
repos, err := d.ListRepositoriesWithMetadata(2, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoriesWithMetadata page 1: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 2 {
|
||||||
|
t.Fatalf("page 1 count: got %d, want 2", len(repos))
|
||||||
|
}
|
||||||
|
if repos[0].Name != "alpha/app" {
|
||||||
|
t.Fatalf("page 1 first: got %q, want %q", repos[0].Name, "alpha/app")
|
||||||
|
}
|
||||||
|
|
||||||
|
repos, err = d.ListRepositoriesWithMetadata(2, 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoriesWithMetadata page 2: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 1 {
|
||||||
|
t.Fatalf("page 2 count: got %d, want 1", len(repos))
|
||||||
|
}
|
||||||
|
if repos[0].Name != "charlie/svc" {
|
||||||
|
t.Fatalf("page 2 first: got %q, want %q", repos[0].Name, "charlie/svc")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepositoryDetail(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
seedAdminRepo(t, d, "myorg/myapp", []string{"latest", "v1.0"})
|
||||||
|
|
||||||
|
detail, err := d.GetRepositoryDetail("myorg/myapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRepositoryDetail: %v", err)
|
||||||
|
}
|
||||||
|
if detail.Name != "myorg/myapp" {
|
||||||
|
t.Fatalf("name: got %q, want %q", detail.Name, "myorg/myapp")
|
||||||
|
}
|
||||||
|
if detail.CreatedAt == "" {
|
||||||
|
t.Fatal("created_at: expected non-empty")
|
||||||
|
}
|
||||||
|
if len(detail.Tags) != 2 {
|
||||||
|
t.Fatalf("tag count: got %d, want 2", len(detail.Tags))
|
||||||
|
}
|
||||||
|
// Tags ordered by name ASC.
|
||||||
|
if detail.Tags[0].Name != "latest" {
|
||||||
|
t.Fatalf("first tag: got %q, want %q", detail.Tags[0].Name, "latest")
|
||||||
|
}
|
||||||
|
if detail.Tags[0].Digest == "" {
|
||||||
|
t.Fatal("first tag digest: expected non-empty")
|
||||||
|
}
|
||||||
|
if detail.Tags[1].Name != "v1.0" {
|
||||||
|
t.Fatalf("second tag: got %q, want %q", detail.Tags[1].Name, "v1.0")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(detail.Manifests) != 1 {
|
||||||
|
t.Fatalf("manifest count: got %d, want 1", len(detail.Manifests))
|
||||||
|
}
|
||||||
|
if detail.Manifests[0].Size != 1024 {
|
||||||
|
t.Fatalf("manifest size: got %d, want 1024", detail.Manifests[0].Size)
|
||||||
|
}
|
||||||
|
if detail.TotalSize != 1024 {
|
||||||
|
t.Fatalf("total size: got %d, want 1024", detail.TotalSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepositoryDetailNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
_, err := d.GetRepositoryDetail("nonexistent/repo")
|
||||||
|
if !errors.Is(err, ErrRepoNotFound) {
|
||||||
|
t.Fatalf("expected ErrRepoNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepositoryDetailEmptyRepo(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('empty/repo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := d.GetRepositoryDetail("empty/repo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRepositoryDetail: %v", err)
|
||||||
|
}
|
||||||
|
if len(detail.Tags) != 0 {
|
||||||
|
t.Fatalf("tags: got %d, want 0", len(detail.Tags))
|
||||||
|
}
|
||||||
|
if len(detail.Manifests) != 0 {
|
||||||
|
t.Fatalf("manifests: got %d, want 0", len(detail.Manifests))
|
||||||
|
}
|
||||||
|
if detail.TotalSize != 0 {
|
||||||
|
t.Fatalf("total size: got %d, want 0", detail.TotalSize)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteRepository(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
seedAdminRepo(t, d, "myorg/myapp", []string{"latest"})
|
||||||
|
|
||||||
|
if err := d.DeleteRepository("myorg/myapp"); err != nil {
|
||||||
|
t.Fatalf("DeleteRepository: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone.
|
||||||
|
_, err := d.GetRepositoryDetail("myorg/myapp")
|
||||||
|
if !errors.Is(err, ErrRepoNotFound) {
|
||||||
|
t.Fatalf("expected ErrRepoNotFound after delete, got %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify cascade: manifests and tags should be gone.
|
||||||
|
var manifestCount int
|
||||||
|
if err := d.QueryRow(`SELECT COUNT(*) FROM manifests`).Scan(&manifestCount); err != nil {
|
||||||
|
t.Fatalf("count manifests: %v", err)
|
||||||
|
}
|
||||||
|
if manifestCount != 0 {
|
||||||
|
t.Fatalf("manifests after delete: got %d, want 0", manifestCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
var tagCount int
|
||||||
|
if err := d.QueryRow(`SELECT COUNT(*) FROM tags`).Scan(&tagCount); err != nil {
|
||||||
|
t.Fatalf("count tags: %v", err)
|
||||||
|
}
|
||||||
|
if tagCount != 0 {
|
||||||
|
t.Fatalf("tags after delete: got %d, want 0", tagCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteRepositoryNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
err := d.DeleteRepository("nonexistent/repo")
|
||||||
|
if !errors.Is(err, ErrRepoNotFound) {
|
||||||
|
t.Fatalf("expected ErrRepoNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePolicyRule(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "allow CI push",
|
||||||
|
Effect: "allow",
|
||||||
|
Roles: []string{"ci"},
|
||||||
|
Actions: []string{"registry:push", "registry:pull"},
|
||||||
|
Repositories: []string{"production/*"},
|
||||||
|
Enabled: true,
|
||||||
|
CreatedBy: "admin-uuid",
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if id == 0 {
|
||||||
|
t.Fatal("expected non-zero ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if got.Priority != 50 {
|
||||||
|
t.Fatalf("priority: got %d, want 50", got.Priority)
|
||||||
|
}
|
||||||
|
if got.Description != "allow CI push" {
|
||||||
|
t.Fatalf("description: got %q, want %q", got.Description, "allow CI push")
|
||||||
|
}
|
||||||
|
if got.Effect != "allow" {
|
||||||
|
t.Fatalf("effect: got %q, want %q", got.Effect, "allow")
|
||||||
|
}
|
||||||
|
if len(got.Roles) != 1 || got.Roles[0] != "ci" {
|
||||||
|
t.Fatalf("roles: got %v, want [ci]", got.Roles)
|
||||||
|
}
|
||||||
|
if len(got.Actions) != 2 {
|
||||||
|
t.Fatalf("actions: got %d, want 2", len(got.Actions))
|
||||||
|
}
|
||||||
|
if len(got.Repositories) != 1 || got.Repositories[0] != "production/*" {
|
||||||
|
t.Fatalf("repositories: got %v, want [production/*]", got.Repositories)
|
||||||
|
}
|
||||||
|
if !got.Enabled {
|
||||||
|
t.Fatal("enabled: got false, want true")
|
||||||
|
}
|
||||||
|
if got.CreatedBy != "admin-uuid" {
|
||||||
|
t.Fatalf("created_by: got %q, want %q", got.CreatedBy, "admin-uuid")
|
||||||
|
}
|
||||||
|
if got.CreatedAt == "" {
|
||||||
|
t.Fatal("created_at: expected non-empty")
|
||||||
|
}
|
||||||
|
if got.UpdatedAt == "" {
|
||||||
|
t.Fatal("updated_at: expected non-empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCreatePolicyRuleDisabled(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 10,
|
||||||
|
Description: "disabled rule",
|
||||||
|
Effect: "deny",
|
||||||
|
Actions: []string{"registry:delete"},
|
||||||
|
Enabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if got.Enabled {
|
||||||
|
t.Fatal("enabled: got true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetPolicyRuleNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
_, err := d.GetPolicyRule(9999)
|
||||||
|
if !errors.Is(err, ErrPolicyRuleNotFound) {
|
||||||
|
t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListPolicyRules(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
// Insert rules with different priorities (out of order).
|
||||||
|
rule1 := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "rule A",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:pull"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
rule2 := PolicyRuleRow{
|
||||||
|
Priority: 10,
|
||||||
|
Description: "rule B",
|
||||||
|
Effect: "deny",
|
||||||
|
Actions: []string{"registry:delete"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
rule3 := PolicyRuleRow{
|
||||||
|
Priority: 30,
|
||||||
|
Description: "rule C",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:push"},
|
||||||
|
Enabled: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := d.CreatePolicyRule(rule1); err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule 1: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := d.CreatePolicyRule(rule2); err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule 2: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := d.CreatePolicyRule(rule3); err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule 3: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := d.ListPolicyRules(50, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListPolicyRules: %v", err)
|
||||||
|
}
|
||||||
|
if len(rules) != 3 {
|
||||||
|
t.Fatalf("rule count: got %d, want 3", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should be ordered by priority ASC: 10, 30, 50.
|
||||||
|
if rules[0].Priority != 10 {
|
||||||
|
t.Fatalf("first rule priority: got %d, want 10", rules[0].Priority)
|
||||||
|
}
|
||||||
|
if rules[0].Description != "rule B" {
|
||||||
|
t.Fatalf("first rule description: got %q, want %q", rules[0].Description, "rule B")
|
||||||
|
}
|
||||||
|
if rules[1].Priority != 30 {
|
||||||
|
t.Fatalf("second rule priority: got %d, want 30", rules[1].Priority)
|
||||||
|
}
|
||||||
|
if rules[2].Priority != 50 {
|
||||||
|
t.Fatalf("third rule priority: got %d, want 50", rules[2].Priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify enabled flags.
|
||||||
|
if !rules[0].Enabled {
|
||||||
|
t.Fatal("rule B enabled: got false, want true")
|
||||||
|
}
|
||||||
|
if rules[1].Enabled {
|
||||||
|
t.Fatal("rule C enabled: got true, want false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListPolicyRulesEmpty(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rules, err := d.ListPolicyRules(50, 0)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListPolicyRules: %v", err)
|
||||||
|
}
|
||||||
|
if rules != nil {
|
||||||
|
t.Fatalf("expected nil rules, got %v", rules)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePolicyRule(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "original",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:pull"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update priority and description.
|
||||||
|
updates := PolicyRuleRow{
|
||||||
|
Priority: 25,
|
||||||
|
Description: "updated",
|
||||||
|
}
|
||||||
|
if err := d.UpdatePolicyRule(id, updates); err != nil {
|
||||||
|
t.Fatalf("UpdatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if got.Priority != 25 {
|
||||||
|
t.Fatalf("priority: got %d, want 25", got.Priority)
|
||||||
|
}
|
||||||
|
if got.Description != "updated" {
|
||||||
|
t.Fatalf("description: got %q, want %q", got.Description, "updated")
|
||||||
|
}
|
||||||
|
// Effect should be unchanged.
|
||||||
|
if got.Effect != "allow" {
|
||||||
|
t.Fatalf("effect: got %q, want %q (unchanged)", got.Effect, "allow")
|
||||||
|
}
|
||||||
|
// Actions should be unchanged.
|
||||||
|
if len(got.Actions) != 1 || got.Actions[0] != "registry:pull" {
|
||||||
|
t.Fatalf("actions: got %v, want [registry:pull] (unchanged)", got.Actions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePolicyRuleBody(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "test",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:pull"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update rule body fields.
|
||||||
|
updates := PolicyRuleRow{
|
||||||
|
Effect: "deny",
|
||||||
|
Actions: []string{"registry:push", "registry:delete"},
|
||||||
|
Roles: []string{"ci"},
|
||||||
|
}
|
||||||
|
if err := d.UpdatePolicyRule(id, updates); err != nil {
|
||||||
|
t.Fatalf("UpdatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if got.Effect != "deny" {
|
||||||
|
t.Fatalf("effect: got %q, want %q", got.Effect, "deny")
|
||||||
|
}
|
||||||
|
if len(got.Actions) != 2 {
|
||||||
|
t.Fatalf("actions: got %d, want 2", len(got.Actions))
|
||||||
|
}
|
||||||
|
if len(got.Roles) != 1 || got.Roles[0] != "ci" {
|
||||||
|
t.Fatalf("roles: got %v, want [ci]", got.Roles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdatePolicyRuleNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
err := d.UpdatePolicyRule(9999, PolicyRuleRow{Description: "nope"})
|
||||||
|
if !errors.Is(err, ErrPolicyRuleNotFound) {
|
||||||
|
t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPolicyRuleEnabled(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "test",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:pull"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable the rule.
|
||||||
|
if err := d.SetPolicyRuleEnabled(id, false); err != nil {
|
||||||
|
t.Fatalf("SetPolicyRuleEnabled(false): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if got.Enabled {
|
||||||
|
t.Fatal("enabled: got true, want false")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Re-enable.
|
||||||
|
if err := d.SetPolicyRuleEnabled(id, true); err != nil {
|
||||||
|
t.Fatalf("SetPolicyRuleEnabled(true): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err = d.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetPolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
if !got.Enabled {
|
||||||
|
t.Fatal("enabled: got false, want true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetPolicyRuleEnabledNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
err := d.SetPolicyRuleEnabled(9999, true)
|
||||||
|
if !errors.Is(err, ErrPolicyRuleNotFound) {
|
||||||
|
t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeletePolicyRule(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
rule := PolicyRuleRow{
|
||||||
|
Priority: 50,
|
||||||
|
Description: "to delete",
|
||||||
|
Effect: "allow",
|
||||||
|
Actions: []string{"registry:pull"},
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.CreatePolicyRule(rule)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("CreatePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.DeletePolicyRule(id); err != nil {
|
||||||
|
t.Fatalf("DeletePolicyRule: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone.
|
||||||
|
_, err = d.GetPolicyRule(id)
|
||||||
|
if !errors.Is(err, ErrPolicyRuleNotFound) {
|
||||||
|
t.Fatalf("expected ErrPolicyRuleNotFound after delete, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeletePolicyRuleNotFound(t *testing.T) {
|
||||||
|
d := migratedTestDB(t)
|
||||||
|
|
||||||
|
err := d.DeletePolicyRule(9999)
|
||||||
|
if !errors.Is(err, ErrPolicyRuleNotFound) {
|
||||||
|
t.Fatalf("expected ErrPolicyRuleNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
9
internal/db/errors.go
Normal file
9
internal/db/errors.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrRepoNotFound = errors.New("db: repository not found")
|
||||||
|
ErrManifestNotFound = errors.New("db: manifest not found")
|
||||||
|
ErrBlobNotFound = errors.New("db: blob not found")
|
||||||
|
)
|
||||||
154
internal/db/push.go
Normal file
154
internal/db/push.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetOrCreateRepository returns the repository ID for the given name,
|
||||||
|
// creating it if it does not exist (implicit creation on first push).
|
||||||
|
func (d *DB) GetOrCreateRepository(name string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&id)
|
||||||
|
if err == nil {
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
if !errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return 0, fmt.Errorf("db: get repository: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: create repository: %w", err)
|
||||||
|
}
|
||||||
|
id, err = result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("db: repository last insert id: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobExists checks whether a blob with the given digest exists in the blobs table.
|
||||||
|
func (d *DB) BlobExists(digest string) (bool, error) {
|
||||||
|
var count int
|
||||||
|
err := d.QueryRow(`SELECT COUNT(*) FROM blobs WHERE digest = ?`, digest).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("db: blob exists: %w", err)
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertBlob inserts a blob row if it does not already exist.
|
||||||
|
// Returns without error if the blob already exists (content-addressed dedup).
|
||||||
|
func (d *DB) InsertBlob(digest string, size int64) error {
|
||||||
|
_, err := d.Exec(
|
||||||
|
`INSERT OR IGNORE INTO blobs (digest, size) VALUES (?, ?)`,
|
||||||
|
digest, size,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: insert blob: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushManifestParams holds the parameters for a manifest push operation.
|
||||||
|
type PushManifestParams struct {
|
||||||
|
RepoName string
|
||||||
|
Digest string
|
||||||
|
MediaType string
|
||||||
|
Content []byte
|
||||||
|
Size int64
|
||||||
|
Tag string // empty if push-by-digest
|
||||||
|
BlobDigests []string // referenced blob digests
|
||||||
|
}
|
||||||
|
|
||||||
|
// PushManifest executes the full manifest push in a single transaction per
|
||||||
|
// ARCHITECTURE.md §5. It creates the repository if needed, inserts/updates
|
||||||
|
// the manifest, populates manifest_blobs, and updates the tag if provided.
|
||||||
|
func (d *DB) PushManifest(p PushManifestParams) error {
|
||||||
|
tx, err := d.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: begin push manifest: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step a: create repository if not exists.
|
||||||
|
var repoID int64
|
||||||
|
err = tx.QueryRow(`SELECT id FROM repositories WHERE name = ?`, p.RepoName).Scan(&repoID)
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
result, insertErr := tx.Exec(`INSERT INTO repositories (name) VALUES (?)`, p.RepoName)
|
||||||
|
if insertErr != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: create repository: %w", insertErr)
|
||||||
|
}
|
||||||
|
repoID, err = result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: repository last insert id: %w", err)
|
||||||
|
}
|
||||||
|
} else if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: get repository: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step b: insert or update manifest.
|
||||||
|
// Use INSERT OR REPLACE on the UNIQUE(repository_id, digest) constraint.
|
||||||
|
result, err := tx.Exec(
|
||||||
|
`INSERT INTO manifests (repository_id, digest, media_type, content, size)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(repository_id, digest) DO UPDATE SET
|
||||||
|
media_type = excluded.media_type,
|
||||||
|
content = excluded.content,
|
||||||
|
size = excluded.size`,
|
||||||
|
repoID, p.Digest, p.MediaType, p.Content, p.Size,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: insert manifest: %w", err)
|
||||||
|
}
|
||||||
|
manifestID, err := result.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: manifest last insert id: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step c: populate manifest_blobs join table.
|
||||||
|
// Delete existing entries for this manifest first (in case of re-push).
|
||||||
|
_, err = tx.Exec(`DELETE FROM manifest_blobs WHERE manifest_id = ?`, manifestID)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: clear manifest_blobs: %w", err)
|
||||||
|
}
|
||||||
|
for _, blobDigest := range p.BlobDigests {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
`INSERT INTO manifest_blobs (manifest_id, blob_id)
|
||||||
|
SELECT ?, id FROM blobs WHERE digest = ?`,
|
||||||
|
manifestID, blobDigest,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: insert manifest_blob: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step d: if reference is a tag, insert or update tag row.
|
||||||
|
if p.Tag != "" {
|
||||||
|
_, err = tx.Exec(
|
||||||
|
`INSERT INTO tags (repository_id, name, manifest_id)
|
||||||
|
VALUES (?, ?, ?)
|
||||||
|
ON CONFLICT(repository_id, name) DO UPDATE SET
|
||||||
|
manifest_id = excluded.manifest_id,
|
||||||
|
updated_at = strftime('%Y-%m-%dT%H:%M:%SZ','now')`,
|
||||||
|
repoID, p.Tag, manifestID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
return fmt.Errorf("db: upsert tag: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := tx.Commit(); err != nil {
|
||||||
|
return fmt.Errorf("db: commit push manifest: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
278
internal/db/push_test.go
Normal file
278
internal/db/push_test.go
Normal file
@@ -0,0 +1,278 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestGetOrCreateRepositoryNew(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.GetOrCreateRepository("newrepo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetOrCreateRepository: %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Fatalf("id: got %d, want > 0", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second call should return the same ID.
|
||||||
|
id2, err := d.GetOrCreateRepository("newrepo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetOrCreateRepository (second): %v", err)
|
||||||
|
}
|
||||||
|
if id2 != id {
|
||||||
|
t.Fatalf("id mismatch: got %d, want %d", id2, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetOrCreateRepositoryExisting(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('existing')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := d.GetOrCreateRepository("existing")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetOrCreateRepository: %v", err)
|
||||||
|
}
|
||||||
|
if id <= 0 {
|
||||||
|
t.Fatalf("id: got %d, want > 0", id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobExists(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO blobs (digest, size) VALUES ('sha256:aaa', 100)`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := d.BlobExists("sha256:aaa")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExists: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("expected blob to exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err = d.BlobExists("sha256:nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExists (nonexistent): %v", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("expected blob to not exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInsertBlob(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.InsertBlob("sha256:bbb", 200); err != nil {
|
||||||
|
t.Fatalf("InsertBlob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := d.BlobExists("sha256:bbb")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExists: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("expected blob to exist after insert")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert again — should be a no-op (INSERT OR IGNORE).
|
||||||
|
if err := d.InsertBlob("sha256:bbb", 200); err != nil {
|
||||||
|
t.Fatalf("InsertBlob (dup): %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushManifestByTag(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert blobs first.
|
||||||
|
if err := d.InsertBlob("sha256:config111", 50); err != nil {
|
||||||
|
t.Fatalf("insert config blob: %v", err)
|
||||||
|
}
|
||||||
|
if err := d.InsertBlob("sha256:layer111", 1000); err != nil {
|
||||||
|
t.Fatalf("insert layer blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
params := PushManifestParams{
|
||||||
|
RepoName: "myrepo",
|
||||||
|
Digest: "sha256:manifest111",
|
||||||
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||||
|
Content: content,
|
||||||
|
Size: int64(len(content)),
|
||||||
|
Tag: "latest",
|
||||||
|
BlobDigests: []string{"sha256:config111", "sha256:layer111"},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.PushManifest(params); err != nil {
|
||||||
|
t.Fatalf("PushManifest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify repository was created.
|
||||||
|
repoID, err := d.GetRepositoryByName("myrepo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRepositoryByName: %v", err)
|
||||||
|
}
|
||||||
|
if repoID <= 0 {
|
||||||
|
t.Fatalf("repo id: got %d, want > 0", repoID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify manifest exists.
|
||||||
|
m, err := d.GetManifestByDigest(repoID, "sha256:manifest111")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetManifestByDigest: %v", err)
|
||||||
|
}
|
||||||
|
if m.MediaType != "application/vnd.oci.image.manifest.v1+json" {
|
||||||
|
t.Fatalf("media type: got %q", m.MediaType)
|
||||||
|
}
|
||||||
|
if m.Size != int64(len(content)) {
|
||||||
|
t.Fatalf("size: got %d, want %d", m.Size, len(content))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify tag points to manifest.
|
||||||
|
m2, err := d.GetManifestByTag(repoID, "latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetManifestByTag: %v", err)
|
||||||
|
}
|
||||||
|
if m2.Digest != "sha256:manifest111" {
|
||||||
|
t.Fatalf("tag digest: got %q", m2.Digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify manifest_blobs join table.
|
||||||
|
var mbCount int
|
||||||
|
if err := d.QueryRow(`SELECT COUNT(*) FROM manifest_blobs WHERE manifest_id = ?`, m.ID).Scan(&mbCount); err != nil {
|
||||||
|
t.Fatalf("count manifest_blobs: %v", err)
|
||||||
|
}
|
||||||
|
if mbCount != 2 {
|
||||||
|
t.Fatalf("manifest_blobs count: got %d, want 2", mbCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushManifestByDigest(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
params := PushManifestParams{
|
||||||
|
RepoName: "myrepo",
|
||||||
|
Digest: "sha256:manifest222",
|
||||||
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||||
|
Content: content,
|
||||||
|
Size: int64(len(content)),
|
||||||
|
Tag: "", // push by digest — no tag
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.PushManifest(params); err != nil {
|
||||||
|
t.Fatalf("PushManifest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no tag was created.
|
||||||
|
var tagCount int
|
||||||
|
if err := d.QueryRow(`SELECT COUNT(*) FROM tags`).Scan(&tagCount); err != nil {
|
||||||
|
t.Fatalf("count tags: %v", err)
|
||||||
|
}
|
||||||
|
if tagCount != 0 {
|
||||||
|
t.Fatalf("tag count: got %d, want 0", tagCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushManifestTagMove(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push first manifest with tag "latest".
|
||||||
|
content1 := []byte(`{"schemaVersion":2,"v":"1"}`)
|
||||||
|
if err := d.PushManifest(PushManifestParams{
|
||||||
|
RepoName: "myrepo",
|
||||||
|
Digest: "sha256:first",
|
||||||
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||||
|
Content: content1,
|
||||||
|
Size: int64(len(content1)),
|
||||||
|
Tag: "latest",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("PushManifest (first): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push second manifest with same tag "latest" — should atomically move tag.
|
||||||
|
content2 := []byte(`{"schemaVersion":2,"v":"2"}`)
|
||||||
|
if err := d.PushManifest(PushManifestParams{
|
||||||
|
RepoName: "myrepo",
|
||||||
|
Digest: "sha256:second",
|
||||||
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||||
|
Content: content2,
|
||||||
|
Size: int64(len(content2)),
|
||||||
|
Tag: "latest",
|
||||||
|
}); err != nil {
|
||||||
|
t.Fatalf("PushManifest (second): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID, err := d.GetRepositoryByName("myrepo")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRepositoryByName: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := d.GetManifestByTag(repoID, "latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetManifestByTag: %v", err)
|
||||||
|
}
|
||||||
|
if m.Digest != "sha256:second" {
|
||||||
|
t.Fatalf("tag should point to second manifest, got %q", m.Digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPushManifestIdempotent(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
params := PushManifestParams{
|
||||||
|
RepoName: "myrepo",
|
||||||
|
Digest: "sha256:manifest333",
|
||||||
|
MediaType: "application/vnd.oci.image.manifest.v1+json",
|
||||||
|
Content: content,
|
||||||
|
Size: int64(len(content)),
|
||||||
|
Tag: "latest",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push twice — should not fail.
|
||||||
|
if err := d.PushManifest(params); err != nil {
|
||||||
|
t.Fatalf("PushManifest (first): %v", err)
|
||||||
|
}
|
||||||
|
if err := d.PushManifest(params); err != nil {
|
||||||
|
t.Fatalf("PushManifest (second): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify only one manifest exists.
|
||||||
|
var mCount int
|
||||||
|
if err := d.QueryRow(`SELECT COUNT(*) FROM manifests`).Scan(&mCount); err != nil {
|
||||||
|
t.Fatalf("count manifests: %v", err)
|
||||||
|
}
|
||||||
|
if mCount != 1 {
|
||||||
|
t.Fatalf("manifest count: got %d, want 1", mCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
154
internal/db/repository.go
Normal file
154
internal/db/repository.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ManifestRow represents a manifest as stored in the database.
|
||||||
|
type ManifestRow struct {
|
||||||
|
ID int64
|
||||||
|
RepositoryID int64
|
||||||
|
Digest string
|
||||||
|
MediaType string
|
||||||
|
Content []byte
|
||||||
|
Size int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRepositoryByName returns the repository ID for the given name.
|
||||||
|
func (d *DB) GetRepositoryByName(name string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
err := d.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return 0, ErrRepoNotFound
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("db: get repository by name: %w", err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetManifestByTag returns the manifest associated with the given tag in a repository.
|
||||||
|
func (d *DB) GetManifestByTag(repoID int64, tag string) (*ManifestRow, error) {
|
||||||
|
var m ManifestRow
|
||||||
|
err := d.QueryRow(
|
||||||
|
`SELECT m.id, m.repository_id, m.digest, m.media_type, m.content, m.size
|
||||||
|
FROM manifests m
|
||||||
|
JOIN tags t ON t.manifest_id = m.id
|
||||||
|
WHERE t.repository_id = ? AND t.name = ?`,
|
||||||
|
repoID, tag,
|
||||||
|
).Scan(&m.ID, &m.RepositoryID, &m.Digest, &m.MediaType, &m.Content, &m.Size)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrManifestNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("db: get manifest by tag: %w", err)
|
||||||
|
}
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetManifestByDigest returns the manifest with the given digest in a repository.
|
||||||
|
func (d *DB) GetManifestByDigest(repoID int64, digest string) (*ManifestRow, error) {
|
||||||
|
var m ManifestRow
|
||||||
|
err := d.QueryRow(
|
||||||
|
`SELECT id, repository_id, digest, media_type, content, size
|
||||||
|
FROM manifests
|
||||||
|
WHERE repository_id = ? AND digest = ?`,
|
||||||
|
repoID, digest,
|
||||||
|
).Scan(&m.ID, &m.RepositoryID, &m.Digest, &m.MediaType, &m.Content, &m.Size)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrManifestNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("db: get manifest by digest: %w", err)
|
||||||
|
}
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobExistsInRepo checks whether a blob with the given digest exists and is
|
||||||
|
// referenced by at least one manifest in the given repository.
|
||||||
|
func (d *DB) BlobExistsInRepo(repoID int64, digest string) (bool, error) {
|
||||||
|
var count int
|
||||||
|
err := d.QueryRow(
|
||||||
|
`SELECT COUNT(*) FROM blobs b
|
||||||
|
JOIN manifest_blobs mb ON mb.blob_id = b.id
|
||||||
|
JOIN manifests m ON m.id = mb.manifest_id
|
||||||
|
WHERE m.repository_id = ? AND b.digest = ?`,
|
||||||
|
repoID, digest,
|
||||||
|
).Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("db: blob exists in repo: %w", err)
|
||||||
|
}
|
||||||
|
return count > 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListTags returns tag names for a repository, ordered alphabetically.
|
||||||
|
// Pagination is cursor-based: after is the last tag name from the previous page,
|
||||||
|
// limit is the maximum number of tags to return.
|
||||||
|
func (d *DB) ListTags(repoID int64, after string, limit int) ([]string, error) {
|
||||||
|
var query string
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if after != "" {
|
||||||
|
query = `SELECT name FROM tags WHERE repository_id = ? AND name > ? ORDER BY name ASC LIMIT ?`
|
||||||
|
args = []any{repoID, after, limit}
|
||||||
|
} else {
|
||||||
|
query = `SELECT name FROM tags WHERE repository_id = ? ORDER BY name ASC LIMIT ?`
|
||||||
|
args = []any{repoID, limit}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := d.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list tags: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var tags []string
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan tag: %w", err)
|
||||||
|
}
|
||||||
|
tags = append(tags, name)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate tags: %w", err)
|
||||||
|
}
|
||||||
|
return tags, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListRepositoryNames returns repository names ordered alphabetically.
|
||||||
|
// Pagination is cursor-based: after is the last repo name from the previous page,
|
||||||
|
// limit is the maximum number of names to return.
|
||||||
|
func (d *DB) ListRepositoryNames(after string, limit int) ([]string, error) {
|
||||||
|
var query string
|
||||||
|
var args []any
|
||||||
|
|
||||||
|
if after != "" {
|
||||||
|
query = `SELECT name FROM repositories WHERE name > ? ORDER BY name ASC LIMIT ?`
|
||||||
|
args = []any{after, limit}
|
||||||
|
} else {
|
||||||
|
query = `SELECT name FROM repositories ORDER BY name ASC LIMIT ?`
|
||||||
|
args = []any{limit}
|
||||||
|
}
|
||||||
|
|
||||||
|
rows, err := d.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("db: list repository names: %w", err)
|
||||||
|
}
|
||||||
|
defer func() { _ = rows.Close() }()
|
||||||
|
|
||||||
|
var names []string
|
||||||
|
for rows.Next() {
|
||||||
|
var name string
|
||||||
|
if err := rows.Scan(&name); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: scan repository name: %w", err)
|
||||||
|
}
|
||||||
|
names = append(names, name)
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, fmt.Errorf("db: iterate repository names: %w", err)
|
||||||
|
}
|
||||||
|
return names, nil
|
||||||
|
}
|
||||||
429
internal/db/repository_test.go
Normal file
429
internal/db/repository_test.go
Normal file
@@ -0,0 +1,429 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// seedTestRepo inserts a repository, manifest, tag, blob, and manifest_blob
|
||||||
|
// link for use in repository query tests. It returns the repository ID.
|
||||||
|
func seedTestRepo(t *testing.T, d *DB) int64 {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('myorg/myapp')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var repoID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'myorg/myapp'`).Scan(&repoID); err != nil {
|
||||||
|
t.Fatalf("select repo id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.Exec(
|
||||||
|
`INSERT INTO manifests (repository_id, digest, media_type, content, size)
|
||||||
|
VALUES (?, 'sha256:aaaa', 'application/vnd.oci.image.manifest.v1+json', '{"layers":[]}', 15)`,
|
||||||
|
repoID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert manifest: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var manifestID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM manifests WHERE digest = 'sha256:aaaa'`).Scan(&manifestID); err != nil {
|
||||||
|
t.Fatalf("select manifest id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.Exec(
|
||||||
|
`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, 'latest', ?)`,
|
||||||
|
repoID, manifestID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert tag: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.Exec(`INSERT INTO blobs (digest, size) VALUES ('sha256:bbbb', 2048)`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var blobID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM blobs WHERE digest = 'sha256:bbbb'`).Scan(&blobID); err != nil {
|
||||||
|
t.Fatalf("select blob id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.Exec(
|
||||||
|
`INSERT INTO manifest_blobs (manifest_id, blob_id) VALUES (?, ?)`,
|
||||||
|
manifestID, blobID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert manifest_blob: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return repoID
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepositoryByName_Found(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedTestRepo(t, d)
|
||||||
|
|
||||||
|
id, err := d.GetRepositoryByName("myorg/myapp")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRepositoryByName: %v", err)
|
||||||
|
}
|
||||||
|
if id == 0 {
|
||||||
|
t.Fatal("expected non-zero repository ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRepositoryByName_NotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.GetRepositoryByName("nonexistent")
|
||||||
|
if !errors.Is(err, ErrRepoNotFound) {
|
||||||
|
t.Fatalf("expected ErrRepoNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetManifestByTag_Found(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
m, err := d.GetManifestByTag(repoID, "latest")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetManifestByTag: %v", err)
|
||||||
|
}
|
||||||
|
if m.Digest != "sha256:aaaa" {
|
||||||
|
t.Fatalf("digest: got %q, want %q", m.Digest, "sha256:aaaa")
|
||||||
|
}
|
||||||
|
if m.MediaType != "application/vnd.oci.image.manifest.v1+json" {
|
||||||
|
t.Fatalf("media type: got %q, want OCI manifest", m.MediaType)
|
||||||
|
}
|
||||||
|
if m.Size != 15 {
|
||||||
|
t.Fatalf("size: got %d, want 15", m.Size)
|
||||||
|
}
|
||||||
|
if string(m.Content) != `{"layers":[]}` {
|
||||||
|
t.Fatalf("content: got %q, want {\"layers\":[]}", string(m.Content))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetManifestByTag_NotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
_, err := d.GetManifestByTag(repoID, "v0.0.0-nonexistent")
|
||||||
|
if !errors.Is(err, ErrManifestNotFound) {
|
||||||
|
t.Fatalf("expected ErrManifestNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetManifestByDigest_Found(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
m, err := d.GetManifestByDigest(repoID, "sha256:aaaa")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetManifestByDigest: %v", err)
|
||||||
|
}
|
||||||
|
if m.Digest != "sha256:aaaa" {
|
||||||
|
t.Fatalf("digest: got %q, want %q", m.Digest, "sha256:aaaa")
|
||||||
|
}
|
||||||
|
if m.RepositoryID != repoID {
|
||||||
|
t.Fatalf("repository_id: got %d, want %d", m.RepositoryID, repoID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetManifestByDigest_NotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
_, err := d.GetManifestByDigest(repoID, "sha256:nonexistent")
|
||||||
|
if !errors.Is(err, ErrManifestNotFound) {
|
||||||
|
t.Fatalf("expected ErrManifestNotFound, got %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobExistsInRepo_Exists(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
exists, err := d.BlobExistsInRepo(repoID, "sha256:bbbb")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExistsInRepo: %v", err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("expected blob to exist in repo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobExistsInRepo_NotInThisRepo(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
seedTestRepo(t, d) // creates blob sha256:bbbb in myorg/myapp
|
||||||
|
|
||||||
|
// Create a second repo with no manifests linking the blob.
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('other/repo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert other repo: %v", err)
|
||||||
|
}
|
||||||
|
var otherRepoID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'other/repo'`).Scan(&otherRepoID); err != nil {
|
||||||
|
t.Fatalf("select other repo id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := d.BlobExistsInRepo(otherRepoID, "sha256:bbbb")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExistsInRepo: %v", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("expected blob to NOT exist in other repo")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobExistsInRepo_BlobDoesNotExist(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
exists, err := d.BlobExistsInRepo(repoID, "sha256:nonexistent")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("BlobExistsInRepo: %v", err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("expected blob to not exist")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListTags_WithTags(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
// Add more tags pointing to the same manifest.
|
||||||
|
var manifestID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil {
|
||||||
|
t.Fatalf("select manifest id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tag := range []string{"v1.0", "v2.0", "beta"} {
|
||||||
|
_, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`,
|
||||||
|
repoID, tag, manifestID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert tag %q: %v", tag, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, err := d.ListTags(repoID, "", 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTags: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expect alphabetical: beta, latest, v1.0, v2.0
|
||||||
|
want := []string{"beta", "latest", "v1.0", "v2.0"}
|
||||||
|
if len(tags) != len(want) {
|
||||||
|
t.Fatalf("tags count: got %d, want %d", len(tags), len(want))
|
||||||
|
}
|
||||||
|
for i, tag := range tags {
|
||||||
|
if tag != want[i] {
|
||||||
|
t.Fatalf("tags[%d]: got %q, want %q", i, tag, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListTags_Pagination(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID := seedTestRepo(t, d)
|
||||||
|
|
||||||
|
var manifestID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil {
|
||||||
|
t.Fatalf("select manifest id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tag := range []string{"v1.0", "v2.0", "beta"} {
|
||||||
|
_, err := d.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`,
|
||||||
|
repoID, tag, manifestID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert tag %q: %v", tag, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First page: 2 tags starting from beginning.
|
||||||
|
tags, err := d.ListTags(repoID, "", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTags page 1: %v", err)
|
||||||
|
}
|
||||||
|
if len(tags) != 2 {
|
||||||
|
t.Fatalf("page 1 count: got %d, want 2", len(tags))
|
||||||
|
}
|
||||||
|
if tags[0] != "beta" || tags[1] != "latest" {
|
||||||
|
t.Fatalf("page 1: got %v, want [beta, latest]", tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second page: after "latest".
|
||||||
|
tags, err = d.ListTags(repoID, "latest", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTags page 2: %v", err)
|
||||||
|
}
|
||||||
|
if len(tags) != 2 {
|
||||||
|
t.Fatalf("page 2 count: got %d, want 2", len(tags))
|
||||||
|
}
|
||||||
|
if tags[0] != "v1.0" || tags[1] != "v2.0" {
|
||||||
|
t.Fatalf("page 2: got %v, want [v1.0, v2.0]", tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third page: after "v2.0" — no more tags.
|
||||||
|
tags, err = d.ListTags(repoID, "v2.0", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTags page 3: %v", err)
|
||||||
|
}
|
||||||
|
if len(tags) != 0 {
|
||||||
|
t.Fatalf("page 3 count: got %d, want 0", len(tags))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListTags_Empty(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a repo with no tags.
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('empty/repo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
var repoID int64
|
||||||
|
if err := d.QueryRow(`SELECT id FROM repositories WHERE name = 'empty/repo'`).Scan(&repoID); err != nil {
|
||||||
|
t.Fatalf("select repo id: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tags, err := d.ListTags(repoID, "", 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListTags: %v", err)
|
||||||
|
}
|
||||||
|
if tags != nil {
|
||||||
|
t.Fatalf("expected nil tags, got %v", tags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoryNames_WithRepos(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range []string{"charlie/app", "alpha/lib", "bravo/svc"} {
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo %q: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := d.ListRepositoryNames("", 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoryNames: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{"alpha/lib", "bravo/svc", "charlie/app"}
|
||||||
|
if len(names) != len(want) {
|
||||||
|
t.Fatalf("names count: got %d, want %d", len(names), len(want))
|
||||||
|
}
|
||||||
|
for i, n := range names {
|
||||||
|
if n != want[i] {
|
||||||
|
t.Fatalf("names[%d]: got %q, want %q", i, n, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoryNames_Pagination(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, name := range []string{"charlie/app", "alpha/lib", "bravo/svc"} {
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo %q: %v", name, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// First page: 2.
|
||||||
|
names, err := d.ListRepositoryNames("", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoryNames page 1: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 2 {
|
||||||
|
t.Fatalf("page 1 count: got %d, want 2", len(names))
|
||||||
|
}
|
||||||
|
if names[0] != "alpha/lib" || names[1] != "bravo/svc" {
|
||||||
|
t.Fatalf("page 1: got %v", names)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second page: after "bravo/svc".
|
||||||
|
names, err = d.ListRepositoryNames("bravo/svc", 2)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoryNames page 2: %v", err)
|
||||||
|
}
|
||||||
|
if len(names) != 1 {
|
||||||
|
t.Fatalf("page 2 count: got %d, want 1", len(names))
|
||||||
|
}
|
||||||
|
if names[0] != "charlie/app" {
|
||||||
|
t.Fatalf("page 2: got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestListRepositoryNames_Empty(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
names, err := d.ListRepositoryNames("", 100)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListRepositoryNames: %v", err)
|
||||||
|
}
|
||||||
|
if names != nil {
|
||||||
|
t.Fatalf("expected nil names, got %v", names)
|
||||||
|
}
|
||||||
|
}
|
||||||
80
internal/db/upload.go
Normal file
80
internal/db/upload.go
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrUploadNotFound indicates the requested upload UUID does not exist.
|
||||||
|
var ErrUploadNotFound = errors.New("db: upload not found")
|
||||||
|
|
||||||
|
// UploadRow represents a row in the uploads table.
|
||||||
|
type UploadRow struct {
|
||||||
|
ID int64
|
||||||
|
UUID string
|
||||||
|
RepositoryID int64
|
||||||
|
ByteOffset int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateUpload inserts a new upload row and returns its ID.
|
||||||
|
func (d *DB) CreateUpload(uuid string, repoID int64) error {
|
||||||
|
_, err := d.Exec(
|
||||||
|
`INSERT INTO uploads (uuid, repository_id) VALUES (?, ?)`,
|
||||||
|
uuid, repoID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: create upload: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUpload returns the upload with the given UUID.
|
||||||
|
func (d *DB) GetUpload(uuid string) (*UploadRow, error) {
|
||||||
|
var u UploadRow
|
||||||
|
err := d.QueryRow(
|
||||||
|
`SELECT id, uuid, repository_id, byte_offset FROM uploads WHERE uuid = ?`, uuid,
|
||||||
|
).Scan(&u.ID, &u.UUID, &u.RepositoryID, &u.ByteOffset)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return nil, ErrUploadNotFound
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("db: get upload: %w", err)
|
||||||
|
}
|
||||||
|
return &u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateUploadOffset sets the byte_offset for an upload.
|
||||||
|
func (d *DB) UpdateUploadOffset(uuid string, offset int64) error {
|
||||||
|
result, err := d.Exec(
|
||||||
|
`UPDATE uploads SET byte_offset = ? WHERE uuid = ?`,
|
||||||
|
offset, uuid,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: update upload offset: %w", err)
|
||||||
|
}
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: update upload offset rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return ErrUploadNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteUpload removes the upload row with the given UUID.
|
||||||
|
func (d *DB) DeleteUpload(uuid string) error {
|
||||||
|
result, err := d.Exec(`DELETE FROM uploads WHERE uuid = ?`, uuid)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete upload: %w", err)
|
||||||
|
}
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("db: delete upload rows affected: %w", err)
|
||||||
|
}
|
||||||
|
if n == 0 {
|
||||||
|
return ErrUploadNotFound
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
124
internal/db/upload_test.go
Normal file
124
internal/db/upload_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCreateAndGetUpload(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a repository for the upload.
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.CreateUpload("test-uuid-1", 1); err != nil {
|
||||||
|
t.Fatalf("CreateUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := d.GetUpload("test-uuid-1")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUpload: %v", err)
|
||||||
|
}
|
||||||
|
if u.UUID != "test-uuid-1" {
|
||||||
|
t.Fatalf("uuid: got %q, want %q", u.UUID, "test-uuid-1")
|
||||||
|
}
|
||||||
|
if u.RepositoryID != 1 {
|
||||||
|
t.Fatalf("repo id: got %d, want 1", u.RepositoryID)
|
||||||
|
}
|
||||||
|
if u.ByteOffset != 0 {
|
||||||
|
t.Fatalf("byte offset: got %d, want 0", u.ByteOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetUploadNotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.GetUpload("nonexistent")
|
||||||
|
if !errors.Is(err, ErrUploadNotFound) {
|
||||||
|
t.Fatalf("err: got %v, want ErrUploadNotFound", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateUploadOffset(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
if err := d.CreateUpload("test-uuid-2", 1); err != nil {
|
||||||
|
t.Fatalf("CreateUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.UpdateUploadOffset("test-uuid-2", 1024); err != nil {
|
||||||
|
t.Fatalf("UpdateUploadOffset: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := d.GetUpload("test-uuid-2")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetUpload: %v", err)
|
||||||
|
}
|
||||||
|
if u.ByteOffset != 1024 {
|
||||||
|
t.Fatalf("byte offset: got %d, want 1024", u.ByteOffset)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateUploadOffsetNotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := d.UpdateUploadOffset("nonexistent", 100)
|
||||||
|
if !errors.Is(err, ErrUploadNotFound) {
|
||||||
|
t.Fatalf("err: got %v, want ErrUploadNotFound", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteUpload(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := d.Exec(`INSERT INTO repositories (name) VALUES ('testrepo')`)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo: %v", err)
|
||||||
|
}
|
||||||
|
if err := d.CreateUpload("test-uuid-3", 1); err != nil {
|
||||||
|
t.Fatalf("CreateUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := d.DeleteUpload("test-uuid-3"); err != nil {
|
||||||
|
t.Fatalf("DeleteUpload: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = d.GetUpload("test-uuid-3")
|
||||||
|
if !errors.Is(err, ErrUploadNotFound) {
|
||||||
|
t.Fatalf("after delete: got %v, want ErrUploadNotFound", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteUploadNotFound(t *testing.T) {
|
||||||
|
d := openTestDB(t)
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err := d.DeleteUpload("nonexistent")
|
||||||
|
if !errors.Is(err, ErrUploadNotFound) {
|
||||||
|
t.Fatalf("err: got %v, want ErrUploadNotFound", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
98
internal/oci/blob.go
Normal file
98
internal/oci/blob.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) handleBlobGet(w http.ResponseWriter, r *http.Request, repo, digest string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPull, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID, err := h.db.GetRepositoryByName(repo)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("repository %q not found", repo))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := h.db.BlobExistsInRepo(repoID, digest)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("blob %q not found in repository", digest))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := h.blobs.Stat(digest)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, "blob not found in storage")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rc, err := h.blobs.Open(digest)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = rc.Close() }()
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/octet-stream")
|
||||||
|
w.Header().Set("Docker-Content-Digest", digest)
|
||||||
|
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = io.Copy(w, rc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleBlobHead(w http.ResponseWriter, r *http.Request, repo, digest string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPull, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID, err := h.db.GetRepositoryByName(repo)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("repository %q not found", repo))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
exists, err := h.db.BlobExistsInRepo(repoID, digest)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("blob %q not found in repository", digest))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
size, err := h.blobs.Stat(digest)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "BLOB_UNKNOWN", http.StatusNotFound, "blob not found in storage")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/octet-stream")
|
||||||
|
w.Header().Set("Docker-Content-Digest", digest)
|
||||||
|
w.Header().Set("Content-Length", strconv.FormatInt(size, 10))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
144
internal/oci/blob_test.go
Normal file
144
internal/oci/blob_test.go
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBlobGet(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addBlob(1, "sha256:layerdigest")
|
||||||
|
|
||||||
|
blobs := newFakeBlobs()
|
||||||
|
blobs.data["sha256:layerdigest"] = []byte("layer-content-bytes")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, blobs, allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/blobs/sha256:layerdigest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := rr.Header().Get("Content-Type"); ct != "application/octet-stream" {
|
||||||
|
t.Fatalf("Content-Type: got %q", ct)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:layerdigest" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
if cl := rr.Header().Get("Content-Length"); cl != "19" {
|
||||||
|
t.Fatalf("Content-Length: got %q, want %q", cl, "19")
|
||||||
|
}
|
||||||
|
if rr.Body.String() != "layer-content-bytes" {
|
||||||
|
t.Fatalf("body: got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobHead(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addBlob(1, "sha256:layerdigest")
|
||||||
|
|
||||||
|
blobs := newFakeBlobs()
|
||||||
|
blobs.data["sha256:layerdigest"] = []byte("layer-content-bytes")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, blobs, allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("HEAD", "/v2/myrepo/blobs/sha256:layerdigest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := rr.Header().Get("Content-Type"); ct != "application/octet-stream" {
|
||||||
|
t.Fatalf("Content-Type: got %q", ct)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:layerdigest" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
if cl := rr.Header().Get("Content-Length"); cl != "19" {
|
||||||
|
t.Fatalf("Content-Length: got %q, want %q", cl, "19")
|
||||||
|
}
|
||||||
|
if rr.Body.Len() != 0 {
|
||||||
|
t.Fatalf("HEAD body should be empty, got %d bytes", rr.Body.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobGetNotInRepo(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
// Blob NOT added to the repo.
|
||||||
|
|
||||||
|
blobs := newFakeBlobs()
|
||||||
|
blobs.data["sha256:orphan"] = []byte("data")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, blobs, allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/blobs/sha256:orphan", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "BLOB_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want BLOB_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobGetRepoNotFound(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
// No repos.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/nosuchrepo/blobs/sha256:abc", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBlobGetMultiSegmentRepo(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("org/team/app", 1)
|
||||||
|
fdb.addBlob(1, "sha256:layerdigest")
|
||||||
|
|
||||||
|
blobs := newFakeBlobs()
|
||||||
|
blobs.data["sha256:layerdigest"] = []byte("data")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, blobs, allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/org/team/app/blobs/sha256:layerdigest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
}
|
||||||
68
internal/oci/catalog.go
Normal file
68
internal/oci/catalog.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type catalogResponse struct {
|
||||||
|
Repositories []string `json:"repositories"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleCatalog(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionCatalog, "") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
n := 0
|
||||||
|
var err error
|
||||||
|
if nStr := r.URL.Query().Get("n"); nStr != "" {
|
||||||
|
n, err = strconv.Atoi(nStr)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
writeOCIError(w, "INVALID_PARAMETER", http.StatusBadRequest, "invalid 'n' parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last := r.URL.Query().Get("last")
|
||||||
|
|
||||||
|
if n == 0 {
|
||||||
|
repos, err := h.db.ListRepositoryNames(last, 10000)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if repos == nil {
|
||||||
|
repos = []string{}
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(catalogResponse{Repositories: repos})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
repos, err := h.db.ListRepositoryNames(last, n+1)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if repos == nil {
|
||||||
|
repos = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasMore := len(repos) > n
|
||||||
|
if hasMore {
|
||||||
|
repos = repos[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasMore {
|
||||||
|
lastRepo := repos[len(repos)-1]
|
||||||
|
linkURL := fmt.Sprintf("/v2/_catalog?n=%d&last=%s", n, lastRepo)
|
||||||
|
w.Header().Set("Link", fmt.Sprintf(`<%s>; rel="next"`, linkURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(catalogResponse{Repositories: repos})
|
||||||
|
}
|
||||||
124
internal/oci/catalog_test.go
Normal file
124
internal/oci/catalog_test.go
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCatalog(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("alpha/lib", 1)
|
||||||
|
fdb.addRepo("bravo/svc", 2)
|
||||||
|
fdb.addRepo("charlie/app", 3)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/_catalog", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body catalogResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{"alpha/lib", "bravo/svc", "charlie/app"}
|
||||||
|
if len(body.Repositories) != len(want) {
|
||||||
|
t.Fatalf("repos count: got %d, want %d", len(body.Repositories), len(want))
|
||||||
|
}
|
||||||
|
for i, r := range body.Repositories {
|
||||||
|
if r != want[i] {
|
||||||
|
t.Fatalf("repos[%d]: got %q, want %q", i, r, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCatalogPagination(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("alpha/lib", 1)
|
||||||
|
fdb.addRepo("bravo/svc", 2)
|
||||||
|
fdb.addRepo("charlie/app", 3)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
// First page: n=2.
|
||||||
|
req := authedRequest("GET", "/v2/_catalog?n=2", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body catalogResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Repositories) != 2 {
|
||||||
|
t.Fatalf("page 1 count: got %d, want 2", len(body.Repositories))
|
||||||
|
}
|
||||||
|
if body.Repositories[0] != "alpha/lib" || body.Repositories[1] != "bravo/svc" {
|
||||||
|
t.Fatalf("page 1: got %v", body.Repositories)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Link header.
|
||||||
|
link := rr.Header().Get("Link")
|
||||||
|
if link == "" {
|
||||||
|
t.Fatal("expected Link header for pagination")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second page: n=2, last=bravo/svc.
|
||||||
|
req = authedRequest("GET", "/v2/_catalog?n=2&last=bravo/svc", nil)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("page 2 status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode page 2: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Repositories) != 1 {
|
||||||
|
t.Fatalf("page 2 count: got %d, want 1", len(body.Repositories))
|
||||||
|
}
|
||||||
|
if body.Repositories[0] != "charlie/app" {
|
||||||
|
t.Fatalf("page 2: got %v", body.Repositories)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No Link header on last page.
|
||||||
|
if link := rr.Header().Get("Link"); link != "" {
|
||||||
|
t.Fatalf("expected no Link header on last page, got %q", link)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCatalogEmpty(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/_catalog", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body catalogResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Repositories) != 0 {
|
||||||
|
t.Fatalf("expected empty repositories, got %v", body.Repositories)
|
||||||
|
}
|
||||||
|
}
|
||||||
119
internal/oci/handler.go
Normal file
119
internal/oci/handler.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DBQuerier provides the database operations needed by OCI handlers.
|
||||||
|
type DBQuerier interface {
|
||||||
|
GetRepositoryByName(name string) (int64, error)
|
||||||
|
GetManifestByTag(repoID int64, tag string) (*db.ManifestRow, error)
|
||||||
|
GetManifestByDigest(repoID int64, digest string) (*db.ManifestRow, error)
|
||||||
|
BlobExistsInRepo(repoID int64, digest string) (bool, error)
|
||||||
|
ListTags(repoID int64, after string, limit int) ([]string, error)
|
||||||
|
ListRepositoryNames(after string, limit int) ([]string, error)
|
||||||
|
|
||||||
|
// Push operations
|
||||||
|
GetOrCreateRepository(name string) (int64, error)
|
||||||
|
BlobExists(digest string) (bool, error)
|
||||||
|
InsertBlob(digest string, size int64) error
|
||||||
|
PushManifest(p db.PushManifestParams) error
|
||||||
|
|
||||||
|
// Upload operations
|
||||||
|
CreateUpload(uuid string, repoID int64) error
|
||||||
|
GetUpload(uuid string) (*db.UploadRow, error)
|
||||||
|
UpdateUploadOffset(uuid string, offset int64) error
|
||||||
|
DeleteUpload(uuid string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// BlobStore provides read and write access to blob storage.
|
||||||
|
type BlobStore interface {
|
||||||
|
Open(digest string) (io.ReadCloser, error)
|
||||||
|
Stat(digest string) (int64, error)
|
||||||
|
StartUpload(uuid string) (*storage.BlobWriter, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PolicyEval evaluates access control policies.
|
||||||
|
type PolicyEval interface {
|
||||||
|
Evaluate(input policy.PolicyInput) (policy.Effect, *policy.Rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuditFunc records audit events. Follows the same signature pattern as
|
||||||
|
// db.WriteAuditEvent but without an error return — audit failures should
|
||||||
|
// not block request processing.
|
||||||
|
type AuditFunc func(eventType, actorID, repository, digest, ip string, details map[string]string)
|
||||||
|
|
||||||
|
// Handler serves OCI Distribution Spec endpoints.
|
||||||
|
type Handler struct {
|
||||||
|
db DBQuerier
|
||||||
|
blobs BlobStore
|
||||||
|
policy PolicyEval
|
||||||
|
auditFn AuditFunc
|
||||||
|
uploads *uploadManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHandler creates a new OCI handler.
|
||||||
|
func NewHandler(querier DBQuerier, blobs BlobStore, pol PolicyEval, auditFn AuditFunc) *Handler {
|
||||||
|
return &Handler{
|
||||||
|
db: querier,
|
||||||
|
blobs: blobs,
|
||||||
|
policy: pol,
|
||||||
|
auditFn: auditFn,
|
||||||
|
uploads: newUploadManager(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isDigest returns true if the reference looks like a digest (sha256:...).
|
||||||
|
func isDigest(ref string) bool {
|
||||||
|
return strings.HasPrefix(ref, "sha256:")
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPolicy evaluates the policy for the given action and repository.
|
||||||
|
// Returns true if access is allowed, false if denied (and writes the OCI error).
|
||||||
|
func (h *Handler) checkPolicy(w http.ResponseWriter, r *http.Request, action policy.Action, repo string) bool {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
if claims == nil {
|
||||||
|
writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
input := policy.PolicyInput{
|
||||||
|
Subject: claims.Subject,
|
||||||
|
AccountType: claims.AccountType,
|
||||||
|
Roles: claims.Roles,
|
||||||
|
Action: action,
|
||||||
|
Repository: repo,
|
||||||
|
}
|
||||||
|
|
||||||
|
effect, _ := h.policy.Evaluate(input)
|
||||||
|
if effect == policy.Deny {
|
||||||
|
if h.auditFn != nil {
|
||||||
|
h.auditFn("policy_deny", claims.Subject, repo, "", r.RemoteAddr, map[string]string{
|
||||||
|
"action": string(action),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
writeOCIError(w, "DENIED", http.StatusForbidden, "access denied by policy")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// audit records an audit event if an audit function is configured.
|
||||||
|
func (h *Handler) audit(r *http.Request, eventType, repo, digest string) {
|
||||||
|
if h.auditFn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
actorID := ""
|
||||||
|
if claims != nil {
|
||||||
|
actorID = claims.Subject
|
||||||
|
}
|
||||||
|
h.auditFn(eventType, actorID, repo, digest, r.RemoteAddr, nil)
|
||||||
|
}
|
||||||
321
internal/oci/handler_test.go
Normal file
321
internal/oci/handler_test.go
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// manifestKey uniquely identifies a manifest for test lookup.
|
||||||
|
type manifestKey struct {
|
||||||
|
repoID int64
|
||||||
|
reference string // tag or digest
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeDB implements DBQuerier for tests.
|
||||||
|
type fakeDB struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
repos map[string]int64 // name -> id
|
||||||
|
manifests map[manifestKey]*db.ManifestRow // (repoID, ref) -> manifest
|
||||||
|
blobs map[int64]map[string]bool // repoID -> set of digests
|
||||||
|
allBlobs map[string]bool // global blob digests
|
||||||
|
tags map[int64][]string // repoID -> sorted tag names
|
||||||
|
repoNames []string // sorted repo names
|
||||||
|
uploads map[string]*db.UploadRow // uuid -> upload
|
||||||
|
nextID int64 // auto-increment counter
|
||||||
|
pushed []db.PushManifestParams // record of pushed manifests
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeDB() *fakeDB {
|
||||||
|
return &fakeDB{
|
||||||
|
repos: make(map[string]int64),
|
||||||
|
manifests: make(map[manifestKey]*db.ManifestRow),
|
||||||
|
blobs: make(map[int64]map[string]bool),
|
||||||
|
allBlobs: make(map[string]bool),
|
||||||
|
tags: make(map[int64][]string),
|
||||||
|
uploads: make(map[string]*db.UploadRow),
|
||||||
|
nextID: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) GetRepositoryByName(name string) (int64, error) {
|
||||||
|
id, ok := f.repos[name]
|
||||||
|
if !ok {
|
||||||
|
return 0, db.ErrRepoNotFound
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) GetManifestByTag(repoID int64, tag string) (*db.ManifestRow, error) {
|
||||||
|
m, ok := f.manifests[manifestKey{repoID, tag}]
|
||||||
|
if !ok {
|
||||||
|
return nil, db.ErrManifestNotFound
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) GetManifestByDigest(repoID int64, digest string) (*db.ManifestRow, error) {
|
||||||
|
m, ok := f.manifests[manifestKey{repoID, digest}]
|
||||||
|
if !ok {
|
||||||
|
return nil, db.ErrManifestNotFound
|
||||||
|
}
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) BlobExistsInRepo(repoID int64, digest string) (bool, error) {
|
||||||
|
digests, ok := f.blobs[repoID]
|
||||||
|
if !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return digests[digest], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) ListTags(repoID int64, after string, limit int) ([]string, error) {
|
||||||
|
allTags := f.tags[repoID]
|
||||||
|
var result []string
|
||||||
|
for _, t := range allTags {
|
||||||
|
if after != "" && t <= after {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, t)
|
||||||
|
if len(result) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) ListRepositoryNames(after string, limit int) ([]string, error) {
|
||||||
|
var result []string
|
||||||
|
for _, n := range f.repoNames {
|
||||||
|
if after != "" && n <= after {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result = append(result, n)
|
||||||
|
if len(result) >= limit {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) GetOrCreateRepository(name string) (int64, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
id, ok := f.repos[name]
|
||||||
|
if ok {
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
id = f.nextID
|
||||||
|
f.nextID++
|
||||||
|
f.repos[name] = id
|
||||||
|
f.repoNames = append(f.repoNames, name)
|
||||||
|
sort.Strings(f.repoNames)
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) BlobExists(digest string) (bool, error) {
|
||||||
|
return f.allBlobs[digest], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) InsertBlob(digest string, size int64) error {
|
||||||
|
f.allBlobs[digest] = true
|
||||||
|
_ = size
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) PushManifest(p db.PushManifestParams) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.pushed = append(f.pushed, p)
|
||||||
|
// Simulate creating the manifest in our fake data.
|
||||||
|
repoID, ok := f.repos[p.RepoName]
|
||||||
|
if !ok {
|
||||||
|
repoID = f.nextID
|
||||||
|
f.nextID++
|
||||||
|
f.repos[p.RepoName] = repoID
|
||||||
|
f.repoNames = append(f.repoNames, p.RepoName)
|
||||||
|
sort.Strings(f.repoNames)
|
||||||
|
}
|
||||||
|
m := &db.ManifestRow{
|
||||||
|
ID: f.nextID,
|
||||||
|
RepositoryID: repoID,
|
||||||
|
Digest: p.Digest,
|
||||||
|
MediaType: p.MediaType,
|
||||||
|
Content: p.Content,
|
||||||
|
Size: p.Size,
|
||||||
|
}
|
||||||
|
f.nextID++
|
||||||
|
f.manifests[manifestKey{repoID, p.Digest}] = m
|
||||||
|
if p.Tag != "" {
|
||||||
|
f.manifests[manifestKey{repoID, p.Tag}] = m
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) CreateUpload(uuid string, repoID int64) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
f.uploads[uuid] = &db.UploadRow{
|
||||||
|
ID: f.nextID,
|
||||||
|
UUID: uuid,
|
||||||
|
RepositoryID: repoID,
|
||||||
|
ByteOffset: 0,
|
||||||
|
}
|
||||||
|
f.nextID++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) GetUpload(uuid string) (*db.UploadRow, error) {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
u, ok := f.uploads[uuid]
|
||||||
|
if !ok {
|
||||||
|
return nil, db.ErrUploadNotFound
|
||||||
|
}
|
||||||
|
return u, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) UpdateUploadOffset(uuid string, offset int64) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
u, ok := f.uploads[uuid]
|
||||||
|
if !ok {
|
||||||
|
return db.ErrUploadNotFound
|
||||||
|
}
|
||||||
|
u.ByteOffset = offset
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeDB) DeleteUpload(uuid string) error {
|
||||||
|
f.mu.Lock()
|
||||||
|
defer f.mu.Unlock()
|
||||||
|
if _, ok := f.uploads[uuid]; !ok {
|
||||||
|
return db.ErrUploadNotFound
|
||||||
|
}
|
||||||
|
delete(f.uploads, uuid)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// addRepo adds a repo to the fakeDB and returns its ID.
|
||||||
|
func (f *fakeDB) addRepo(name string, id int64) {
|
||||||
|
f.repos[name] = id
|
||||||
|
f.repoNames = append(f.repoNames, name)
|
||||||
|
sort.Strings(f.repoNames)
|
||||||
|
if id >= f.nextID {
|
||||||
|
f.nextID = id + 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// addManifest adds a manifest accessible by both tag and digest.
|
||||||
|
func (f *fakeDB) addManifest(repoID int64, tag, digest, mediaType string, content []byte) {
|
||||||
|
m := &db.ManifestRow{
|
||||||
|
ID: f.nextID,
|
||||||
|
RepositoryID: repoID,
|
||||||
|
Digest: digest,
|
||||||
|
MediaType: mediaType,
|
||||||
|
Content: content,
|
||||||
|
Size: int64(len(content)),
|
||||||
|
}
|
||||||
|
f.nextID++
|
||||||
|
if tag != "" {
|
||||||
|
f.manifests[manifestKey{repoID, tag}] = m
|
||||||
|
}
|
||||||
|
f.manifests[manifestKey{repoID, digest}] = m
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBlob registers a blob digest in a repository.
|
||||||
|
func (f *fakeDB) addBlob(repoID int64, digest string) {
|
||||||
|
if f.blobs[repoID] == nil {
|
||||||
|
f.blobs[repoID] = make(map[string]bool)
|
||||||
|
}
|
||||||
|
f.blobs[repoID][digest] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// addGlobalBlob registers a blob in the global blob table.
|
||||||
|
func (f *fakeDB) addGlobalBlob(digest string) {
|
||||||
|
f.allBlobs[digest] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// addTag adds a tag to a repository's tag list.
|
||||||
|
func (f *fakeDB) addTag(repoID int64, tag string) {
|
||||||
|
f.tags[repoID] = append(f.tags[repoID], tag)
|
||||||
|
sort.Strings(f.tags[repoID])
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakeBlobs implements BlobStore for tests.
|
||||||
|
type fakeBlobs struct {
|
||||||
|
data map[string][]byte // digest -> content
|
||||||
|
uploads map[string]*bytes.Buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
func newFakeBlobs() *fakeBlobs {
|
||||||
|
return &fakeBlobs{
|
||||||
|
data: make(map[string][]byte),
|
||||||
|
uploads: make(map[string]*bytes.Buffer),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeBlobs) Open(digest string) (io.ReadCloser, error) {
|
||||||
|
data, ok := f.data[digest]
|
||||||
|
if !ok {
|
||||||
|
return nil, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return io.NopCloser(bytes.NewReader(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeBlobs) Stat(digest string) (int64, error) {
|
||||||
|
data, ok := f.data[digest]
|
||||||
|
if !ok {
|
||||||
|
return 0, io.ErrUnexpectedEOF
|
||||||
|
}
|
||||||
|
return int64(len(data)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakeBlobs) StartUpload(uuid string) (*storage.BlobWriter, error) {
|
||||||
|
// For tests that need real storage, use a real Store in t.TempDir().
|
||||||
|
// This fake panics to catch unintended usage.
|
||||||
|
panic("fakeBlobs.StartUpload should not be called; use a real storage.Store for upload tests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// fakePolicy implements PolicyEval, always returning Allow.
|
||||||
|
type fakePolicy struct {
|
||||||
|
effect policy.Effect
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePolicy) Evaluate(_ policy.PolicyInput) (policy.Effect, *policy.Rule) {
|
||||||
|
return f.effect, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// allowAll returns a fakePolicy that allows all requests.
|
||||||
|
func allowAll() *fakePolicy {
|
||||||
|
return &fakePolicy{effect: policy.Allow}
|
||||||
|
}
|
||||||
|
|
||||||
|
// testRouter creates a chi.Mux with the OCI handler mounted at /v2.
|
||||||
|
func testRouter(h *Handler) *chi.Mux {
|
||||||
|
parent := chi.NewRouter()
|
||||||
|
parent.Mount("/v2", h.Router())
|
||||||
|
return parent
|
||||||
|
}
|
||||||
|
|
||||||
|
// authedRequest creates an HTTP request with authenticated claims in the context.
|
||||||
|
func authedRequest(method, path string, body io.Reader) *http.Request {
|
||||||
|
req := httptest.NewRequest(method, path, body)
|
||||||
|
claims := &auth.Claims{
|
||||||
|
Subject: "test-user",
|
||||||
|
AccountType: "human",
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
return req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||||
|
}
|
||||||
222
internal/oci/manifest.go
Normal file
222
internal/oci/manifest.go
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (h *Handler) handleManifestGet(w http.ResponseWriter, r *http.Request, repo, reference string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPull, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := h.resolveManifest(w, repo, reference)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.audit(r, "manifest_pulled", repo, m.Digest)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", m.MediaType)
|
||||||
|
w.Header().Set("Docker-Content-Digest", m.Digest)
|
||||||
|
w.Header().Set("Content-Length", strconv.FormatInt(m.Size, 10))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(m.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleManifestHead(w http.ResponseWriter, r *http.Request, repo, reference string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPull, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m, ok := h.resolveManifest(w, repo, reference)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", m.MediaType)
|
||||||
|
w.Header().Set("Docker-Content-Digest", m.Digest)
|
||||||
|
w.Header().Set("Content-Length", strconv.FormatInt(m.Size, 10))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveManifest looks up a manifest by tag or digest, writing OCI errors on failure.
|
||||||
|
func (h *Handler) resolveManifest(w http.ResponseWriter, repo, reference string) (*db.ManifestRow, bool) {
|
||||||
|
repoID, err := h.db.GetRepositoryByName(repo)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("repository %q not found", repo))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
var m *db.ManifestRow
|
||||||
|
if isDigest(reference) {
|
||||||
|
m, err = h.db.GetManifestByDigest(repoID, reference)
|
||||||
|
} else {
|
||||||
|
m, err = h.db.GetManifestByTag(repoID, reference)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrManifestNotFound) {
|
||||||
|
writeOCIError(w, "MANIFEST_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("manifest %q not found", reference))
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// ociManifest is the minimal structure for parsing an OCI image manifest.
|
||||||
|
type ociManifest struct {
|
||||||
|
SchemaVersion int `json:"schemaVersion"`
|
||||||
|
MediaType string `json:"mediaType,omitempty"`
|
||||||
|
Config ociDescriptor `json:"config"`
|
||||||
|
Layers []ociDescriptor `json:"layers"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ociDescriptor is a content-addressed reference within a manifest.
|
||||||
|
type ociDescriptor struct {
|
||||||
|
MediaType string `json:"mediaType"`
|
||||||
|
Digest string `json:"digest"`
|
||||||
|
Size int64 `json:"size"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// supportedManifestMediaTypes lists the manifest media types MCR accepts.
|
||||||
|
var supportedManifestMediaTypes = map[string]bool{
|
||||||
|
"application/vnd.oci.image.manifest.v1+json": true,
|
||||||
|
"application/vnd.docker.distribution.manifest.v2+json": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleManifestPut handles PUT /v2/<name>/manifests/<reference>
|
||||||
|
func (h *Handler) handleManifestPut(w http.ResponseWriter, r *http.Request, repo, reference string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Read and parse manifest JSON.
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "failed to read request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(body) == 0 {
|
||||||
|
writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "empty manifest")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var manifest ociManifest
|
||||||
|
if err := json.Unmarshal(body, &manifest); err != nil {
|
||||||
|
writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "malformed manifest JSON")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if manifest.SchemaVersion != 2 {
|
||||||
|
writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest, "unsupported schema version")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine media type from Content-Type header, falling back to manifest body.
|
||||||
|
mediaType := r.Header.Get("Content-Type")
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = manifest.MediaType
|
||||||
|
}
|
||||||
|
if mediaType == "" {
|
||||||
|
mediaType = "application/vnd.oci.image.manifest.v1+json"
|
||||||
|
}
|
||||||
|
if !supportedManifestMediaTypes[mediaType] {
|
||||||
|
writeOCIError(w, "MANIFEST_INVALID", http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("unsupported media type: %s", mediaType))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Compute SHA-256 digest.
|
||||||
|
sum := sha256.Sum256(body)
|
||||||
|
computedDigest := "sha256:" + hex.EncodeToString(sum[:])
|
||||||
|
|
||||||
|
// Step 3: If reference is a digest, verify it matches.
|
||||||
|
tag := ""
|
||||||
|
if isDigest(reference) {
|
||||||
|
if reference != computedDigest {
|
||||||
|
writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("digest mismatch: computed %s, got %s", computedDigest, reference))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tag = reference
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Collect all referenced blob digests.
|
||||||
|
var blobDigests []string
|
||||||
|
if manifest.Config.Digest != "" {
|
||||||
|
blobDigests = append(blobDigests, manifest.Config.Digest)
|
||||||
|
}
|
||||||
|
for _, layer := range manifest.Layers {
|
||||||
|
if layer.Digest != "" {
|
||||||
|
blobDigests = append(blobDigests, layer.Digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 5: Verify all referenced blobs exist.
|
||||||
|
for _, bd := range blobDigests {
|
||||||
|
exists, err := h.db.BlobExists(bd)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
writeOCIError(w, "MANIFEST_BLOB_UNKNOWN", http.StatusBadRequest,
|
||||||
|
fmt.Sprintf("blob %s not found", bd))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 6: Single transaction — create repo, insert manifest, populate
|
||||||
|
// manifest_blobs, upsert tag.
|
||||||
|
params := db.PushManifestParams{
|
||||||
|
RepoName: repo,
|
||||||
|
Digest: computedDigest,
|
||||||
|
MediaType: mediaType,
|
||||||
|
Content: body,
|
||||||
|
Size: int64(len(body)),
|
||||||
|
Tag: tag,
|
||||||
|
BlobDigests: blobDigests,
|
||||||
|
}
|
||||||
|
if err := h.db.PushManifest(params); err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 7: Audit and respond.
|
||||||
|
details := map[string]string{}
|
||||||
|
if tag != "" {
|
||||||
|
details["tag"] = tag
|
||||||
|
}
|
||||||
|
if h.auditFn != nil {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
actorID := ""
|
||||||
|
if claims != nil {
|
||||||
|
actorID = claims.Subject
|
||||||
|
}
|
||||||
|
h.auditFn("manifest_pushed", actorID, repo, computedDigest, r.RemoteAddr, details)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", fmt.Sprintf("/v2/%s/manifests/%s", repo, computedDigest))
|
||||||
|
w.Header().Set("Docker-Content-Digest", computedDigest)
|
||||||
|
w.Header().Set("Content-Type", mediaType)
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
}
|
||||||
187
internal/oci/manifest_test.go
Normal file
187
internal/oci/manifest_test.go
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManifestGetByTag(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/manifests/latest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := rr.Header().Get("Content-Type"); ct != "application/vnd.oci.image.manifest.v1+json" {
|
||||||
|
t.Fatalf("Content-Type: got %q", ct)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
if cl := rr.Header().Get("Content-Length"); cl != "19" {
|
||||||
|
t.Fatalf("Content-Length: got %q, want %q", cl, "19")
|
||||||
|
}
|
||||||
|
if rr.Body.String() != `{"schemaVersion":2}` {
|
||||||
|
t.Fatalf("body: got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestGetByDigest(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/manifests/sha256:aaaa", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
if rr.Body.String() != `{"schemaVersion":2}` {
|
||||||
|
t.Fatalf("body: got %q", rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestHead(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
fdb.addManifest(1, "latest", "sha256:aaaa", "application/vnd.oci.image.manifest.v1+json", content)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("HEAD", "/v2/myrepo/manifests/latest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := rr.Header().Get("Content-Type"); ct != "application/vnd.oci.image.manifest.v1+json" {
|
||||||
|
t.Fatalf("Content-Type: got %q", ct)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:aaaa" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
if cl := rr.Header().Get("Content-Length"); cl != "19" {
|
||||||
|
t.Fatalf("Content-Length: got %q, want %q", cl, "19")
|
||||||
|
}
|
||||||
|
if rr.Body.Len() != 0 {
|
||||||
|
t.Fatalf("HEAD body should be empty, got %d bytes", rr.Body.Len())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestGetRepoNotFound(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
// No repos added.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/nosuchrepo/manifests/latest", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestGetManifestNotFoundByTag(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
// No manifests added.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/manifests/nonexistent", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want MANIFEST_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestGetManifestNotFoundByDigest(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
// No manifests added.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/manifests/sha256:nonexistent", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want MANIFEST_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestGetMultiSegmentRepo(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("org/team/app", 1)
|
||||||
|
content := []byte(`{"layers":[]}`)
|
||||||
|
fdb.addManifest(1, "v1.0", "sha256:cccc", "application/vnd.oci.image.manifest.v1+json", content)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/org/team/app/manifests/v1.0", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if dcd := rr.Header().Get("Docker-Content-Digest"); dcd != "sha256:cccc" {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q", dcd)
|
||||||
|
}
|
||||||
|
}
|
||||||
23
internal/oci/ocierror.go
Normal file
23
internal/oci/ocierror.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ociErrorEntry struct {
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ociErrorResponse struct {
|
||||||
|
Errors []ociErrorEntry `json:"errors"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeOCIError(w http.ResponseWriter, code string, status int, message string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(ociErrorResponse{
|
||||||
|
Errors: []ociErrorEntry{{Code: code, Message: message}},
|
||||||
|
})
|
||||||
|
}
|
||||||
266
internal/oci/push_test.go
Normal file
266
internal/oci/push_test.go
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func makeManifest(configDigest string, layerDigests []string) []byte {
|
||||||
|
layers := ""
|
||||||
|
for i, d := range layerDigests {
|
||||||
|
if i > 0 {
|
||||||
|
layers += ","
|
||||||
|
}
|
||||||
|
layers += fmt.Sprintf(`{"mediaType":"application/vnd.oci.image.layer.v1.tar+gzip","digest":%q,"size":1000}`, d)
|
||||||
|
}
|
||||||
|
return []byte(fmt.Sprintf(`{"schemaVersion":2,"mediaType":"application/vnd.oci.image.manifest.v1+json","config":{"mediaType":"application/vnd.oci.image.config.v1+json","digest":%q,"size":100},"layers":[%s]}`, configDigest, layers))
|
||||||
|
}
|
||||||
|
|
||||||
|
func manifestDigest(content []byte) string {
|
||||||
|
sum := sha256.Sum256(content)
|
||||||
|
return "sha256:" + hex.EncodeToString(sum[:])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushByTag(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addGlobalBlob("sha256:config1")
|
||||||
|
fdb.addGlobalBlob("sha256:layer1")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
content := makeManifest("sha256:config1", []string{"sha256:layer1"})
|
||||||
|
digest := manifestDigest(content)
|
||||||
|
|
||||||
|
req := authedRequest("PUT", "/v2/myrepo/manifests/latest", nil)
|
||||||
|
req.Body = http.NoBody
|
||||||
|
// Re-create with proper body.
|
||||||
|
req = authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("status: got %d, want %d, body: %s", rr.Code, http.StatusCreated, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
dcd := rr.Header().Get("Docker-Content-Digest")
|
||||||
|
if dcd != digest {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q, want %q", dcd, digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
loc := rr.Header().Get("Location")
|
||||||
|
if loc == "" {
|
||||||
|
t.Fatal("Location header missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
ct := rr.Header().Get("Content-Type")
|
||||||
|
if ct != "application/vnd.oci.image.manifest.v1+json" {
|
||||||
|
t.Fatalf("Content-Type: got %q", ct)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify manifest was stored.
|
||||||
|
if len(fdb.pushed) != 1 {
|
||||||
|
t.Fatalf("pushed count: got %d, want 1", len(fdb.pushed))
|
||||||
|
}
|
||||||
|
if fdb.pushed[0].Tag != "latest" {
|
||||||
|
t.Fatalf("pushed tag: got %q, want %q", fdb.pushed[0].Tag, "latest")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushByDigest(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addGlobalBlob("sha256:config1")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
content := makeManifest("sha256:config1", nil)
|
||||||
|
digest := manifestDigest(content)
|
||||||
|
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/"+digest, content)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("status: got %d, body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no tag was set.
|
||||||
|
if len(fdb.pushed) != 1 {
|
||||||
|
t.Fatalf("pushed count: got %d", len(fdb.pushed))
|
||||||
|
}
|
||||||
|
if fdb.pushed[0].Tag != "" {
|
||||||
|
t.Fatalf("pushed tag: got %q, want empty", fdb.pushed[0].Tag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushDigestMismatch(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
content := []byte(`{"schemaVersion":2}`)
|
||||||
|
wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||||
|
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/"+wrongDigest, content)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "DIGEST_INVALID" {
|
||||||
|
t.Fatalf("error code: got %+v, want DIGEST_INVALID", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushMissingBlob(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
// Config blob exists but layer blob does not.
|
||||||
|
fdb.addGlobalBlob("sha256:config1")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
content := makeManifest("sha256:config1", []string{"sha256:missing_layer"})
|
||||||
|
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want %d, body: %s", rr.Code, http.StatusBadRequest, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_BLOB_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want MANIFEST_BLOB_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushMalformed(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", []byte("not valid json"))
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "MANIFEST_INVALID" {
|
||||||
|
t.Fatalf("error code: got %+v, want MANIFEST_INVALID", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushEmpty(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", nil)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushUpdatesTag(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addGlobalBlob("sha256:config1")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
// First push with tag "latest".
|
||||||
|
content1 := makeManifest("sha256:config1", nil)
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content1)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("first push status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
firstDigest := rr.Header().Get("Docker-Content-Digest")
|
||||||
|
|
||||||
|
// Second push with same tag — different content.
|
||||||
|
fdb.addGlobalBlob("sha256:config2")
|
||||||
|
content2 := makeManifest("sha256:config2", nil)
|
||||||
|
req = authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content2)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("second push status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
secondDigest := rr.Header().Get("Docker-Content-Digest")
|
||||||
|
|
||||||
|
if firstDigest == secondDigest {
|
||||||
|
t.Fatal("two pushes should produce different digests")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that the tag was atomically moved.
|
||||||
|
if len(fdb.pushed) != 2 {
|
||||||
|
t.Fatalf("pushed count: got %d, want 2", len(fdb.pushed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManifestPushRepushIdempotent(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addGlobalBlob("sha256:config1")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
content := makeManifest("sha256:config1", nil)
|
||||||
|
|
||||||
|
// Push the same manifest twice.
|
||||||
|
for i := range 2 {
|
||||||
|
req := authedPushRequest("PUT", "/v2/myrepo/manifests/latest", content)
|
||||||
|
req.Header.Set("Content-Type", "application/vnd.oci.image.manifest.v1+json")
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("push %d status: got %d", i+1, rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
171
internal/oci/routes.go
Normal file
171
internal/oci/routes.go
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ociPathInfo holds the parsed components of an OCI API path.
|
||||||
|
type ociPathInfo struct {
|
||||||
|
name string // repository name (may contain slashes)
|
||||||
|
kind string // "manifests", "blobs", or "tags"
|
||||||
|
reference string // tag, digest, or "list"
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOCIPath extracts the repository name and operation from an OCI path.
|
||||||
|
// The path should NOT include the /v2/ prefix.
|
||||||
|
// Examples:
|
||||||
|
//
|
||||||
|
// "myrepo/manifests/latest" -> {name:"myrepo", kind:"manifests", reference:"latest"}
|
||||||
|
// "org/team/app/blobs/sha256:abc" -> {name:"org/team/app", kind:"blobs", reference:"sha256:abc"}
|
||||||
|
// "myrepo/tags/list" -> {name:"myrepo", kind:"tags", reference:"list"}
|
||||||
|
// "myrepo/blobs/uploads/" -> {name:"myrepo", kind:"uploads", reference:""}
|
||||||
|
// "myrepo/blobs/uploads/uuid-here" -> {name:"myrepo", kind:"uploads", reference:"uuid-here"}
|
||||||
|
func parseOCIPath(path string) (ociPathInfo, bool) {
|
||||||
|
// Check for /tags/list suffix.
|
||||||
|
if strings.HasSuffix(path, "/tags/list") {
|
||||||
|
name := path[:len(path)-len("/tags/list")]
|
||||||
|
if name == "" {
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
return ociPathInfo{name: name, kind: "tags", reference: "list"}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for /blobs/uploads/ (must come before /blobs/).
|
||||||
|
if idx := strings.LastIndex(path, "/blobs/uploads/"); idx >= 0 {
|
||||||
|
name := path[:idx]
|
||||||
|
uuid := path[idx+len("/blobs/uploads/"):]
|
||||||
|
if name == "" {
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
return ociPathInfo{name: name, kind: "uploads", reference: uuid}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for /blobs/uploads (trailing-slash-trimmed form for POST initiate).
|
||||||
|
if strings.HasSuffix(path, "/blobs/uploads") {
|
||||||
|
name := path[:len(path)-len("/blobs/uploads")]
|
||||||
|
if name == "" {
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
return ociPathInfo{name: name, kind: "uploads", reference: ""}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for /manifests/<ref>.
|
||||||
|
if idx := strings.LastIndex(path, "/manifests/"); idx >= 0 {
|
||||||
|
name := path[:idx]
|
||||||
|
ref := path[idx+len("/manifests/"):]
|
||||||
|
if name == "" || ref == "" {
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
return ociPathInfo{name: name, kind: "manifests", reference: ref}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for /blobs/<digest>.
|
||||||
|
if idx := strings.LastIndex(path, "/blobs/"); idx >= 0 {
|
||||||
|
name := path[:idx]
|
||||||
|
ref := path[idx+len("/blobs/"):]
|
||||||
|
if name == "" || ref == "" {
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
return ociPathInfo{name: name, kind: "blobs", reference: ref}, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return ociPathInfo{}, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Router returns a chi router for OCI Distribution Spec endpoints.
|
||||||
|
// It should be mounted at /v2 on the parent router.
|
||||||
|
func (h *Handler) Router() chi.Router {
|
||||||
|
r := chi.NewRouter()
|
||||||
|
|
||||||
|
// Catalog endpoint: GET /v2/_catalog
|
||||||
|
r.Get("/_catalog", h.handleCatalog)
|
||||||
|
|
||||||
|
// All other OCI endpoints use a catch-all to support multi-segment repo names.
|
||||||
|
r.HandleFunc("/*", h.dispatch)
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatch routes requests to the appropriate handler based on the parsed path.
|
||||||
|
func (h *Handler) dispatch(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Get the path after /v2/
|
||||||
|
path := chi.URLParam(r, "*")
|
||||||
|
if path == "" {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "repository name required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
info, ok := parseOCIPath(path)
|
||||||
|
if !ok {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "invalid OCI path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
switch info.kind {
|
||||||
|
case "manifests":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.handleManifestGet(w, r, info.name, info.reference)
|
||||||
|
case http.MethodHead:
|
||||||
|
h.handleManifestHead(w, r, info.name, info.reference)
|
||||||
|
case http.MethodPut:
|
||||||
|
h.handleManifestPut(w, r, info.name, info.reference)
|
||||||
|
default:
|
||||||
|
w.Header().Set("Allow", "GET, HEAD, PUT")
|
||||||
|
writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
}
|
||||||
|
case "blobs":
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodGet:
|
||||||
|
h.handleBlobGet(w, r, info.name, info.reference)
|
||||||
|
case http.MethodHead:
|
||||||
|
h.handleBlobHead(w, r, info.name, info.reference)
|
||||||
|
default:
|
||||||
|
w.Header().Set("Allow", "GET, HEAD")
|
||||||
|
writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
}
|
||||||
|
case "uploads":
|
||||||
|
h.dispatchUpload(w, r, info.name, info.reference)
|
||||||
|
case "tags":
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
w.Header().Set("Allow", "GET")
|
||||||
|
writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleTagsList(w, r, info.name)
|
||||||
|
default:
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound, "unknown operation")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dispatchUpload routes upload requests to the appropriate handler.
|
||||||
|
func (h *Handler) dispatchUpload(w http.ResponseWriter, r *http.Request, repo, uuid string) {
|
||||||
|
if uuid == "" {
|
||||||
|
// POST /v2/<name>/blobs/uploads/ — initiate
|
||||||
|
if r.Method != http.MethodPost {
|
||||||
|
w.Header().Set("Allow", "POST")
|
||||||
|
writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
h.handleUploadInitiate(w, r, repo)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Operations on existing upload UUID.
|
||||||
|
switch r.Method {
|
||||||
|
case http.MethodPatch:
|
||||||
|
h.handleUploadChunk(w, r, repo, uuid)
|
||||||
|
case http.MethodPut:
|
||||||
|
h.handleUploadComplete(w, r, repo, uuid)
|
||||||
|
case http.MethodGet:
|
||||||
|
h.handleUploadStatus(w, r, repo, uuid)
|
||||||
|
case http.MethodDelete:
|
||||||
|
h.handleUploadCancel(w, r, repo, uuid)
|
||||||
|
default:
|
||||||
|
w.Header().Set("Allow", "PATCH, PUT, GET, DELETE")
|
||||||
|
writeOCIError(w, "UNSUPPORTED", http.StatusMethodNotAllowed, "method not allowed")
|
||||||
|
}
|
||||||
|
}
|
||||||
141
internal/oci/routes_test.go
Normal file
141
internal/oci/routes_test.go
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestParseOCIPath(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
path string
|
||||||
|
want ociPathInfo
|
||||||
|
wantOK bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple repo manifest by tag",
|
||||||
|
path: "myrepo/manifests/latest",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "manifests", reference: "latest"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-segment repo manifest by tag",
|
||||||
|
path: "org/team/app/manifests/v1.0",
|
||||||
|
want: ociPathInfo{name: "org/team/app", kind: "manifests", reference: "v1.0"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manifest by digest",
|
||||||
|
path: "myrepo/manifests/sha256:abc123def456",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "manifests", reference: "sha256:abc123def456"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple repo blob",
|
||||||
|
path: "myrepo/blobs/sha256:abc123",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "blobs", reference: "sha256:abc123"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-segment repo blob",
|
||||||
|
path: "org/team/app/blobs/sha256:abc123",
|
||||||
|
want: ociPathInfo{name: "org/team/app", kind: "blobs", reference: "sha256:abc123"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple repo tags list",
|
||||||
|
path: "myrepo/tags/list",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "tags", reference: "list"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-segment repo tags list",
|
||||||
|
path: "org/app/tags/list",
|
||||||
|
want: ociPathInfo{name: "org/app", kind: "tags", reference: "list"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty path",
|
||||||
|
path: "",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "just repo name",
|
||||||
|
path: "myrepo",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown operation",
|
||||||
|
path: "myrepo/unknown/ref",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "manifests with no ref",
|
||||||
|
path: "myrepo/manifests/",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blobs with no digest",
|
||||||
|
path: "myrepo/blobs/",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tags without list suffix",
|
||||||
|
path: "myrepo/tags/something",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no repo name before manifests",
|
||||||
|
path: "/manifests/latest",
|
||||||
|
wantOK: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upload initiate (trailing slash)",
|
||||||
|
path: "myrepo/blobs/uploads/",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "uploads", reference: ""},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upload initiate (no trailing slash)",
|
||||||
|
path: "myrepo/blobs/uploads",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "uploads", reference: ""},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "upload with uuid",
|
||||||
|
path: "myrepo/blobs/uploads/abc-123-def",
|
||||||
|
want: ociPathInfo{name: "myrepo", kind: "uploads", reference: "abc-123-def"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-segment repo upload",
|
||||||
|
path: "org/team/app/blobs/uploads/uuid-456",
|
||||||
|
want: ociPathInfo{name: "org/team/app", kind: "uploads", reference: "uuid-456"},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-segment repo upload initiate",
|
||||||
|
path: "org/team/app/blobs/uploads/",
|
||||||
|
want: ociPathInfo{name: "org/team/app", kind: "uploads", reference: ""},
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, ok := parseOCIPath(tt.path)
|
||||||
|
if ok != tt.wantOK {
|
||||||
|
t.Fatalf("parseOCIPath(%q) ok = %v, want %v", tt.path, ok, tt.wantOK)
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if got.name != tt.want.name {
|
||||||
|
t.Errorf("name: got %q, want %q", got.name, tt.want.name)
|
||||||
|
}
|
||||||
|
if got.kind != tt.want.kind {
|
||||||
|
t.Errorf("kind: got %q, want %q", got.kind, tt.want.kind)
|
||||||
|
}
|
||||||
|
if got.reference != tt.want.reference {
|
||||||
|
t.Errorf("reference: got %q, want %q", got.reference, tt.want.reference)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
84
internal/oci/tags.go
Normal file
84
internal/oci/tags.go
Normal file
@@ -0,0 +1,84 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
type tagListResponse struct {
|
||||||
|
Name string `json:"name"`
|
||||||
|
Tags []string `json:"tags"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *Handler) handleTagsList(w http.ResponseWriter, r *http.Request, repo string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPull, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
repoID, err := h.db.GetRepositoryByName(repo)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeOCIError(w, "NAME_UNKNOWN", http.StatusNotFound,
|
||||||
|
fmt.Sprintf("repository %q not found", repo))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse pagination params.
|
||||||
|
n := 0
|
||||||
|
if nStr := r.URL.Query().Get("n"); nStr != "" {
|
||||||
|
n, err = strconv.Atoi(nStr)
|
||||||
|
if err != nil || n < 0 {
|
||||||
|
writeOCIError(w, "INVALID_PARAMETER", http.StatusBadRequest, "invalid 'n' parameter")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
last := r.URL.Query().Get("last")
|
||||||
|
|
||||||
|
// Default: no limit (return all).
|
||||||
|
if n == 0 {
|
||||||
|
tags, err := h.db.ListTags(repoID, last, 10000)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tags == nil {
|
||||||
|
tags = []string{}
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(tagListResponse{Name: repo, Tags: tags})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request n+1 to detect if there are more results.
|
||||||
|
tags, err := h.db.ListTags(repoID, last, n+1)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if tags == nil {
|
||||||
|
tags = []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
hasMore := len(tags) > n
|
||||||
|
if hasMore {
|
||||||
|
tags = tags[:n]
|
||||||
|
}
|
||||||
|
|
||||||
|
if hasMore {
|
||||||
|
lastTag := tags[len(tags)-1]
|
||||||
|
linkURL := fmt.Sprintf("/v2/%s/tags/list?n=%d&last=%s", repo, n, lastTag)
|
||||||
|
w.Header().Set("Link", fmt.Sprintf(`<%s>; rel="next"`, linkURL))
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
_ = json.NewEncoder(w).Encode(tagListResponse{Name: repo, Tags: tags})
|
||||||
|
}
|
||||||
154
internal/oci/tags_test.go
Normal file
154
internal/oci/tags_test.go
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTagsList(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addTag(1, "latest")
|
||||||
|
fdb.addTag(1, "v1.0")
|
||||||
|
fdb.addTag(1, "v2.0")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/tags/list", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body tagListResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if body.Name != "myrepo" {
|
||||||
|
t.Fatalf("name: got %q, want %q", body.Name, "myrepo")
|
||||||
|
}
|
||||||
|
want := []string{"latest", "v1.0", "v2.0"}
|
||||||
|
if len(body.Tags) != len(want) {
|
||||||
|
t.Fatalf("tags count: got %d, want %d", len(body.Tags), len(want))
|
||||||
|
}
|
||||||
|
for i, tag := range body.Tags {
|
||||||
|
if tag != want[i] {
|
||||||
|
t.Fatalf("tags[%d]: got %q, want %q", i, tag, want[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTagsListPagination(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
fdb.addTag(1, "alpha")
|
||||||
|
fdb.addTag(1, "beta")
|
||||||
|
fdb.addTag(1, "gamma")
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
// First page: n=2.
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/tags/list?n=2", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body tagListResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Tags) != 2 {
|
||||||
|
t.Fatalf("page 1 tags count: got %d, want 2", len(body.Tags))
|
||||||
|
}
|
||||||
|
if body.Tags[0] != "alpha" || body.Tags[1] != "beta" {
|
||||||
|
t.Fatalf("page 1 tags: got %v", body.Tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Link header.
|
||||||
|
link := rr.Header().Get("Link")
|
||||||
|
if link == "" {
|
||||||
|
t.Fatal("expected Link header for pagination")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second page: n=2, last=beta.
|
||||||
|
req = authedRequest("GET", "/v2/myrepo/tags/list?n=2&last=beta", nil)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("page 2 status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode page 2: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Tags) != 1 {
|
||||||
|
t.Fatalf("page 2 tags count: got %d, want 1", len(body.Tags))
|
||||||
|
}
|
||||||
|
if body.Tags[0] != "gamma" {
|
||||||
|
t.Fatalf("page 2 tags: got %v", body.Tags)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No Link header on last page.
|
||||||
|
if link := rr.Header().Get("Link"); link != "" {
|
||||||
|
t.Fatalf("expected no Link header on last page, got %q", link)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTagsListEmpty(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
fdb.addRepo("myrepo", 1)
|
||||||
|
// No tags.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/myrepo/tags/list", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body tagListResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Tags) != 0 {
|
||||||
|
t.Fatalf("expected empty tags, got %v", body.Tags)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTagsListRepoNotFound(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
// No repos.
|
||||||
|
|
||||||
|
h := NewHandler(fdb, newFakeBlobs(), allowAll(), nil)
|
||||||
|
router := testRouter(h)
|
||||||
|
|
||||||
|
req := authedRequest("GET", "/v2/nosuchrepo/tags/list", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error body: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "NAME_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want NAME_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
241
internal/oci/upload.go
Normal file
241
internal/oci/upload.go
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// uploadManager tracks in-progress blob writers by UUID.
|
||||||
|
type uploadManager struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
writers map[string]*storage.BlobWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUploadManager() *uploadManager {
|
||||||
|
return &uploadManager{writers: make(map[string]*storage.BlobWriter)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadManager) set(uuid string, bw *storage.BlobWriter) {
|
||||||
|
m.mu.Lock()
|
||||||
|
m.writers[uuid] = bw
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadManager) get(uuid string) (*storage.BlobWriter, bool) {
|
||||||
|
m.mu.Lock()
|
||||||
|
bw, ok := m.writers[uuid]
|
||||||
|
m.mu.Unlock()
|
||||||
|
return bw, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *uploadManager) remove(uuid string) {
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.writers, uuid)
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateUUID creates a random UUID (v4) string.
|
||||||
|
func generateUUID() (string, error) {
|
||||||
|
var buf [16]byte
|
||||||
|
if _, err := rand.Read(buf[:]); err != nil {
|
||||||
|
return "", fmt.Errorf("oci: generate uuid: %w", err)
|
||||||
|
}
|
||||||
|
// Set version 4 and variant bits.
|
||||||
|
buf[6] = (buf[6] & 0x0f) | 0x40
|
||||||
|
buf[8] = (buf[8] & 0x3f) | 0x80
|
||||||
|
return fmt.Sprintf("%08x-%04x-%04x-%04x-%012x",
|
||||||
|
buf[0:4], buf[4:6], buf[6:8], buf[8:10], buf[10:16]), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUploadInitiate handles POST /v2/<name>/blobs/uploads/
|
||||||
|
func (h *Handler) handleUploadInitiate(w http.ResponseWriter, r *http.Request, repo string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create repository if it doesn't exist (implicit creation).
|
||||||
|
repoID, err := h.db.GetOrCreateRepository(repo)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid, err := generateUUID()
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert upload row in DB.
|
||||||
|
if err := h.db.CreateUpload(uuid, repoID); err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create temp file via storage.
|
||||||
|
bw, err := h.blobs.StartUpload(uuid)
|
||||||
|
if err != nil {
|
||||||
|
// Clean up DB row on storage failure.
|
||||||
|
_ = h.db.DeleteUpload(uuid)
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.uploads.set(uuid, bw)
|
||||||
|
|
||||||
|
w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid))
|
||||||
|
w.Header().Set("Docker-Upload-UUID", uuid)
|
||||||
|
w.Header().Set("Range", "0-0")
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUploadChunk handles PATCH /v2/<name>/blobs/uploads/<uuid>
|
||||||
|
func (h *Handler) handleUploadChunk(w http.ResponseWriter, r *http.Request, repo, uuid string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bw, ok := h.uploads.get(uuid)
|
||||||
|
if !ok {
|
||||||
|
writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append request body to upload file.
|
||||||
|
n, err := io.Copy(bw, r.Body)
|
||||||
|
if err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "write failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update offset in DB.
|
||||||
|
newOffset := bw.BytesWritten()
|
||||||
|
if err := h.db.UpdateUploadOffset(uuid, newOffset); err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = n // bytes written this chunk
|
||||||
|
|
||||||
|
w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid))
|
||||||
|
w.Header().Set("Docker-Upload-UUID", uuid)
|
||||||
|
w.Header().Set("Range", fmt.Sprintf("0-%d", newOffset))
|
||||||
|
w.WriteHeader(http.StatusAccepted)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUploadComplete handles PUT /v2/<name>/blobs/uploads/<uuid>?digest=<digest>
|
||||||
|
func (h *Handler) handleUploadComplete(w http.ResponseWriter, r *http.Request, repo, uuid string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
digest := r.URL.Query().Get("digest")
|
||||||
|
if digest == "" {
|
||||||
|
writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "digest parameter required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bw, ok := h.uploads.get(uuid)
|
||||||
|
if !ok {
|
||||||
|
writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If request body is non-empty, append it first (monolithic upload).
|
||||||
|
if r.ContentLength != 0 {
|
||||||
|
if _, err := io.Copy(bw, r.Body); err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "write failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit the blob: verify digest, move to final location.
|
||||||
|
size := bw.BytesWritten()
|
||||||
|
_, err := bw.Commit(digest)
|
||||||
|
if err != nil {
|
||||||
|
h.uploads.remove(uuid)
|
||||||
|
if errors.Is(err, storage.ErrDigestMismatch) {
|
||||||
|
_ = h.db.DeleteUpload(uuid)
|
||||||
|
writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "digest mismatch")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, storage.ErrInvalidDigest) {
|
||||||
|
_ = h.db.DeleteUpload(uuid)
|
||||||
|
writeOCIError(w, "DIGEST_INVALID", http.StatusBadRequest, "invalid digest format")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "commit failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
h.uploads.remove(uuid)
|
||||||
|
|
||||||
|
// Insert blob row (no-op if already exists — content-addressed dedup).
|
||||||
|
if err := h.db.InsertBlob(digest, size); err != nil {
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete upload row.
|
||||||
|
_ = h.db.DeleteUpload(uuid)
|
||||||
|
|
||||||
|
h.audit(r, "blob_uploaded", repo, digest)
|
||||||
|
|
||||||
|
w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/%s", repo, digest))
|
||||||
|
w.Header().Set("Docker-Content-Digest", digest)
|
||||||
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUploadStatus handles GET /v2/<name>/blobs/uploads/<uuid>
|
||||||
|
func (h *Handler) handleUploadStatus(w http.ResponseWriter, r *http.Request, repo, uuid string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
upload, err := h.db.GetUpload(uuid)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrUploadNotFound) {
|
||||||
|
writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Location", fmt.Sprintf("/v2/%s/blobs/uploads/%s", repo, uuid))
|
||||||
|
w.Header().Set("Docker-Upload-UUID", uuid)
|
||||||
|
w.Header().Set("Range", fmt.Sprintf("0-%d", upload.ByteOffset))
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUploadCancel handles DELETE /v2/<name>/blobs/uploads/<uuid>
|
||||||
|
func (h *Handler) handleUploadCancel(w http.ResponseWriter, r *http.Request, repo, uuid string) {
|
||||||
|
if !h.checkPolicy(w, r, policy.ActionPush, repo) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
bw, ok := h.uploads.get(uuid)
|
||||||
|
if ok {
|
||||||
|
_ = bw.Cancel()
|
||||||
|
h.uploads.remove(uuid)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := h.db.DeleteUpload(uuid); err != nil {
|
||||||
|
if errors.Is(err, db.ErrUploadNotFound) {
|
||||||
|
writeOCIError(w, "BLOB_UPLOAD_UNKNOWN", http.StatusNotFound, "upload not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeOCIError(w, "UNKNOWN", http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
291
internal/oci/upload_test.go
Normal file
291
internal/oci/upload_test.go
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
package oci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testHandlerWithStorage creates a handler with real storage in t.TempDir().
|
||||||
|
func testHandlerWithStorage(t *testing.T, fdb *fakeDB) (*Handler, *chi.Mux) {
|
||||||
|
t.Helper()
|
||||||
|
dir := t.TempDir()
|
||||||
|
store := storage.New(dir+"/layers", dir+"/uploads")
|
||||||
|
h := NewHandler(fdb, store, allowAll(), nil)
|
||||||
|
router := chi.NewRouter()
|
||||||
|
router.Mount("/v2", h.Router())
|
||||||
|
return h, router
|
||||||
|
}
|
||||||
|
|
||||||
|
func authedPushRequest(method, path string, body []byte) *http.Request {
|
||||||
|
var reader *bytes.Reader
|
||||||
|
if body != nil {
|
||||||
|
reader = bytes.NewReader(body)
|
||||||
|
} else {
|
||||||
|
reader = bytes.NewReader(nil)
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(method, path, reader)
|
||||||
|
claims := &auth.Claims{
|
||||||
|
Subject: "pusher",
|
||||||
|
AccountType: "human",
|
||||||
|
Roles: []string{"user"},
|
||||||
|
}
|
||||||
|
return req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadInitiate(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusAccepted {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusAccepted)
|
||||||
|
}
|
||||||
|
|
||||||
|
loc := rr.Header().Get("Location")
|
||||||
|
if !strings.HasPrefix(loc, "/v2/myrepo/blobs/uploads/") {
|
||||||
|
t.Fatalf("Location: got %q", loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
if uuid == "" {
|
||||||
|
t.Fatal("Docker-Upload-UUID header missing")
|
||||||
|
}
|
||||||
|
|
||||||
|
rng := rr.Header().Get("Range")
|
||||||
|
if rng != "0-0" {
|
||||||
|
t.Fatalf("Range: got %q, want %q", rng, "0-0")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify repo was implicitly created.
|
||||||
|
if _, ok := fdb.repos["myrepo"]; !ok {
|
||||||
|
t.Fatal("repository should have been implicitly created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadInitiateUniqueUUIDs(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
uuids := make(map[string]bool)
|
||||||
|
for range 5 {
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusAccepted {
|
||||||
|
t.Fatalf("status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
if uuids[uuid] {
|
||||||
|
t.Fatalf("duplicate UUID: %s", uuid)
|
||||||
|
}
|
||||||
|
uuids[uuid] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMonolithicUpload(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
// Step 1: Initiate upload.
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusAccepted {
|
||||||
|
t.Fatalf("initiate status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
|
||||||
|
// Step 2: Complete upload with body and digest in a single PUT.
|
||||||
|
blobData := []byte("hello world blob data")
|
||||||
|
sum := sha256.Sum256(blobData)
|
||||||
|
digest := "sha256:" + hex.EncodeToString(sum[:])
|
||||||
|
|
||||||
|
putURL := "/v2/myrepo/blobs/uploads/" + uuid + "?digest=" + digest
|
||||||
|
req = authedPushRequest("PUT", putURL, blobData)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("complete status: got %d, body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
loc := rr.Header().Get("Location")
|
||||||
|
if !strings.Contains(loc, digest) {
|
||||||
|
t.Fatalf("Location should contain digest: got %q", loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
dcd := rr.Header().Get("Docker-Content-Digest")
|
||||||
|
if dcd != digest {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q, want %q", dcd, digest)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify blob was inserted in fake DB.
|
||||||
|
if !fdb.allBlobs[digest] {
|
||||||
|
t.Fatal("blob should exist in DB after upload")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestChunkedUpload(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
// Step 1: Initiate.
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
|
||||||
|
// Step 2: PATCH chunk 1.
|
||||||
|
chunk1 := []byte("chunk-one-data-")
|
||||||
|
req = authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/"+uuid, chunk1)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusAccepted {
|
||||||
|
t.Fatalf("patch 1 status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: PATCH chunk 2.
|
||||||
|
chunk2 := []byte("chunk-two-data")
|
||||||
|
req = authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/"+uuid, chunk2)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusAccepted {
|
||||||
|
t.Fatalf("patch 2 status: got %d", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 4: Complete with PUT.
|
||||||
|
allData := append(chunk1, chunk2...)
|
||||||
|
sum := sha256.Sum256(allData)
|
||||||
|
digest := "sha256:" + hex.EncodeToString(sum[:])
|
||||||
|
|
||||||
|
req = authedPushRequest("PUT", "/v2/myrepo/blobs/uploads/"+uuid+"?digest="+digest, nil)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusCreated {
|
||||||
|
t.Fatalf("complete status: got %d, body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
if rr.Header().Get("Docker-Content-Digest") != digest {
|
||||||
|
t.Fatalf("Docker-Content-Digest: got %q, want %q", rr.Header().Get("Docker-Content-Digest"), digest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadDigestMismatch(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
// Initiate.
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
|
||||||
|
// Complete with wrong digest.
|
||||||
|
blobData := []byte("some data")
|
||||||
|
wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000"
|
||||||
|
|
||||||
|
req = authedPushRequest("PUT", "/v2/myrepo/blobs/uploads/"+uuid+"?digest="+wrongDigest, blobData)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "DIGEST_INVALID" {
|
||||||
|
t.Fatalf("error code: got %+v, want DIGEST_INVALID", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadStatus(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
// Initiate.
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
|
||||||
|
// Check status.
|
||||||
|
req = authedPushRequest("GET", "/v2/myrepo/blobs/uploads/"+uuid, nil)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rr.Header().Get("Docker-Upload-UUID") != uuid {
|
||||||
|
t.Fatalf("Docker-Upload-UUID: got %q", rr.Header().Get("Docker-Upload-UUID"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadCancel(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
// Initiate.
|
||||||
|
req := authedPushRequest("POST", "/v2/myrepo/blobs/uploads/", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
uuid := rr.Header().Get("Docker-Upload-UUID")
|
||||||
|
|
||||||
|
// Cancel.
|
||||||
|
req = authedPushRequest("DELETE", "/v2/myrepo/blobs/uploads/"+uuid, nil)
|
||||||
|
rr = httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNoContent {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNoContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify upload was removed from DB.
|
||||||
|
if _, ok := fdb.uploads[uuid]; ok {
|
||||||
|
t.Fatal("upload should have been deleted from DB")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUploadNonexistentUUID(t *testing.T) {
|
||||||
|
fdb := newFakeDB()
|
||||||
|
_, router := testHandlerWithStorage(t, fdb)
|
||||||
|
|
||||||
|
req := authedPushRequest("PATCH", "/v2/myrepo/blobs/uploads/nonexistent-uuid", []byte("data"))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusNotFound {
|
||||||
|
t.Fatalf("status: got %d, want %d", rr.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
var body ociErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&body); err != nil {
|
||||||
|
t.Fatalf("decode error: %v", err)
|
||||||
|
}
|
||||||
|
if len(body.Errors) != 1 || body.Errors[0].Code != "BLOB_UPLOAD_UNKNOWN" {
|
||||||
|
t.Fatalf("error code: got %+v, want BLOB_UPLOAD_UNKNOWN", body.Errors)
|
||||||
|
}
|
||||||
|
}
|
||||||
52
internal/server/admin.go
Normal file
52
internal/server/admin.go
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
type adminErrorResponse struct {
|
||||||
|
Error string `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeAdminError(w http.ResponseWriter, status int, message string) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(adminErrorResponse{Error: message})
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(status)
|
||||||
|
_ = json.NewEncoder(w).Encode(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequireAdmin returns middleware that checks for the admin role.
|
||||||
|
// Returns 403 with an admin error format if the caller is not an admin.
|
||||||
|
func RequireAdmin() func(http.Handler) http.Handler {
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
if claims == nil {
|
||||||
|
writeAdminError(w, http.StatusUnauthorized, "authentication required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !hasRole(claims.Roles, "admin") {
|
||||||
|
writeAdminError(w, http.StatusForbidden, "admin role required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func hasRole(roles []string, target string) bool {
|
||||||
|
for _, r := range roles {
|
||||||
|
if r == target {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
48
internal/server/admin_audit.go
Normal file
48
internal/server/admin_audit.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminListAuditHandler handles GET /v1/audit.
|
||||||
|
func AdminListAuditHandler(database *db.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
q := r.URL.Query()
|
||||||
|
|
||||||
|
limit := 50
|
||||||
|
if n := q.Get("n"); n != "" {
|
||||||
|
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||||
|
limit = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
offset := 0
|
||||||
|
if o := q.Get("offset"); o != "" {
|
||||||
|
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||||
|
offset = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filter := db.AuditFilter{
|
||||||
|
EventType: q.Get("event_type"),
|
||||||
|
ActorID: q.Get("actor_id"),
|
||||||
|
Repository: q.Get("repository"),
|
||||||
|
Since: q.Get("since"),
|
||||||
|
Until: q.Get("until"),
|
||||||
|
Limit: limit,
|
||||||
|
Offset: offset,
|
||||||
|
}
|
||||||
|
|
||||||
|
events, err := database.ListAuditEvents(filter)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if events == nil {
|
||||||
|
events = []db.AuditEvent{}
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, events)
|
||||||
|
}
|
||||||
|
}
|
||||||
152
internal/server/admin_audit_test.go
Normal file
152
internal/server/admin_audit_test.go
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminListAuditEvents(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Seed some audit events.
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "10.0.0.1", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "sha256:abc", "10.0.0.1",
|
||||||
|
map[string]string{"tag": "latest"}); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "10.0.0.2", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []db.AuditEvent
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 3 {
|
||||||
|
t.Fatalf("event count: got %d, want 3", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListAuditEventsWithFilter(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter by event_type.
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit?event_type=login_ok", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []db.AuditEvent
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatalf("event count: got %d, want 2", len(events))
|
||||||
|
}
|
||||||
|
for _, e := range events {
|
||||||
|
if e.EventType != "login_ok" {
|
||||||
|
t.Fatalf("unexpected event type: %q", e.EventType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListAuditEventsFilterByActor(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "actor-a", "", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "actor-b", "", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit?actor_id=actor-a", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []db.AuditEvent
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 1 {
|
||||||
|
t.Fatalf("event count: got %d, want 1", len(events))
|
||||||
|
}
|
||||||
|
if events[0].ActorID != "actor-a" {
|
||||||
|
t.Fatalf("actor_id: got %q, want %q", events[0].ActorID, "actor-a")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListAuditEventsPagination(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
for range 5 {
|
||||||
|
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil {
|
||||||
|
t.Fatalf("WriteAuditEvent: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit?n=2&offset=0", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []db.AuditEvent
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 2 {
|
||||||
|
t.Fatalf("event count: got %d, want 2", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListAuditEventsEmpty(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var events []db.AuditEvent
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(events) != 0 {
|
||||||
|
t.Fatalf("event count: got %d, want 0", len(events))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListAuditEventsNonAdmin(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router := buildNonAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||||
|
if rr.Code != 403 {
|
||||||
|
t.Fatalf("status: got %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
60
internal/server/admin_auth.go
Normal file
60
internal/server/admin_auth.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type adminLoginRequest struct {
|
||||||
|
Username string `json:"username"`
|
||||||
|
Password string `json:"password"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type adminLoginResponse struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
ExpiresAt string `json:"expires_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminLoginHandler handles POST /v1/auth/login.
|
||||||
|
func AdminLoginHandler(loginClient LoginClient) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req adminLoginRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Username == "" || req.Password == "" {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "username and password required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token, expiresIn, err := loginClient.Login(req.Username, req.Password)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusUnauthorized, "authentication failed")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||||
|
writeJSON(w, http.StatusOK, adminLoginResponse{
|
||||||
|
Token: token,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminLogoutHandler handles POST /v1/auth/logout.
|
||||||
|
func AdminLogoutHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
// MCIAS token revocation is not currently supported.
|
||||||
|
// The client should discard the token.
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminHealthHandler handles GET /v1/health.
|
||||||
|
func AdminHealthHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||||
|
}
|
||||||
|
}
|
||||||
116
internal/server/admin_auth_test.go
Normal file
116
internal/server/admin_auth_test.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminHealthHandler(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Health endpoint does not require auth.
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/health", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp map[string]string
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if resp["status"] != "ok" {
|
||||||
|
t.Fatalf("status field: got %q, want %q", resp["status"], "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminLoginSuccess(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
body := `{"username":"admin","password":"secret"}`
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/auth/login", body)
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp adminLoginResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Token != "test-token" {
|
||||||
|
t.Fatalf("token: got %q, want %q", resp.Token, "test-token")
|
||||||
|
}
|
||||||
|
if resp.ExpiresAt == "" {
|
||||||
|
t.Fatal("expires_at: expected non-empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminLoginInvalidCreds(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
|
||||||
|
validator := &fakeValidator{
|
||||||
|
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||||
|
}
|
||||||
|
login := &fakeLoginClient{err: auth.ErrUnauthorized}
|
||||||
|
reloader := &fakePolicyReloader{}
|
||||||
|
gcState := &GCState{}
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||||
|
DB: database,
|
||||||
|
Login: login,
|
||||||
|
Engine: reloader,
|
||||||
|
AuditFn: nil,
|
||||||
|
GCState: gcState,
|
||||||
|
})
|
||||||
|
|
||||||
|
body := `{"username":"admin","password":"wrong"}`
|
||||||
|
rr := adminReq(t, r, "POST", "/v1/auth/login", body)
|
||||||
|
if rr.Code != 401 {
|
||||||
|
t.Fatalf("status: got %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var errResp adminErrorResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&errResp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if errResp.Error != "authentication failed" {
|
||||||
|
t.Fatalf("error: got %q, want %q", errResp.Error, "authentication failed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminLoginMissingFields(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
body := `{"username":"admin"}`
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/auth/login", body)
|
||||||
|
if rr.Code != 400 {
|
||||||
|
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminLoginBadJSON(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/auth/login", "not json")
|
||||||
|
if rr.Code != 400 {
|
||||||
|
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminLogout(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/auth/logout", "")
|
||||||
|
if rr.Code != 204 {
|
||||||
|
t.Fatalf("status: got %d, want 204", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
66
internal/server/admin_gc.go
Normal file
66
internal/server/admin_gc.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GCLastRun records the result of the last GC run.
|
||||||
|
type GCLastRun struct {
|
||||||
|
StartedAt string `json:"started_at"`
|
||||||
|
CompletedAt string `json:"completed_at,omitempty"`
|
||||||
|
BlobsRemoved int `json:"blobs_removed"`
|
||||||
|
BytesFreed int64 `json:"bytes_freed"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GCState tracks the current state of garbage collection.
|
||||||
|
type GCState struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
Running bool `json:"running"`
|
||||||
|
LastRun *GCLastRun `json:"last_run,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type gcStatusResponse struct {
|
||||||
|
Running bool `json:"running"`
|
||||||
|
LastRun *GCLastRun `json:"last_run,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type gcTriggerResponse struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminTriggerGCHandler handles POST /v1/gc.
|
||||||
|
func AdminTriggerGCHandler(state *GCState) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
state.mu.Lock()
|
||||||
|
if state.Running {
|
||||||
|
state.mu.Unlock()
|
||||||
|
writeAdminError(w, http.StatusConflict, "garbage collection already running")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
state.Running = true
|
||||||
|
state.mu.Unlock()
|
||||||
|
|
||||||
|
// GC engine is Phase 9 -- for now, just mark as running and return.
|
||||||
|
// The actual GC goroutine will be wired up in Phase 9.
|
||||||
|
gcID := uuid.New().String()
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusAccepted, gcTriggerResponse{ID: gcID})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminGCStatusHandler handles GET /v1/gc/status.
|
||||||
|
func AdminGCStatusHandler(state *GCState) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
state.mu.Lock()
|
||||||
|
resp := gcStatusResponse{
|
||||||
|
Running: state.Running,
|
||||||
|
LastRun: state.LastRun,
|
||||||
|
}
|
||||||
|
state.mu.Unlock()
|
||||||
|
|
||||||
|
writeJSON(w, http.StatusOK, resp)
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/server/admin_gc_test.go
Normal file
94
internal/server/admin_gc_test.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminTriggerGC(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||||
|
if rr.Code != 202 {
|
||||||
|
t.Fatalf("status: got %d, want 202; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp gcTriggerResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if resp.ID == "" {
|
||||||
|
t.Fatal("expected non-empty GC ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminTriggerGCConflict(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// First trigger should succeed.
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||||
|
if rr.Code != 202 {
|
||||||
|
t.Fatalf("first trigger status: got %d, want 202", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second trigger should conflict because GC is still "running".
|
||||||
|
rr = adminReq(t, router, "POST", "/v1/gc", "")
|
||||||
|
if rr.Code != 409 {
|
||||||
|
t.Fatalf("second trigger status: got %d, want 409; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGCStatus(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Before triggering GC.
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resp gcStatusResponse
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if resp.Running {
|
||||||
|
t.Fatal("expected running=false before trigger")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger GC.
|
||||||
|
rr = adminReq(t, router, "POST", "/v1/gc", "")
|
||||||
|
if rr.Code != 202 {
|
||||||
|
t.Fatalf("trigger status: got %d, want 202", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// After triggering GC.
|
||||||
|
rr = adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if !resp.Running {
|
||||||
|
t.Fatal("expected running=true after trigger")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGCNonAdmin(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router := buildNonAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||||
|
if rr.Code != 403 {
|
||||||
|
t.Fatalf("trigger status: got %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
rr = adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||||
|
if rr.Code != 403 {
|
||||||
|
t.Fatalf("status status: got %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
343
internal/server/admin_policy.go
Normal file
343
internal/server/admin_policy.go
Normal file
@@ -0,0 +1,343 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PolicyReloader can reload policy rules from a store.
|
||||||
|
type PolicyReloader interface {
|
||||||
|
Reload(store policy.RuleStore) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// policyCreateRequest is the JSON body for creating a policy rule.
|
||||||
|
type policyCreateRequest struct {
|
||||||
|
Priority int `json:"priority"`
|
||||||
|
Description string `json:"description"`
|
||||||
|
Effect string `json:"effect"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
AccountTypes []string `json:"account_types,omitempty"`
|
||||||
|
SubjectUUID string `json:"subject_uuid,omitempty"`
|
||||||
|
Actions []string `json:"actions"`
|
||||||
|
Repositories []string `json:"repositories,omitempty"`
|
||||||
|
Enabled *bool `json:"enabled,omitempty"` // pointer to distinguish unset from false
|
||||||
|
}
|
||||||
|
|
||||||
|
// policyUpdateRequest is the JSON body for updating a policy rule.
|
||||||
|
type policyUpdateRequest struct {
|
||||||
|
Priority *int `json:"priority,omitempty"`
|
||||||
|
Description *string `json:"description,omitempty"`
|
||||||
|
Effect *string `json:"effect,omitempty"`
|
||||||
|
Roles []string `json:"roles,omitempty"`
|
||||||
|
AccountTypes []string `json:"account_types,omitempty"`
|
||||||
|
SubjectUUID *string `json:"subject_uuid,omitempty"`
|
||||||
|
Actions []string `json:"actions,omitempty"`
|
||||||
|
Repositories []string `json:"repositories,omitempty"`
|
||||||
|
Enabled *bool `json:"enabled,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var validActions = map[string]bool{
|
||||||
|
string(policy.ActionVersionCheck): true,
|
||||||
|
string(policy.ActionPull): true,
|
||||||
|
string(policy.ActionPush): true,
|
||||||
|
string(policy.ActionDelete): true,
|
||||||
|
string(policy.ActionCatalog): true,
|
||||||
|
string(policy.ActionPolicyManage): true,
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateActions(actions []string) error {
|
||||||
|
for _, a := range actions {
|
||||||
|
if !validActions[a] {
|
||||||
|
return fmt.Errorf("invalid action: %q", a)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateEffect(effect string) error {
|
||||||
|
if effect != "allow" && effect != "deny" {
|
||||||
|
return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminListPolicyRulesHandler handles GET /v1/policy/rules.
|
||||||
|
func AdminListPolicyRulesHandler(database *db.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
limit := 50
|
||||||
|
offset := 0
|
||||||
|
if n := r.URL.Query().Get("n"); n != "" {
|
||||||
|
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||||
|
limit = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if o := r.URL.Query().Get("offset"); o != "" {
|
||||||
|
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||||
|
offset = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules, err := database.ListPolicyRules(limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rules == nil {
|
||||||
|
rules = []db.PolicyRuleRow{}
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, rules)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminCreatePolicyRuleHandler handles POST /v1/policy/rules.
|
||||||
|
func AdminCreatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var req policyCreateRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Priority < 1 {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Description == "" {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "description is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := validateEffect(req.Effect); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if len(req.Actions) == 0 {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "at least one action is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := validateActions(req.Actions); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
enabled := true
|
||||||
|
if req.Enabled != nil {
|
||||||
|
enabled = *req.Enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
createdBy := ""
|
||||||
|
if claims != nil {
|
||||||
|
createdBy = claims.Subject
|
||||||
|
}
|
||||||
|
|
||||||
|
row := db.PolicyRuleRow{
|
||||||
|
Priority: req.Priority,
|
||||||
|
Description: req.Description,
|
||||||
|
Effect: req.Effect,
|
||||||
|
Roles: req.Roles,
|
||||||
|
AccountTypes: req.AccountTypes,
|
||||||
|
SubjectUUID: req.SubjectUUID,
|
||||||
|
Actions: req.Actions,
|
||||||
|
Repositories: req.Repositories,
|
||||||
|
Enabled: enabled,
|
||||||
|
CreatedBy: createdBy,
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := database.CreatePolicyRule(row)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload policy engine.
|
||||||
|
if engine != nil {
|
||||||
|
_ = engine.Reload(database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if auditFn != nil {
|
||||||
|
auditFn("policy_rule_created", createdBy, "", "", r.RemoteAddr, map[string]string{
|
||||||
|
"rule_id": strconv.FormatInt(id, 10),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
created, err := database.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusCreated, created)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminGetPolicyRuleHandler handles GET /v1/policy/rules/{id}.
|
||||||
|
func AdminGetPolicyRuleHandler(database *db.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := database.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||||
|
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminUpdatePolicyRuleHandler handles PATCH /v1/policy/rules/{id}.
|
||||||
|
func AdminUpdatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var req policyUpdateRequest
|
||||||
|
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if req.Priority != nil && *req.Priority < 1 {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if req.Effect != nil {
|
||||||
|
if err := validateEffect(*req.Effect); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if req.Actions != nil {
|
||||||
|
if len(req.Actions) == 0 {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "at least one action is required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if err := validateActions(req.Actions); err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
updates := db.PolicyRuleRow{}
|
||||||
|
if req.Priority != nil {
|
||||||
|
updates.Priority = *req.Priority
|
||||||
|
}
|
||||||
|
if req.Description != nil {
|
||||||
|
updates.Description = *req.Description
|
||||||
|
}
|
||||||
|
if req.Effect != nil {
|
||||||
|
updates.Effect = *req.Effect
|
||||||
|
}
|
||||||
|
if req.Roles != nil {
|
||||||
|
updates.Roles = req.Roles
|
||||||
|
}
|
||||||
|
if req.AccountTypes != nil {
|
||||||
|
updates.AccountTypes = req.AccountTypes
|
||||||
|
}
|
||||||
|
if req.SubjectUUID != nil {
|
||||||
|
updates.SubjectUUID = *req.SubjectUUID
|
||||||
|
}
|
||||||
|
if req.Actions != nil {
|
||||||
|
updates.Actions = req.Actions
|
||||||
|
}
|
||||||
|
if req.Repositories != nil {
|
||||||
|
updates.Repositories = req.Repositories
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := database.UpdatePolicyRule(id, updates); err != nil {
|
||||||
|
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||||
|
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle enabled separately since it's a bool.
|
||||||
|
if req.Enabled != nil {
|
||||||
|
if err := database.SetPolicyRuleEnabled(id, *req.Enabled); err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload policy engine.
|
||||||
|
if engine != nil {
|
||||||
|
_ = engine.Reload(database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if auditFn != nil {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
actorID := ""
|
||||||
|
if claims != nil {
|
||||||
|
actorID = claims.Subject
|
||||||
|
}
|
||||||
|
auditFn("policy_rule_updated", actorID, "", "", r.RemoteAddr, map[string]string{
|
||||||
|
"rule_id": strconv.FormatInt(id, 10),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
updated, err := database.GetPolicyRule(id)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, updated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDeletePolicyRuleHandler handles DELETE /v1/policy/rules/{id}.
|
||||||
|
func AdminDeletePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := database.DeletePolicyRule(id); err != nil {
|
||||||
|
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||||
|
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reload policy engine.
|
||||||
|
if engine != nil {
|
||||||
|
_ = engine.Reload(database)
|
||||||
|
}
|
||||||
|
|
||||||
|
if auditFn != nil {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
actorID := ""
|
||||||
|
if claims != nil {
|
||||||
|
actorID = claims.Subject
|
||||||
|
}
|
||||||
|
auditFn("policy_rule_deleted", actorID, "", "", r.RemoteAddr, map[string]string{
|
||||||
|
"rule_id": strconv.FormatInt(id, 10),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
337
internal/server/admin_policy_test.go
Normal file
337
internal/server/admin_policy_test.go
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAdminPolicyCRUDCycle(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, reloader := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Create a rule.
|
||||||
|
createBody := `{
|
||||||
|
"priority": 50,
|
||||||
|
"description": "allow CI push",
|
||||||
|
"effect": "allow",
|
||||||
|
"roles": ["ci"],
|
||||||
|
"actions": ["registry:push", "registry:pull"],
|
||||||
|
"repositories": ["production/*"],
|
||||||
|
"enabled": true
|
||||||
|
}`
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||||
|
if rr.Code != 201 {
|
||||||
|
t.Fatalf("create status: got %d, want 201; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
if reloader.reloadCount != 1 {
|
||||||
|
t.Fatalf("reload count after create: got %d, want 1", reloader.reloadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
var created db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||||
|
t.Fatalf("decode created: %v", err)
|
||||||
|
}
|
||||||
|
if created.ID == 0 {
|
||||||
|
t.Fatal("expected non-zero ID")
|
||||||
|
}
|
||||||
|
if created.Effect != "allow" {
|
||||||
|
t.Fatalf("effect: got %q, want %q", created.Effect, "allow")
|
||||||
|
}
|
||||||
|
if len(created.Roles) != 1 || created.Roles[0] != "ci" {
|
||||||
|
t.Fatalf("roles: got %v, want [ci]", created.Roles)
|
||||||
|
}
|
||||||
|
if created.CreatedBy != "admin-uuid" {
|
||||||
|
t.Fatalf("created_by: got %q, want %q", created.CreatedBy, "admin-uuid")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the rule.
|
||||||
|
rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("get status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var got db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&got); err != nil {
|
||||||
|
t.Fatalf("decode got: %v", err)
|
||||||
|
}
|
||||||
|
if got.ID != created.ID {
|
||||||
|
t.Fatalf("id: got %d, want %d", got.ID, created.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List rules.
|
||||||
|
rr = adminReq(t, router, "GET", "/v1/policy/rules/", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("list status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil {
|
||||||
|
t.Fatalf("decode list: %v", err)
|
||||||
|
}
|
||||||
|
if len(rules) != 1 {
|
||||||
|
t.Fatalf("rule count: got %d, want 1", len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the rule.
|
||||||
|
updateBody := `{"priority": 25, "description": "updated CI push"}`
|
||||||
|
rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), updateBody)
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
if reloader.reloadCount != 2 {
|
||||||
|
t.Fatalf("reload count after update: got %d, want 2", reloader.reloadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
var updated db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil {
|
||||||
|
t.Fatalf("decode updated: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Priority != 25 {
|
||||||
|
t.Fatalf("updated priority: got %d, want 25", updated.Priority)
|
||||||
|
}
|
||||||
|
if updated.Description != "updated CI push" {
|
||||||
|
t.Fatalf("updated description: got %q, want %q", updated.Description, "updated CI push")
|
||||||
|
}
|
||||||
|
// Effect should be unchanged.
|
||||||
|
if updated.Effect != "allow" {
|
||||||
|
t.Fatalf("updated effect: got %q, want %q (unchanged)", updated.Effect, "allow")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete the rule.
|
||||||
|
rr = adminReq(t, router, "DELETE", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||||
|
if rr.Code != 204 {
|
||||||
|
t.Fatalf("delete status: got %d, want 204", rr.Code)
|
||||||
|
}
|
||||||
|
if reloader.reloadCount != 3 {
|
||||||
|
t.Fatalf("reload count after delete: got %d, want 3", reloader.reloadCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone.
|
||||||
|
rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("after delete status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminCreatePolicyRuleValidation(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "priority too low",
|
||||||
|
body: `{"priority":0,"description":"test","effect":"allow","actions":["registry:pull"]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing description",
|
||||||
|
body: `{"priority":1,"effect":"allow","actions":["registry:pull"]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid effect",
|
||||||
|
body: `{"priority":1,"description":"test","effect":"maybe","actions":["registry:pull"]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no actions",
|
||||||
|
body: `{"priority":1,"description":"test","effect":"allow","actions":[]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid action",
|
||||||
|
body: `{"priority":1,"description":"test","effect":"allow","actions":["bogus:action"]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "bad JSON",
|
||||||
|
body: `not json`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/policy/rules/", tc.body)
|
||||||
|
if rr.Code != tc.want {
|
||||||
|
t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUpdatePolicyRuleValidation(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Create a rule to update.
|
||||||
|
createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"]}`
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||||
|
if rr.Code != 201 {
|
||||||
|
t.Fatalf("create status: got %d, want 201", rr.Code)
|
||||||
|
}
|
||||||
|
var created db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body string
|
||||||
|
want int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "priority too low",
|
||||||
|
body: `{"priority":0}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid effect",
|
||||||
|
body: `{"effect":"maybe"}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty actions",
|
||||||
|
body: `{"actions":[]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid action",
|
||||||
|
body: `{"actions":["bogus:action"]}`,
|
||||||
|
want: 400,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
rr := adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), tc.body)
|
||||||
|
if rr.Code != tc.want {
|
||||||
|
t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUpdatePolicyRuleNotFound(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "PATCH", "/v1/policy/rules/9999", `{"description":"nope"}`)
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDeletePolicyRuleNotFound(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "DELETE", "/v1/policy/rules/9999", "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGetPolicyRuleNotFound(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/policy/rules/9999", "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGetPolicyRuleInvalidID(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/policy/rules/not-a-number", "")
|
||||||
|
if rr.Code != 400 {
|
||||||
|
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminPolicyRulesNonAdmin(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router := buildNonAdminRouter(t, database)
|
||||||
|
|
||||||
|
// All policy rule endpoints require admin.
|
||||||
|
endpoints := []struct {
|
||||||
|
method string
|
||||||
|
path string
|
||||||
|
body string
|
||||||
|
}{
|
||||||
|
{"GET", "/v1/policy/rules/", ""},
|
||||||
|
{"POST", "/v1/policy/rules/", `{"priority":1,"description":"test","effect":"allow","actions":["registry:pull"]}`},
|
||||||
|
{"GET", "/v1/policy/rules/1", ""},
|
||||||
|
{"PATCH", "/v1/policy/rules/1", `{"description":"updated"}`},
|
||||||
|
{"DELETE", "/v1/policy/rules/1", ""},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ep := range endpoints {
|
||||||
|
t.Run(ep.method+" "+ep.path, func(t *testing.T) {
|
||||||
|
rr := adminReq(t, router, ep.method, ep.path, ep.body)
|
||||||
|
if rr.Code != 403 {
|
||||||
|
t.Fatalf("status: got %d, want 403; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminUpdatePolicyRuleEnabled(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
// Create an enabled rule.
|
||||||
|
createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"],"enabled":true}`
|
||||||
|
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||||
|
if rr.Code != 201 {
|
||||||
|
t.Fatalf("create status: got %d, want 201", rr.Code)
|
||||||
|
}
|
||||||
|
var created db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if !created.Enabled {
|
||||||
|
t.Fatal("expected created rule to be enabled")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disable it.
|
||||||
|
rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), `{"enabled":false}`)
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
var updated db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if updated.Enabled {
|
||||||
|
t.Fatal("expected updated rule to be disabled")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListPolicyRulesEmpty(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/policy/rules/", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rules []db.PolicyRuleRow
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(rules) != 0 {
|
||||||
|
t.Fatalf("rule count: got %d, want 0", len(rules))
|
||||||
|
}
|
||||||
|
}
|
||||||
94
internal/server/admin_repo.go
Normal file
94
internal/server/admin_repo.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminListReposHandler handles GET /v1/repositories.
|
||||||
|
func AdminListReposHandler(database *db.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
limit := 50
|
||||||
|
offset := 0
|
||||||
|
if n := r.URL.Query().Get("n"); n != "" {
|
||||||
|
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||||
|
limit = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if o := r.URL.Query().Get("offset"); o != "" {
|
||||||
|
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||||
|
offset = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
repos, err := database.ListRepositoriesWithMetadata(limit, offset)
|
||||||
|
if err != nil {
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if repos == nil {
|
||||||
|
repos = []db.RepoMetadata{}
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, repos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminGetRepoHandler handles GET /v1/repositories/*.
|
||||||
|
func AdminGetRepoHandler(database *db.DB) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
name := chi.URLParam(r, "*")
|
||||||
|
if name == "" {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "repository name required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
detail, err := database.GetRepositoryDetail(name)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeAdminError(w, http.StatusNotFound, "repository not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeJSON(w, http.StatusOK, detail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AdminDeleteRepoHandler handles DELETE /v1/repositories/*.
|
||||||
|
// Requires admin role (enforced by RequireAdmin middleware).
|
||||||
|
func AdminDeleteRepoHandler(database *db.DB, auditFn AuditFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
name := chi.URLParam(r, "*")
|
||||||
|
if name == "" {
|
||||||
|
writeAdminError(w, http.StatusBadRequest, "repository name required")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := database.DeleteRepository(name); err != nil {
|
||||||
|
if errors.Is(err, db.ErrRepoNotFound) {
|
||||||
|
writeAdminError(w, http.StatusNotFound, "repository not found")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if auditFn != nil {
|
||||||
|
claims := auth.ClaimsFromContext(r.Context())
|
||||||
|
actorID := ""
|
||||||
|
if claims != nil {
|
||||||
|
actorID = claims.Subject
|
||||||
|
}
|
||||||
|
auditFn("repo_deleted", actorID, name, "", r.RemoteAddr, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusNoContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
186
internal/server/admin_repo_test.go
Normal file
186
internal/server/admin_repo_test.go
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// seedRepoForAdmin inserts a repository with a manifest and tags into the test DB.
|
||||||
|
func seedRepoForAdmin(t *testing.T, database *db.DB, name string, tags []string) {
|
||||||
|
t.Helper()
|
||||||
|
_, err := database.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert repo %q: %v", name, err)
|
||||||
|
}
|
||||||
|
var repoID int64
|
||||||
|
if err := database.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&repoID); err != nil {
|
||||||
|
t.Fatalf("select repo id: %v", err)
|
||||||
|
}
|
||||||
|
_, err = database.Exec(
|
||||||
|
`INSERT INTO manifests (repository_id, digest, media_type, content, size)
|
||||||
|
VALUES (?, ?, 'application/vnd.oci.image.manifest.v1+json', '{}', 512)`,
|
||||||
|
repoID, "sha256:manifest-"+name,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert manifest: %v", err)
|
||||||
|
}
|
||||||
|
var manifestID int64
|
||||||
|
if err := database.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil {
|
||||||
|
t.Fatalf("select manifest id: %v", err)
|
||||||
|
}
|
||||||
|
for _, tag := range tags {
|
||||||
|
_, err := database.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`,
|
||||||
|
repoID, tag, manifestID)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("insert tag %q: %v", tag, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListRepos(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
seedRepoForAdmin(t, database, "alpha/app", []string{"latest", "v1.0"})
|
||||||
|
seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"})
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/repositories", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var repos []db.RepoMetadata
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 2 {
|
||||||
|
t.Fatalf("repo count: got %d, want 2", len(repos))
|
||||||
|
}
|
||||||
|
if repos[0].Name != "alpha/app" {
|
||||||
|
t.Fatalf("first repo: got %q, want %q", repos[0].Name, "alpha/app")
|
||||||
|
}
|
||||||
|
if repos[0].TagCount != 2 {
|
||||||
|
t.Fatalf("tag count: got %d, want 2", repos[0].TagCount)
|
||||||
|
}
|
||||||
|
if repos[0].ManifestCount != 1 {
|
||||||
|
t.Fatalf("manifest count: got %d, want 1", repos[0].ManifestCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListReposEmpty(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/repositories", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var repos []db.RepoMetadata
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 0 {
|
||||||
|
t.Fatalf("repo count: got %d, want 0", len(repos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminListReposPagination(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
seedRepoForAdmin(t, database, "alpha/app", []string{"latest"})
|
||||||
|
seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"})
|
||||||
|
seedRepoForAdmin(t, database, "charlie/svc", []string{"latest"})
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/repositories?n=2&offset=0", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
var repos []db.RepoMetadata
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if len(repos) != 2 {
|
||||||
|
t.Fatalf("page 1 count: got %d, want 2", len(repos))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGetRepoDetail(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest", "v1.0"})
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "")
|
||||||
|
if rr.Code != 200 {
|
||||||
|
t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
var detail db.RepoDetail
|
||||||
|
if err := json.NewDecoder(rr.Body).Decode(&detail); err != nil {
|
||||||
|
t.Fatalf("decode: %v", err)
|
||||||
|
}
|
||||||
|
if detail.Name != "myorg/myapp" {
|
||||||
|
t.Fatalf("name: got %q, want %q", detail.Name, "myorg/myapp")
|
||||||
|
}
|
||||||
|
if len(detail.Tags) != 2 {
|
||||||
|
t.Fatalf("tag count: got %d, want 2", len(detail.Tags))
|
||||||
|
}
|
||||||
|
if len(detail.Manifests) != 1 {
|
||||||
|
t.Fatalf("manifest count: got %d, want 1", len(detail.Manifests))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminGetRepoDetailNotFound(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "GET", "/v1/repositories/nonexistent/repo", "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDeleteRepo(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"})
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "")
|
||||||
|
if rr.Code != 204 {
|
||||||
|
t.Fatalf("status: got %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify it's gone.
|
||||||
|
rr = adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("after delete status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDeleteRepoNotFound(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router, _ := buildAdminRouter(t, database)
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "DELETE", "/v1/repositories/nonexistent/repo", "")
|
||||||
|
if rr.Code != 404 {
|
||||||
|
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdminDeleteRepoNonAdmin(t *testing.T) {
|
||||||
|
database := openAdminTestDB(t)
|
||||||
|
router := buildNonAdminRouter(t, database)
|
||||||
|
|
||||||
|
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"})
|
||||||
|
|
||||||
|
rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "")
|
||||||
|
if rr.Code != 403 {
|
||||||
|
t.Fatalf("status: got %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
56
internal/server/admin_routes.go
Normal file
56
internal/server/admin_routes.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AdminDeps holds the dependencies needed by admin routes.
|
||||||
|
type AdminDeps struct {
|
||||||
|
DB *db.DB
|
||||||
|
Login LoginClient
|
||||||
|
Engine PolicyReloader
|
||||||
|
AuditFn AuditFunc
|
||||||
|
GCState *GCState
|
||||||
|
}
|
||||||
|
|
||||||
|
// MountAdminRoutes adds admin REST endpoints to the router.
|
||||||
|
// Auth middleware is applied at the route group level.
|
||||||
|
func MountAdminRoutes(r chi.Router, validator TokenValidator, serviceName string, deps AdminDeps) {
|
||||||
|
// Health endpoint - no auth required.
|
||||||
|
r.Get("/v1/health", AdminHealthHandler())
|
||||||
|
|
||||||
|
// Auth endpoints - no bearer auth required (login uses credentials).
|
||||||
|
r.Post("/v1/auth/login", AdminLoginHandler(deps.Login))
|
||||||
|
|
||||||
|
// Authenticated endpoints.
|
||||||
|
r.Route("/v1", func(v1 chi.Router) {
|
||||||
|
v1.Use(RequireAuth(validator, serviceName))
|
||||||
|
|
||||||
|
// Logout.
|
||||||
|
v1.Post("/auth/logout", AdminLogoutHandler())
|
||||||
|
|
||||||
|
// Repositories - list and detail require auth, delete requires admin.
|
||||||
|
v1.Get("/repositories", AdminListReposHandler(deps.DB))
|
||||||
|
v1.Get("/repositories/*", AdminGetRepoHandler(deps.DB))
|
||||||
|
v1.With(RequireAdmin()).Delete("/repositories/*", AdminDeleteRepoHandler(deps.DB, deps.AuditFn))
|
||||||
|
|
||||||
|
// Policy - all require admin.
|
||||||
|
v1.Route("/policy/rules", func(pr chi.Router) {
|
||||||
|
pr.Use(RequireAdmin())
|
||||||
|
pr.Get("/", AdminListPolicyRulesHandler(deps.DB))
|
||||||
|
pr.Post("/", AdminCreatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||||
|
pr.Get("/{id}", AdminGetPolicyRuleHandler(deps.DB))
|
||||||
|
pr.Patch("/{id}", AdminUpdatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||||
|
pr.Delete("/{id}", AdminDeletePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Audit - requires admin.
|
||||||
|
v1.With(RequireAdmin()).Get("/audit", AdminListAuditHandler(deps.DB))
|
||||||
|
|
||||||
|
// GC - requires admin.
|
||||||
|
v1.With(RequireAdmin()).Post("/gc", AdminTriggerGCHandler(deps.GCState))
|
||||||
|
v1.With(RequireAdmin()).Get("/gc/status", AdminGCStatusHandler(deps.GCState))
|
||||||
|
})
|
||||||
|
}
|
||||||
148
internal/server/admin_test.go
Normal file
148
internal/server/admin_test.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||||
|
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func openAdminTestDB(t *testing.T) *db.DB {
|
||||||
|
t.Helper()
|
||||||
|
path := filepath.Join(t.TempDir(), "test.db")
|
||||||
|
d, err := db.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Open: %v", err)
|
||||||
|
}
|
||||||
|
t.Cleanup(func() { _ = d.Close() })
|
||||||
|
if err := d.Migrate(); err != nil {
|
||||||
|
t.Fatalf("Migrate: %v", err)
|
||||||
|
}
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
type fakePolicyReloader struct {
|
||||||
|
reloadCount int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *fakePolicyReloader) Reload(_ policy.RuleStore) error {
|
||||||
|
f.reloadCount++
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAdminRouter creates a chi router with admin routes wired up,
|
||||||
|
// using a fakeValidator that returns admin claims for any bearer token.
|
||||||
|
func buildAdminRouter(t *testing.T, database *db.DB) (chi.Router, *fakePolicyReloader) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
validator := &fakeValidator{
|
||||||
|
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||||
|
}
|
||||||
|
login := &fakeLoginClient{token: "test-token", expiresIn: 3600}
|
||||||
|
reloader := &fakePolicyReloader{}
|
||||||
|
gcState := &GCState{}
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||||
|
DB: database,
|
||||||
|
Login: login,
|
||||||
|
Engine: reloader,
|
||||||
|
AuditFn: nil,
|
||||||
|
GCState: gcState,
|
||||||
|
})
|
||||||
|
return r, reloader
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildNonAdminRouter creates a chi router that returns non-admin claims.
|
||||||
|
func buildNonAdminRouter(t *testing.T, database *db.DB) chi.Router {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
validator := &fakeValidator{
|
||||||
|
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||||
|
}
|
||||||
|
login := &fakeLoginClient{token: "test-token", expiresIn: 3600}
|
||||||
|
reloader := &fakePolicyReloader{}
|
||||||
|
gcState := &GCState{}
|
||||||
|
|
||||||
|
r := chi.NewRouter()
|
||||||
|
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||||
|
DB: database,
|
||||||
|
Login: login,
|
||||||
|
Engine: reloader,
|
||||||
|
AuditFn: nil,
|
||||||
|
GCState: gcState,
|
||||||
|
})
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// adminReq is a convenience helper for making HTTP requests against the admin
|
||||||
|
// router, automatically including the Authorization header.
|
||||||
|
func adminReq(t *testing.T, router http.Handler, method, path string, body string) *httptest.ResponseRecorder {
|
||||||
|
t.Helper()
|
||||||
|
var bodyReader io.Reader
|
||||||
|
if body != "" {
|
||||||
|
bodyReader = strings.NewReader(body)
|
||||||
|
}
|
||||||
|
req := httptest.NewRequest(method, path, bodyReader)
|
||||||
|
req.Header.Set("Authorization", "Bearer valid-token")
|
||||||
|
if body != "" {
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
}
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
router.ServeHTTP(rr, req)
|
||||||
|
return rr
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAdminAllowed(t *testing.T) {
|
||||||
|
claims := &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}}
|
||||||
|
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req = req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusOK {
|
||||||
|
t.Fatalf("admin allowed: got %d, want 200", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAdminDenied(t *testing.T) {
|
||||||
|
claims := &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}
|
||||||
|
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
t.Fatal("inner handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
req = req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusForbidden {
|
||||||
|
t.Fatalf("non-admin denied: got %d, want 403", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRequireAdminNoClaims(t *testing.T) {
|
||||||
|
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||||
|
t.Fatal("inner handler should not be called")
|
||||||
|
}))
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/test", nil)
|
||||||
|
rr := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(rr, req)
|
||||||
|
|
||||||
|
if rr.Code != http.StatusUnauthorized {
|
||||||
|
t.Fatalf("no claims: got %d, want 401", rr.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user