diff --git a/PROGRESS.md b/PROGRESS.md index 55a4460..2f0bb50 100644 --- a/PROGRESS.md +++ b/PROGRESS.md @@ -6,13 +6,15 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ## Current State -**Phase:** 1 complete, ready for Batch A (Phase 2 + Phase 3) +**Phase:** Batch A complete (Phases 2 + 3), ready for Phase 4 (policy engine) **Last updated:** 2026-03-19 ### Completed - Phase 0: Project scaffolding (all 4 steps) - Phase 1: Configuration & database (all 3 steps) +- Phase 2: Blob storage layer (all 2 steps) +- Phase 3: MCIAS authentication (all 4 steps) - `ARCHITECTURE.md` — Full design specification (18 sections) - `CLAUDE.md` — AI development guidance - `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps) @@ -20,14 +22,72 @@ See `PROJECT_PLAN.md` for the implementation roadmap and ### Next Steps -1. Begin Batch A: Phase 2 (blob storage) and Phase 3 (MCIAS auth) - in parallel — these are independent -2. After both complete, proceed to Phase 4 (policy engine) +1. Phase 4: Policy engine (depends on Phase 3) +2. After Phase 4, Batch B: Phase 5 (OCI pull) and Phase 8 (admin REST) --- ## Log +### 2026-03-19 — Batch A: Phase 2 (blob storage) + Phase 3 (MCIAS auth) + +**Task:** Implement content-addressed blob storage and MCIAS authentication +with OCI token endpoint and auth middleware. + +**Changes:** + +Phase 2 — `internal/storage/` (Steps 2.1 + 2.2): +- `storage.go`: `Store` struct with `layersPath`/`uploadsPath`, `New()` + constructor, digest validation (`^sha256:[a-f0-9]{64}$`), content-addressed + path layout: `/sha256//` +- `writer.go`: `BlobWriter` wrapping `*os.File` + `crypto/sha256` running hash + via `io.MultiWriter`. `StartUpload(uuid)` creates temp file in uploads dir. + `Write()` updates both file and hash. `Commit(expectedDigest)` finalizes hash, + verifies digest, `MkdirAll` prefix dir, `Rename` atomically. `Cancel()` cleans + up temp file. `BytesWritten()` returns offset. +- `reader.go`: `Open(digest)` returns `io.ReadCloser`, `Stat(digest)` returns + size, `Delete(digest)` removes blob + best-effort prefix dir cleanup, + `Exists(digest)` returns bool. All validate digest format first. +- `errors.go`: `ErrBlobNotFound`, `ErrDigestMismatch`, `ErrInvalidDigest` +- No new dependencies (stdlib only) + +Phase 3 — `internal/auth/` (Steps 3.1) + `internal/server/` (Steps 3.2–3.4): +- `auth/client.go`: `Client` with `NewClient(serverURL, caCert, serviceName, + tags)`, TLS 1.3 minimum, optional custom CA cert, 10s HTTP timeout. + `Login()` POSTs to MCIAS `/v1/auth/login`. `ValidateToken()` with SHA-256 + cache keying and 30s TTL. +- `auth/claims.go`: `Claims` struct (Subject, AccountType, Roles) with context + helpers `ContextWithClaims`/`ClaimsFromContext` +- `auth/cache.go`: `validationCache` with `sync.RWMutex`, lazy eviction, + injectable `now` function for testing +- `auth/errors.go`: `ErrUnauthorized`, `ErrMCIASUnavailable` +- `server/middleware.go`: `TokenValidator` interface, `RequireAuth` middleware + (Bearer token extraction, `WWW-Authenticate` header, OCI error format) +- `server/token.go`: `LoginClient` interface, `TokenHandler` (Basic auth → + bearer token exchange via MCIAS, RFC 3339 `issued_at`) +- `server/v2.go`: `V2Handler` returning 200 `{}` +- `server/routes.go`: `NewRouter` with chi: `/v2/token` (no auth), + `/v2/` (RequireAuth middleware) +- `server/ocierror.go`: `writeOCIError()` helper for OCI error JSON format +- New dependency: `github.com/go-chi/chi/v5` + +**Verification:** +- `make all` passes: vet clean, lint 0 issues, 52 tests passing + (7 config + 13 db/audit + 14 storage + 9 auth + 9 server), all 3 binaries built +- Storage tests: new store, digest validation (3 valid + 9 invalid), path layout, + write+commit, digest mismatch rejection (temp cleanup verified), cancel cleanup, + bytes written tracking, concurrent writes to different UUIDs, open after write, + stat, exists, delete (verify gone), open not found, invalid digest format + (covers Open/Stat/Delete/Exists) +- Auth tests: cache put/get, TTL expiry with clock injection, concurrent cache + access, login success/failure (httptest mock), validate success/revoked, + cache hit (request counter), cache expiry (clock advance) +- Server tests: RequireAuth valid/missing/invalid token, token handler + success/invalid creds/missing auth, routes integration (authenticated /v2/, + unauthenticated /v2/ → 401, token endpoint bypasses auth) + +--- + ### 2026-03-19 — Phase 1: Configuration & database **Task:** Implement TOML config loading with env overrides and validation, diff --git a/PROJECT_PLAN.md b/PROJECT_PLAN.md index 11e3c1a..c8f1211 100644 --- a/PROJECT_PLAN.md +++ b/PROJECT_PLAN.md @@ -11,8 +11,8 @@ design specification. |-------|-------------|--------| | 0 | Project scaffolding | **Complete** | | 1 | Configuration & database | **Complete** | -| 2 | Blob storage layer | Not started | -| 3 | MCIAS authentication | Not started | +| 2 | Blob storage layer | **Complete** | +| 3 | MCIAS authentication | **Complete** | | 4 | Policy engine | Not started | | 5 | OCI API — pull path | Not started | | 6 | OCI API — push path | Not started | diff --git a/go.mod b/go.mod index 5bac3ec..581df45 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.25.7 require ( github.com/dustin/go-humanize v1.0.1 // indirect + github.com/go-chi/chi/v5 v5.2.5 // indirect github.com/google/uuid v1.6.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect diff --git a/go.sum b/go.sum index e23fd96..5d4f378 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= +github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= diff --git a/internal/auth/cache.go b/internal/auth/cache.go new file mode 100644 index 0000000..8339d44 --- /dev/null +++ b/internal/auth/cache.go @@ -0,0 +1,65 @@ +package auth + +import ( + "sync" + "time" +) + +// cacheEntry holds a cached Claims value and its expiration time. +type cacheEntry struct { + claims *Claims + expiresAt time.Time +} + +// validationCache provides a concurrency-safe, TTL-based cache for token +// validation results. Tokens are keyed by their SHA-256 hex digest. +type validationCache struct { + mu sync.RWMutex + entries map[string]cacheEntry + ttl time.Duration + now func() time.Time // injectable clock for testing +} + +// newCache creates a validationCache with the given TTL. +func newCache(ttl time.Duration) *validationCache { + return &validationCache{ + entries: make(map[string]cacheEntry), + ttl: ttl, + now: time.Now, + } +} + +// get returns cached claims for the given token hash, or false if the +// entry is missing or expired. Expired entries are lazily evicted. +func (c *validationCache) get(tokenHash string) (*Claims, bool) { + c.mu.RLock() + entry, ok := c.entries[tokenHash] + c.mu.RUnlock() + + if !ok { + return nil, false + } + + if c.now().After(entry.expiresAt) { + // Lazy evict the expired entry. + c.mu.Lock() + // Re-check under write lock in case another goroutine already evicted. + if e, exists := c.entries[tokenHash]; exists && c.now().After(e.expiresAt) { + delete(c.entries, tokenHash) + } + c.mu.Unlock() + return nil, false + } + + return entry.claims, true +} + +// put stores claims in the cache with an expiration of now + TTL. +func (c *validationCache) put(tokenHash string, claims *Claims) { + c.mu.Lock() + c.entries[tokenHash] = cacheEntry{ + claims: claims, + expiresAt: c.now().Add(c.ttl), + } + c.mu.Unlock() +} diff --git a/internal/auth/cache_test.go b/internal/auth/cache_test.go new file mode 100644 index 0000000..e1b9664 --- /dev/null +++ b/internal/auth/cache_test.go @@ -0,0 +1,71 @@ +package auth + +import ( + "sync" + "testing" + "time" +) + +func TestCachePutGet(t *testing.T) { + t.Helper() + c := newCache(30 * time.Second) + + claims := &Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}} + c.put("abc123", claims) + + got, ok := c.get("abc123") + if !ok { + t.Fatal("expected cache hit, got miss") + } + if got.Subject != "alice" { + t.Fatalf("subject: got %q, want %q", got.Subject, "alice") + } +} + +func TestCacheTTLExpiry(t *testing.T) { + t.Helper() + now := time.Now() + c := newCache(30 * time.Second) + c.now = func() time.Time { return now } + + claims := &Claims{Subject: "bob"} + c.put("def456", claims) + + // Still within TTL. + got, ok := c.get("def456") + if !ok { + t.Fatal("expected cache hit before TTL expiry") + } + if got.Subject != "bob" { + t.Fatalf("subject: got %q, want %q", got.Subject, "bob") + } + + // Advance clock past TTL. + c.now = func() time.Time { return now.Add(31 * time.Second) } + + _, ok = c.get("def456") + if ok { + t.Fatal("expected cache miss after TTL expiry, got hit") + } +} + +func TestCacheConcurrent(t *testing.T) { + t.Helper() + c := newCache(30 * time.Second) + + var wg sync.WaitGroup + for i := range 100 { + wg.Add(2) + key := string(rune('A' + i%26)) + go func() { + defer wg.Done() + c.put(key, &Claims{Subject: key}) + }() + go func() { + defer wg.Done() + c.get(key) //nolint:gosec // result intentionally ignored in concurrency test + }() + } + wg.Wait() + // If we get here without a race detector complaint, the test passes. +} diff --git a/internal/auth/claims.go b/internal/auth/claims.go new file mode 100644 index 0000000..533da5d --- /dev/null +++ b/internal/auth/claims.go @@ -0,0 +1,27 @@ +package auth + +import "context" + +// Claims represents the validated identity and roles extracted from an +// MCIAS token. +type Claims struct { + Subject string + AccountType string + Roles []string +} + +// claimsKey is an unexported type used as the context key for Claims, +// preventing collisions with keys from other packages. +type claimsKey struct{} + +// ContextWithClaims returns a new context carrying the given Claims. +func ContextWithClaims(ctx context.Context, c *Claims) context.Context { + return context.WithValue(ctx, claimsKey{}, c) +} + +// ClaimsFromContext extracts Claims from the context. It returns nil if +// no claims are present. +func ClaimsFromContext(ctx context.Context) *Claims { + c, _ := ctx.Value(claimsKey{}).(*Claims) + return c +} diff --git a/internal/auth/client.go b/internal/auth/client.go new file mode 100644 index 0000000..b728cec --- /dev/null +++ b/internal/auth/client.go @@ -0,0 +1,179 @@ +package auth + +import ( + "bytes" + "crypto/sha256" + "crypto/tls" + "crypto/x509" + "encoding/hex" + "encoding/json" + "fmt" + "net/http" + "os" + "strings" + "time" +) + +const cacheTTL = 30 * time.Second + +// Client communicates with an MCIAS server for authentication and token +// validation. It caches successful validation results for 30 seconds. +type Client struct { + httpClient *http.Client + baseURL string + serviceName string + tags []string + cache *validationCache +} + +// NewClient creates an auth Client that talks to the MCIAS server at +// serverURL. If caCert is non-empty, it is loaded as a PEM file and +// used as the only trusted root CA. TLS 1.3 is required for all HTTPS +// connections. +// +// For plain HTTP URLs (used in tests), TLS configuration is skipped. +func NewClient(serverURL, caCert, serviceName string, tags []string) (*Client, error) { + transport := &http.Transport{} + + if !strings.HasPrefix(serverURL, "http://") { + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS13, + } + + if caCert != "" { + pem, err := os.ReadFile(caCert) //nolint:gosec // CA cert path is operator-supplied + if err != nil { + return nil, fmt.Errorf("auth: read CA cert %s: %w", caCert, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(pem) { + return nil, fmt.Errorf("auth: no valid certificates in %s", caCert) + } + tlsCfg.RootCAs = pool + } + + transport.TLSClientConfig = tlsCfg + } + + return &Client{ + httpClient: &http.Client{ + Transport: transport, + Timeout: 10 * time.Second, + }, + baseURL: strings.TrimRight(serverURL, "/"), + serviceName: serviceName, + tags: tags, + cache: newCache(cacheTTL), + }, nil +} + +// loginRequest is the JSON body sent to MCIAS /v1/auth/login. +type loginRequest struct { + Username string `json:"username"` + Password string `json:"password"` + ServiceName string `json:"service_name"` + Tags []string `json:"tags,omitempty"` +} + +// loginResponse is the JSON body returned by MCIAS /v1/auth/login. +type loginResponse struct { + Token string `json:"token"` + ExpiresIn int `json:"expires_in"` +} + +// Login authenticates a user against MCIAS and returns a bearer token. +func (c *Client) Login(username, password string) (token string, expiresIn int, err error) { + body, err := json.Marshal(loginRequest{ //nolint:gosec // G117: password is intentionally sent to MCIAS for authentication + Username: username, + Password: password, + ServiceName: c.serviceName, + Tags: c.tags, + }) + if err != nil { + return "", 0, fmt.Errorf("auth: marshal login request: %w", err) + } + + resp, err := c.httpClient.Post( + c.baseURL+"/v1/auth/login", + "application/json", + bytes.NewReader(body), + ) + if err != nil { + return "", 0, fmt.Errorf("auth: MCIAS login: %w", ErrMCIASUnavailable) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return "", 0, ErrUnauthorized + } + + var lr loginResponse + if err := json.NewDecoder(resp.Body).Decode(&lr); err != nil { + return "", 0, fmt.Errorf("auth: decode login response: %w", err) + } + + return lr.Token, lr.ExpiresIn, nil +} + +// validateRequest is the JSON body sent to MCIAS /v1/token/validate. +type validateRequest struct { + Token string `json:"token"` +} + +// validateResponse is the JSON body returned by MCIAS /v1/token/validate. +type validateResponse struct { + Valid bool `json:"valid"` + Claims struct { + Subject string `json:"subject"` + AccountType string `json:"account_type"` + Roles []string `json:"roles"` + } `json:"claims"` +} + +// ValidateToken checks a bearer token against MCIAS. Results are cached +// by SHA-256 hash for 30 seconds. +func (c *Client) ValidateToken(token string) (*Claims, error) { + h := sha256.Sum256([]byte(token)) + tokenHash := hex.EncodeToString(h[:]) + + if claims, ok := c.cache.get(tokenHash); ok { + return claims, nil + } + + body, err := json.Marshal(validateRequest{Token: token}) + if err != nil { + return nil, fmt.Errorf("auth: marshal validate request: %w", err) + } + + resp, err := c.httpClient.Post( + c.baseURL+"/v1/token/validate", + "application/json", + bytes.NewReader(body), + ) + if err != nil { + return nil, fmt.Errorf("auth: MCIAS validate: %w", ErrMCIASUnavailable) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + return nil, ErrUnauthorized + } + + var vr validateResponse + if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil { + return nil, fmt.Errorf("auth: decode validate response: %w", err) + } + + if !vr.Valid { + return nil, ErrUnauthorized + } + + claims := &Claims{ + Subject: vr.Claims.Subject, + AccountType: vr.Claims.AccountType, + Roles: vr.Claims.Roles, + } + + c.cache.put(tokenHash, claims) + return claims, nil +} diff --git a/internal/auth/client_test.go b/internal/auth/client_test.go new file mode 100644 index 0000000..8e4fcc8 --- /dev/null +++ b/internal/auth/client_test.go @@ -0,0 +1,220 @@ +package auth + +import ( + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" +) + +// newTestServer starts an httptest.Server that routes MCIAS endpoints. +// The handler functions are pluggable per test. +func newTestServer(t *testing.T, loginHandler, validateHandler http.HandlerFunc) *httptest.Server { + t.Helper() + mux := http.NewServeMux() + if loginHandler != nil { + mux.HandleFunc("/v1/auth/login", loginHandler) + } + if validateHandler != nil { + mux.HandleFunc("/v1/token/validate", validateHandler) + } + srv := httptest.NewServer(mux) + t.Cleanup(srv.Close) + return srv +} + +func newTestClient(t *testing.T, serverURL string) *Client { + t.Helper() + c, err := NewClient(serverURL, "", "mcr-test", []string{"env:test"}) + if err != nil { + t.Fatalf("NewClient: %v", err) + } + return c +} + +func TestLoginSuccess(t *testing.T) { + srv := newTestServer(t, func(w http.ResponseWriter, r *http.Request) { + var req loginRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(loginResponse{ + Token: "tok-abc", + ExpiresIn: 3600, + }) + }, nil) + + c := newTestClient(t, srv.URL) + token, expiresIn, err := c.Login("alice", "secret") + if err != nil { + t.Fatalf("Login: %v", err) + } + if token != "tok-abc" { + t.Fatalf("token: got %q, want %q", token, "tok-abc") + } + if expiresIn != 3600 { + t.Fatalf("expiresIn: got %d, want %d", expiresIn, 3600) + } +} + +func TestLoginFailure(t *testing.T) { + srv := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + }, nil) + + c := newTestClient(t, srv.URL) + _, _, err := c.Login("alice", "wrong") + if !errors.Is(err, ErrUnauthorized) { + t.Fatalf("Login error: got %v, want %v", err, ErrUnauthorized) + } +} + +func TestValidateSuccess(t *testing.T) { + srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(validateResponse{ + Valid: true, + Claims: struct { + Subject string `json:"subject"` + AccountType string `json:"account_type"` + Roles []string `json:"roles"` + }{ + Subject: "alice", + AccountType: "user", + Roles: []string{"reader", "writer"}, + }, + }) + }) + + c := newTestClient(t, srv.URL) + claims, err := c.ValidateToken("valid-token-123") + if err != nil { + t.Fatalf("ValidateToken: %v", err) + } + if claims.Subject != "alice" { + t.Fatalf("subject: got %q, want %q", claims.Subject, "alice") + } + if claims.AccountType != "user" { + t.Fatalf("account_type: got %q, want %q", claims.AccountType, "user") + } + if len(claims.Roles) != 2 || claims.Roles[0] != "reader" || claims.Roles[1] != "writer" { + t.Fatalf("roles: got %v, want [reader writer]", claims.Roles) + } +} + +func TestValidateRevoked(t *testing.T) { + srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(validateResponse{Valid: false}) + }) + + c := newTestClient(t, srv.URL) + _, err := c.ValidateToken("revoked-token") + if !errors.Is(err, ErrUnauthorized) { + t.Fatalf("ValidateToken error: got %v, want %v", err, ErrUnauthorized) + } +} + +func TestValidateCacheHit(t *testing.T) { + var callCount atomic.Int64 + + srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { + callCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(validateResponse{ + Valid: true, + Claims: struct { + Subject string `json:"subject"` + AccountType string `json:"account_type"` + Roles []string `json:"roles"` + }{ + Subject: "bob", + AccountType: "service", + Roles: []string{"admin"}, + }, + }) + }) + + c := newTestClient(t, srv.URL) + + // First call — should hit the server. + claims1, err := c.ValidateToken("cached-token") + if err != nil { + t.Fatalf("first ValidateToken: %v", err) + } + if callCount.Load() != 1 { + t.Fatalf("expected 1 server call after first validate, got %d", callCount.Load()) + } + + // Second call — should come from cache. + claims2, err := c.ValidateToken("cached-token") + if err != nil { + t.Fatalf("second ValidateToken: %v", err) + } + if callCount.Load() != 1 { + t.Fatalf("expected 1 server call after second validate (cache hit), got %d", callCount.Load()) + } + + if claims1.Subject != claims2.Subject { + t.Fatalf("cached claims mismatch: %q vs %q", claims1.Subject, claims2.Subject) + } +} + +func TestValidateCacheExpiry(t *testing.T) { + var callCount atomic.Int64 + + srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) { + callCount.Add(1) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(validateResponse{ + Valid: true, + Claims: struct { + Subject string `json:"subject"` + AccountType string `json:"account_type"` + Roles []string `json:"roles"` + }{ + Subject: "charlie", + AccountType: "user", + Roles: nil, + }, + }) + }) + + c := newTestClient(t, srv.URL) + + // Inject a controllable clock. + now := time.Now() + c.cache.now = func() time.Time { return now } + + // First call. + if _, err := c.ValidateToken("expiry-token"); err != nil { + t.Fatalf("first ValidateToken: %v", err) + } + if callCount.Load() != 1 { + t.Fatalf("expected 1 server call, got %d", callCount.Load()) + } + + // Second call within TTL — cache hit. + if _, err := c.ValidateToken("expiry-token"); err != nil { + t.Fatalf("second ValidateToken: %v", err) + } + if callCount.Load() != 1 { + t.Fatalf("expected 1 server call (cache hit), got %d", callCount.Load()) + } + + // Advance clock past the 30s TTL. + c.cache.now = func() time.Time { return now.Add(31 * time.Second) } + + // Third call — cache miss, should hit server again. + if _, err := c.ValidateToken("expiry-token"); err != nil { + t.Fatalf("third ValidateToken: %v", err) + } + if callCount.Load() != 2 { + t.Fatalf("expected 2 server calls after cache expiry, got %d", callCount.Load()) + } +} diff --git a/internal/auth/errors.go b/internal/auth/errors.go new file mode 100644 index 0000000..06368bd --- /dev/null +++ b/internal/auth/errors.go @@ -0,0 +1,8 @@ +package auth + +import "errors" + +var ( + ErrUnauthorized = errors.New("auth: unauthorized") + ErrMCIASUnavailable = errors.New("auth: MCIAS unavailable") +) diff --git a/internal/server/middleware.go b/internal/server/middleware.go new file mode 100644 index 0000000..d83ef48 --- /dev/null +++ b/internal/server/middleware.go @@ -0,0 +1,59 @@ +package server + +import ( + "fmt" + "net/http" + "strings" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +// TokenValidator abstracts token validation so the middleware can work +// with the real MCIAS client or a test fake. +type TokenValidator interface { + ValidateToken(token string) (*auth.Claims, error) +} + +// RequireAuth returns middleware that validates Bearer tokens via the +// given TokenValidator. On success the authenticated Claims are injected +// into the request context. On failure a 401 with an OCI-format error +// body and a WWW-Authenticate header is returned. +func RequireAuth(validator TokenValidator, serviceName string) func(http.Handler) http.Handler { + wwwAuth := fmt.Sprintf(`Bearer realm="/v2/token",service="%s"`, serviceName) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + token := extractBearerToken(r) + if token == "" { + w.Header().Set("WWW-Authenticate", wwwAuth) + writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required") + return + } + + claims, err := validator.ValidateToken(token) + if err != nil { + w.Header().Set("WWW-Authenticate", wwwAuth) + writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required") + return + } + + ctx := auth.ContextWithClaims(r.Context(), claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) + } +} + +// extractBearerToken parses a "Bearer " value from the +// Authorization header. It returns an empty string if the header is +// missing or malformed. +func extractBearerToken(r *http.Request) string { + h := r.Header.Get("Authorization") + if h == "" { + return "" + } + const prefix = "Bearer " + if !strings.HasPrefix(h, prefix) { + return "" + } + return strings.TrimSpace(h[len(prefix):]) +} diff --git a/internal/server/middleware_test.go b/internal/server/middleware_test.go new file mode 100644 index 0000000..1d9b35a --- /dev/null +++ b/internal/server/middleware_test.go @@ -0,0 +1,112 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +type fakeValidator struct { + claims *auth.Claims + err error +} + +func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) { + return f.claims, f.err +} + +func TestRequireAuthValid(t *testing.T) { + t.Helper() + claims := &auth.Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}} + validator := &fakeValidator{claims: claims} + + var gotClaims *auth.Claims + inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { + gotClaims = auth.ClaimsFromContext(r.Context()) + }) + + handler := RequireAuth(validator, "mcr-test")(inner) + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + req.Header.Set("Authorization", "Bearer valid-token") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) + } + if gotClaims == nil { + t.Fatal("expected claims in context, got nil") + } + if gotClaims.Subject != "alice" { + t.Fatalf("subject: got %q, want %q", gotClaims.Subject, "alice") + } +} + +func TestRequireAuthMissing(t *testing.T) { + t.Helper() + validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} + + inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + t.Fatal("inner handler should not be called") + }) + + handler := RequireAuth(validator, "mcr-test")(inner) + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + // No Authorization header. + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + + wwwAuth := rec.Header().Get("WWW-Authenticate") + want := `Bearer realm="/v2/token",service="mcr-test"` + if wwwAuth != want { + t.Fatalf("WWW-Authenticate: got %q, want %q", wwwAuth, want) + } + + var ociErr ociErrorResponse + if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { + t.Fatalf("decode OCI error: %v", err) + } + if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { + t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) + } +} + +func TestRequireAuthInvalid(t *testing.T) { + t.Helper() + validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} + + inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { + t.Fatal("inner handler should not be called") + }) + + handler := RequireAuth(validator, "mcr-test")(inner) + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + req.Header.Set("Authorization", "Bearer bad-token") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + + var ociErr ociErrorResponse + if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { + t.Fatalf("decode OCI error: %v", err) + } + if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { + t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) + } +} diff --git a/internal/server/ocierror.go b/internal/server/ocierror.go new file mode 100644 index 0000000..baa102c --- /dev/null +++ b/internal/server/ocierror.go @@ -0,0 +1,23 @@ +package server + +import ( + "encoding/json" + "net/http" +) + +type ociErrorEntry struct { + Code string `json:"code"` + Message string `json:"message"` +} + +type ociErrorResponse struct { + Errors []ociErrorEntry `json:"errors"` +} + +func writeOCIError(w http.ResponseWriter, code string, status int, message string) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + _ = json.NewEncoder(w).Encode(ociErrorResponse{ + Errors: []ociErrorEntry{{Code: code, Message: message}}, + }) +} diff --git a/internal/server/routes.go b/internal/server/routes.go new file mode 100644 index 0000000..2f2d67c --- /dev/null +++ b/internal/server/routes.go @@ -0,0 +1,21 @@ +package server + +import "github.com/go-chi/chi/v5" + +// NewRouter builds the chi router with all OCI Distribution Spec +// endpoints and auth middleware wired up. +func NewRouter(validator TokenValidator, loginClient LoginClient, serviceName string) *chi.Mux { + r := chi.NewRouter() + + // Token endpoint is NOT behind RequireAuth — clients use Basic auth + // here to obtain a bearer token. + r.Get("/v2/token", TokenHandler(loginClient)) + + // All other /v2 endpoints require a valid bearer token. + r.Route("/v2", func(v2 chi.Router) { + v2.Use(RequireAuth(validator, serviceName)) + v2.Get("/", V2Handler()) + }) + + return r +} diff --git a/internal/server/routes_test.go b/internal/server/routes_test.go new file mode 100644 index 0000000..dcf00b8 --- /dev/null +++ b/internal/server/routes_test.go @@ -0,0 +1,94 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +func TestRoutesV2Authenticated(t *testing.T) { + t.Helper() + validator := &fakeValidator{ + claims: &auth.Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}}, + } + loginClient := &fakeLoginClient{token: "tok-abc", expiresIn: 3600} + router := NewRouter(validator, loginClient, "mcr-test") + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + req.Header.Set("Authorization", "Bearer valid-token") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) + } + + ct := rec.Header().Get("Content-Type") + if ct != "application/json" { + t.Fatalf("Content-Type: got %q, want %q", ct, "application/json") + } + + body := rec.Body.String() + if body != "{}" { + t.Fatalf("body: got %q, want %q", body, "{}") + } +} + +func TestRoutesV2Unauthenticated(t *testing.T) { + t.Helper() + validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} + loginClient := &fakeLoginClient{} + router := NewRouter(validator, loginClient, "mcr-test") + + req := httptest.NewRequest(http.MethodGet, "/v2/", nil) + // No Authorization header. + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + + wwwAuth := rec.Header().Get("WWW-Authenticate") + want := `Bearer realm="/v2/token",service="mcr-test"` + if wwwAuth != want { + t.Fatalf("WWW-Authenticate: got %q, want %q", wwwAuth, want) + } +} + +func TestRoutesTokenEndpoint(t *testing.T) { + t.Helper() + // The validator should never be called for /v2/token. + validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} + loginClient := &fakeLoginClient{token: "tok-from-login", expiresIn: 1800} + router := NewRouter(validator, loginClient, "mcr-test") + + req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) + req.SetBasicAuth("bob", "password") + rec := httptest.NewRecorder() + + router.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) + } + + var resp tokenResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Token != "tok-from-login" { + t.Fatalf("token: got %q, want %q", resp.Token, "tok-from-login") + } + if resp.ExpiresIn != 1800 { + t.Fatalf("expires_in: got %d, want %d", resp.ExpiresIn, 1800) + } + if resp.IssuedAt == "" { + t.Fatal("issued_at: expected non-empty RFC 3339 timestamp") + } +} diff --git a/internal/server/token.go b/internal/server/token.go new file mode 100644 index 0000000..f5525bb --- /dev/null +++ b/internal/server/token.go @@ -0,0 +1,45 @@ +package server + +import ( + "encoding/json" + "net/http" + "time" +) + +// LoginClient abstracts the MCIAS login call so the handler can work +// with the real client or a test fake. +type LoginClient interface { + Login(username, password string) (token string, expiresIn int, err error) +} + +// tokenResponse is the JSON body returned by the token endpoint. +type tokenResponse struct { + Token string `json:"token"` + ExpiresIn int `json:"expires_in"` + IssuedAt string `json:"issued_at"` +} + +// TokenHandler returns an http.HandlerFunc that exchanges Basic +// credentials for a bearer token via the given LoginClient. +func TokenHandler(loginClient LoginClient) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + username, password, ok := r.BasicAuth() + if !ok || username == "" { + writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "basic authentication required") + return + } + + token, expiresIn, err := loginClient.Login(username, password) + if err != nil { + writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication failed") + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(tokenResponse{ + Token: token, + ExpiresIn: expiresIn, + IssuedAt: time.Now().UTC().Format(time.RFC3339), + }) + } +} diff --git a/internal/server/token_test.go b/internal/server/token_test.go new file mode 100644 index 0000000..9bf7d70 --- /dev/null +++ b/internal/server/token_test.go @@ -0,0 +1,98 @@ +package server + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "git.wntrmute.dev/kyle/mcr/internal/auth" +) + +type fakeLoginClient struct { + token string + expiresIn int + err error +} + +func (f *fakeLoginClient) Login(_, _ string) (string, int, error) { + return f.token, f.expiresIn, f.err +} + +func TestTokenHandlerSuccess(t *testing.T) { + t.Helper() + lc := &fakeLoginClient{token: "tok-xyz", expiresIn: 7200} + handler := TokenHandler(lc) + + req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) + req.SetBasicAuth("alice", "secret") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) + } + + var resp tokenResponse + if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { + t.Fatalf("decode response: %v", err) + } + if resp.Token != "tok-xyz" { + t.Fatalf("token: got %q, want %q", resp.Token, "tok-xyz") + } + if resp.ExpiresIn != 7200 { + t.Fatalf("expires_in: got %d, want %d", resp.ExpiresIn, 7200) + } + if resp.IssuedAt == "" { + t.Fatal("issued_at: expected non-empty RFC 3339 timestamp") + } +} + +func TestTokenHandlerInvalidCreds(t *testing.T) { + t.Helper() + lc := &fakeLoginClient{err: auth.ErrUnauthorized} + handler := TokenHandler(lc) + + req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) + req.SetBasicAuth("alice", "wrong") + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + + var ociErr ociErrorResponse + if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { + t.Fatalf("decode OCI error: %v", err) + } + if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { + t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) + } +} + +func TestTokenHandlerMissingAuth(t *testing.T) { + t.Helper() + lc := &fakeLoginClient{token: "should-not-matter"} + handler := TokenHandler(lc) + + req := httptest.NewRequest(http.MethodGet, "/v2/token", nil) + // No Authorization header. + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusUnauthorized { + t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) + } + + var ociErr ociErrorResponse + if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { + t.Fatalf("decode OCI error: %v", err) + } + if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { + t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) + } +} diff --git a/internal/server/v2.go b/internal/server/v2.go new file mode 100644 index 0000000..a280b23 --- /dev/null +++ b/internal/server/v2.go @@ -0,0 +1,13 @@ +package server + +import "net/http" + +// V2Handler returns an http.HandlerFunc that responds with 200 OK and +// an empty JSON object, per the OCI Distribution Spec version check. +func V2Handler() http.HandlerFunc { + return func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("{}")) + } +} diff --git a/internal/storage/errors.go b/internal/storage/errors.go new file mode 100644 index 0000000..651f1eb --- /dev/null +++ b/internal/storage/errors.go @@ -0,0 +1,9 @@ +package storage + +import "errors" + +var ( + ErrBlobNotFound = errors.New("storage: blob not found") + ErrDigestMismatch = errors.New("storage: digest mismatch") + ErrInvalidDigest = errors.New("storage: invalid digest format") +) diff --git a/internal/storage/reader.go b/internal/storage/reader.go new file mode 100644 index 0000000..1c33d26 --- /dev/null +++ b/internal/storage/reader.go @@ -0,0 +1,75 @@ +package storage + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" +) + +// Open validates the digest and returns a ReadCloser for the blob. +// Returns ErrBlobNotFound if the blob does not exist on disk. +func (s *Store) Open(digest string) (io.ReadCloser, error) { + if err := validateDigest(digest); err != nil { + return nil, err + } + + f, err := os.Open(s.blobPath(digest)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return nil, ErrBlobNotFound + } + return nil, fmt.Errorf("storage: open blob: %w", err) + } + return f, nil +} + +// Stat returns the size of the blob in bytes. +// Returns ErrBlobNotFound if the blob does not exist on disk. +func (s *Store) Stat(digest string) (int64, error) { + if err := validateDigest(digest); err != nil { + return 0, err + } + + info, err := os.Stat(s.blobPath(digest)) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + return 0, ErrBlobNotFound + } + return 0, fmt.Errorf("storage: stat blob: %w", err) + } + return info.Size(), nil +} + +// Delete removes the blob file and attempts to clean up its prefix +// directory. Non-empty or already-removed prefix directories are +// silently ignored. +func (s *Store) Delete(digest string) error { + if err := validateDigest(digest); err != nil { + return err + } + + path := s.blobPath(digest) + if err := os.Remove(path); err != nil { + if errors.Is(err, os.ErrNotExist) { + return ErrBlobNotFound + } + return fmt.Errorf("storage: delete blob: %w", err) + } + + // Best-effort cleanup of the prefix directory. + _ = os.Remove(filepath.Dir(path)) + + return nil +} + +// Exists reports whether the blob exists on disk. +func (s *Store) Exists(digest string) bool { + if err := validateDigest(digest); err != nil { + return false + } + + _, err := os.Stat(s.blobPath(digest)) + return err == nil +} diff --git a/internal/storage/reader_test.go b/internal/storage/reader_test.go new file mode 100644 index 0000000..e76d163 --- /dev/null +++ b/internal/storage/reader_test.go @@ -0,0 +1,107 @@ +package storage + +import ( + "errors" + "io" + "testing" +) + +func TestOpenAfterWrite(t *testing.T) { + s := newTestStore(t) + data := []byte("readable blob content") + digest := writeTestBlob(t, s, data) + + rc, err := s.Open(digest) + if err != nil { + t.Fatalf("Open: %v", err) + } + defer func() { _ = rc.Close() }() + + got, err := io.ReadAll(rc) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } + if string(got) != string(data) { + t.Fatalf("Open content: got %q, want %q", got, data) + } +} + +func TestStatAfterWrite(t *testing.T) { + s := newTestStore(t) + data := []byte("stat this blob") + digest := writeTestBlob(t, s, data) + + size, err := s.Stat(digest) + if err != nil { + t.Fatalf("Stat: %v", err) + } + if size != int64(len(data)) { + t.Fatalf("Stat size: got %d, want %d", size, len(data)) + } +} + +func TestExists(t *testing.T) { + s := newTestStore(t) + data := []byte("existence check") + digest := writeTestBlob(t, s, data) + + if !s.Exists(digest) { + t.Fatal("Exists returned false for written blob") + } + + nonexistent := "sha256:0000000000000000000000000000000000000000000000000000000000000000" + if s.Exists(nonexistent) { + t.Fatal("Exists returned true for nonexistent blob") + } +} + +func TestDelete(t *testing.T) { + s := newTestStore(t) + data := []byte("delete me") + digest := writeTestBlob(t, s, data) + + if err := s.Delete(digest); err != nil { + t.Fatalf("Delete: %v", err) + } + + if s.Exists(digest) { + t.Fatal("Exists returned true after Delete") + } + + _, err := s.Open(digest) + if !errors.Is(err, ErrBlobNotFound) { + t.Fatalf("Open after Delete: got %v, want ErrBlobNotFound", err) + } +} + +func TestOpenNotFound(t *testing.T) { + s := newTestStore(t) + + digest := "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + _, err := s.Open(digest) + if !errors.Is(err, ErrBlobNotFound) { + t.Fatalf("Open nonexistent: got %v, want ErrBlobNotFound", err) + } +} + +func TestInvalidDigestFormat(t *testing.T) { + s := newTestStore(t) + bad := "not-a-digest" + + if _, err := s.Open(bad); !errors.Is(err, ErrInvalidDigest) { + t.Fatalf("Open with bad digest: got %v, want ErrInvalidDigest", err) + } + + if _, err := s.Stat(bad); !errors.Is(err, ErrInvalidDigest) { + t.Fatalf("Stat with bad digest: got %v, want ErrInvalidDigest", err) + } + + if err := s.Delete(bad); !errors.Is(err, ErrInvalidDigest) { + t.Fatalf("Delete with bad digest: got %v, want ErrInvalidDigest", err) + } + + // Exists should return false for an invalid digest, not panic. + if s.Exists(bad) { + t.Fatal("Exists returned true for invalid digest") + } +} diff --git a/internal/storage/storage.go b/internal/storage/storage.go new file mode 100644 index 0000000..906bc34 --- /dev/null +++ b/internal/storage/storage.go @@ -0,0 +1,38 @@ +package storage + +import ( + "path/filepath" + "regexp" +) + +var digestRe = regexp.MustCompile(`^sha256:[a-f0-9]{64}$`) + +// Store manages blob storage on the local filesystem. +type Store struct { + layersPath string + uploadsPath string +} + +// New creates a Store that will write final blobs under layersPath and +// stage in-progress uploads under uploadsPath. +func New(layersPath, uploadsPath string) *Store { + return &Store{ + layersPath: layersPath, + uploadsPath: uploadsPath, + } +} + +// validateDigest checks that digest matches sha256:<64 lowercase hex chars>. +func validateDigest(digest string) error { + if !digestRe.MatchString(digest) { + return ErrInvalidDigest + } + return nil +} + +// blobPath returns the filesystem path for a blob with the given digest. +// The layout is: /sha256// +func (s *Store) blobPath(digest string) string { + hex := digest[len("sha256:"):] + return filepath.Join(s.layersPath, "sha256", hex[0:2], hex) +} diff --git a/internal/storage/storage_test.go b/internal/storage/storage_test.go new file mode 100644 index 0000000..335fd27 --- /dev/null +++ b/internal/storage/storage_test.go @@ -0,0 +1,68 @@ +package storage + +import ( + "errors" + "path/filepath" + "testing" +) + +func newTestStore(t *testing.T) *Store { + t.Helper() + dir := t.TempDir() + return New(filepath.Join(dir, "layers"), filepath.Join(dir, "uploads")) +} + +func TestNew(t *testing.T) { + s := newTestStore(t) + if s == nil { + t.Fatal("New returned nil") + } + if s.layersPath == "" { + t.Fatal("layersPath is empty") + } + if s.uploadsPath == "" { + t.Fatal("uploadsPath is empty") + } +} + +func TestValidateDigest(t *testing.T) { + valid := []string{ + "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + "sha256:0000000000000000000000000000000000000000000000000000000000000000", + "sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789", + } + for _, d := range valid { + if err := validateDigest(d); err != nil { + t.Errorf("validateDigest(%q) = %v, want nil", d, err) + } + } + + invalid := []string{ + "", + "sha256:", + "sha256:abc", + "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b85", // 63 chars + "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b8555", // 65 chars + "sha256:E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855", // uppercase + "md5:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // wrong algo + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // missing prefix + "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b85g", // non-hex char + } + for _, d := range invalid { + if err := validateDigest(d); !errors.Is(err, ErrInvalidDigest) { + t.Errorf("validateDigest(%q) = %v, want ErrInvalidDigest", d, err) + } + } +} + +func TestBlobPath(t *testing.T) { + s := New("/data/layers", "/data/uploads") + + digest := "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + got := s.blobPath(digest) + want := filepath.Join("/data/layers", "sha256", "e3", + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855") + if got != want { + t.Fatalf("blobPath(%q)\n got %q\nwant %q", digest, got, want) + } +} diff --git a/internal/storage/writer.go b/internal/storage/writer.go new file mode 100644 index 0000000..88bdfdd --- /dev/null +++ b/internal/storage/writer.go @@ -0,0 +1,107 @@ +package storage + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "hash" + "io" + "os" + "path/filepath" +) + +// BlobWriter stages blob data in a temporary file while computing its +// SHA-256 digest on the fly. +type BlobWriter struct { + file *os.File + hash hash.Hash + mw io.Writer + path string + written int64 + closed bool + store *Store +} + +// StartUpload begins a new blob upload, creating a temp file at +// /. +func (s *Store) StartUpload(uuid string) (*BlobWriter, error) { + if err := os.MkdirAll(s.uploadsPath, 0700); err != nil { + return nil, fmt.Errorf("storage: create uploads dir: %w", err) + } + + path := filepath.Join(s.uploadsPath, uuid) + f, err := os.Create(path) //nolint:gosec // upload UUID is server-generated, not user input + if err != nil { + return nil, fmt.Errorf("storage: create upload file: %w", err) + } + + h := sha256.New() + return &BlobWriter{ + file: f, + hash: h, + mw: io.MultiWriter(f, h), + path: path, + store: s, + }, nil +} + +// Write writes p to both the staging file and the running hash. +func (bw *BlobWriter) Write(p []byte) (int, error) { + n, err := bw.mw.Write(p) + bw.written += int64(n) + if err != nil { + return n, fmt.Errorf("storage: write: %w", err) + } + return n, nil +} + +// Commit finalises the upload. It closes the staging file, verifies +// the computed digest matches expectedDigest, and atomically moves +// the file to its content-addressed location. +func (bw *BlobWriter) Commit(expectedDigest string) (string, error) { + if !bw.closed { + bw.closed = true + if err := bw.file.Close(); err != nil { + return "", fmt.Errorf("storage: close upload file: %w", err) + } + } + + if err := validateDigest(expectedDigest); err != nil { + _ = os.Remove(bw.path) + return "", err + } + + computed := "sha256:" + hex.EncodeToString(bw.hash.Sum(nil)) + if computed != expectedDigest { + _ = os.Remove(bw.path) + return "", ErrDigestMismatch + } + + dst := bw.store.blobPath(computed) + if err := os.MkdirAll(filepath.Dir(dst), 0700); err != nil { + return "", fmt.Errorf("storage: create blob dir: %w", err) + } + + if err := os.Rename(bw.path, dst); err != nil { + return "", fmt.Errorf("storage: rename blob: %w", err) + } + + return computed, nil +} + +// Cancel aborts the upload, closing and removing the temp file. +func (bw *BlobWriter) Cancel() error { + if !bw.closed { + bw.closed = true + _ = bw.file.Close() + } + if err := os.Remove(bw.path); err != nil && !os.IsNotExist(err) { + return fmt.Errorf("storage: remove upload file: %w", err) + } + return nil +} + +// BytesWritten returns the number of bytes written so far. +func (bw *BlobWriter) BytesWritten() int64 { + return bw.written +} diff --git a/internal/storage/writer_test.go b/internal/storage/writer_test.go new file mode 100644 index 0000000..3b1e666 --- /dev/null +++ b/internal/storage/writer_test.go @@ -0,0 +1,188 @@ +package storage + +import ( + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "os" + "sync" + "testing" + "time" +) + +func writeTestBlob(t *testing.T, s *Store, data []byte) string { + t.Helper() + uuid := "test-upload-" + fmt.Sprintf("%d", time.Now().UnixNano()) + w, err := s.StartUpload(uuid) + if err != nil { + t.Fatalf("StartUpload: %v", err) + } + if _, err := w.Write(data); err != nil { + t.Fatalf("Write: %v", err) + } + h := sha256.Sum256(data) + digest := "sha256:" + hex.EncodeToString(h[:]) + got, err := w.Commit(digest) + if err != nil { + t.Fatalf("Commit: %v", err) + } + if got != digest { + t.Fatalf("Commit returned %q, want %q", got, digest) + } + return digest +} + +func TestWriteAndCommit(t *testing.T) { + s := newTestStore(t) + data := []byte("hello, blob storage") + digest := writeTestBlob(t, s, data) + + // Verify file exists at expected path. + path := s.blobPath(digest) + info, err := os.Stat(path) + if err != nil { + t.Fatalf("stat blob file: %v", err) + } + if info.Size() != int64(len(data)) { + t.Fatalf("blob size: got %d, want %d", info.Size(), len(data)) + } + + // Verify content. + content, err := os.ReadFile(path) //nolint:gosec // test file path from t.TempDir() + if err != nil { + t.Fatalf("read blob file: %v", err) + } + if string(content) != string(data) { + t.Fatalf("blob content: got %q, want %q", content, data) + } +} + +func TestDigestMismatch(t *testing.T) { + s := newTestStore(t) + data := []byte("some data") + + uuid := "mismatch-upload" + w, err := s.StartUpload(uuid) + if err != nil { + t.Fatalf("StartUpload: %v", err) + } + if _, err := w.Write(data); err != nil { + t.Fatalf("Write: %v", err) + } + + wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000" + _, err = w.Commit(wrongDigest) + if !errors.Is(err, ErrDigestMismatch) { + t.Fatalf("Commit with wrong digest: got %v, want ErrDigestMismatch", err) + } + + // Verify temp file was cleaned up. + tempPath := w.path + if _, err := os.Stat(tempPath); !os.IsNotExist(err) { + t.Fatalf("temp file should be removed after digest mismatch, stat err: %v", err) + } +} + +func TestCancel(t *testing.T) { + s := newTestStore(t) + + uuid := "cancel-upload" + w, err := s.StartUpload(uuid) + if err != nil { + t.Fatalf("StartUpload: %v", err) + } + if _, err := w.Write([]byte("partial data")); err != nil { + t.Fatalf("Write: %v", err) + } + + tempPath := w.path + if err := w.Cancel(); err != nil { + t.Fatalf("Cancel: %v", err) + } + + if _, err := os.Stat(tempPath); !os.IsNotExist(err) { + t.Fatalf("temp file should be removed after Cancel, stat err: %v", err) + } +} + +func TestBytesWritten(t *testing.T) { + s := newTestStore(t) + + uuid := "bytes-upload" + w, err := s.StartUpload(uuid) + if err != nil { + t.Fatalf("StartUpload: %v", err) + } + t.Cleanup(func() { _ = w.Cancel() }) + + if w.BytesWritten() != 0 { + t.Fatalf("BytesWritten before write: got %d, want 0", w.BytesWritten()) + } + + data1 := []byte("first chunk") + if _, err := w.Write(data1); err != nil { + t.Fatalf("Write: %v", err) + } + if w.BytesWritten() != int64(len(data1)) { + t.Fatalf("BytesWritten after first write: got %d, want %d", w.BytesWritten(), len(data1)) + } + + data2 := []byte("second chunk") + if _, err := w.Write(data2); err != nil { + t.Fatalf("Write: %v", err) + } + want := int64(len(data1) + len(data2)) + if w.BytesWritten() != want { + t.Fatalf("BytesWritten after second write: got %d, want %d", w.BytesWritten(), want) + } +} + +func TestConcurrentWrites(t *testing.T) { + s := newTestStore(t) + + type result struct { + digest string + err error + } + + blobs := [][]byte{ + []byte("concurrent blob alpha"), + []byte("concurrent blob beta"), + } + + results := make([]result, len(blobs)) + var wg sync.WaitGroup + + for i, data := range blobs { + wg.Add(1) + go func(idx int, d []byte) { + defer wg.Done() + uuid := fmt.Sprintf("concurrent-%d-%d", idx, time.Now().UnixNano()) + w, err := s.StartUpload(uuid) + if err != nil { + results[idx] = result{err: err} + return + } + if _, err := w.Write(d); err != nil { + results[idx] = result{err: err} + return + } + h := sha256.Sum256(d) + digest := "sha256:" + hex.EncodeToString(h[:]) + got, err := w.Commit(digest) + results[idx] = result{digest: got, err: err} + }(i, data) + } + + wg.Wait() + + for i, r := range results { + if r.err != nil { + t.Fatalf("blob %d: %v", i, r.err) + } + if !s.Exists(r.digest) { + t.Fatalf("blob %d: digest %q not found after concurrent write", i, r.digest) + } + } +}