Batch A: blob storage layer, MCIAS auth, OCI token endpoint

Phase 2 — internal/storage/:
Content-addressed blob storage with atomic writes via rename.
BlobWriter stages data in uploads dir with running SHA-256 hash,
commits by verifying digest then renaming to layers/sha256/<prefix>/<hex>.
Reader provides Open, Stat, Delete, Exists with digest validation.

Phase 3 — internal/auth/ + internal/server/:
MCIAS client with Login and ValidateToken, 30s SHA-256-keyed cache
with lazy eviction and injectable clock for testing. TLS 1.3 minimum
with optional custom CA cert.
Chi router with RequireAuth middleware (Bearer token extraction,
WWW-Authenticate header, OCI error format), token endpoint (Basic
auth → bearer exchange via MCIAS), and /v2/ version check handler.

52 tests passing (14 storage + 9 auth + 9 server + 20 existing).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-19 14:51:19 -07:00
parent fde66be9c1
commit 3314b7a618
25 changed files with 1696 additions and 6 deletions

View File

@@ -6,13 +6,15 @@ See `PROJECT_PLAN.md` for the implementation roadmap and
## Current State
**Phase:** 1 complete, ready for Batch A (Phase 2 + Phase 3)
**Phase:** Batch A complete (Phases 2 + 3), ready for Phase 4 (policy engine)
**Last updated:** 2026-03-19
### Completed
- Phase 0: Project scaffolding (all 4 steps)
- Phase 1: Configuration & database (all 3 steps)
- Phase 2: Blob storage layer (all 2 steps)
- Phase 3: MCIAS authentication (all 4 steps)
- `ARCHITECTURE.md` — Full design specification (18 sections)
- `CLAUDE.md` — AI development guidance
- `PROJECT_PLAN.md` — Implementation plan (14 phases, 40+ steps)
@@ -20,14 +22,72 @@ See `PROJECT_PLAN.md` for the implementation roadmap and
### Next Steps
1. Begin Batch A: Phase 2 (blob storage) and Phase 3 (MCIAS auth)
in parallel — these are independent
2. After both complete, proceed to Phase 4 (policy engine)
1. Phase 4: Policy engine (depends on Phase 3)
2. After Phase 4, Batch B: Phase 5 (OCI pull) and Phase 8 (admin REST)
---
## Log
### 2026-03-19 — Batch A: Phase 2 (blob storage) + Phase 3 (MCIAS auth)
**Task:** Implement content-addressed blob storage and MCIAS authentication
with OCI token endpoint and auth middleware.
**Changes:**
Phase 2 — `internal/storage/` (Steps 2.1 + 2.2):
- `storage.go`: `Store` struct with `layersPath`/`uploadsPath`, `New()`
constructor, digest validation (`^sha256:[a-f0-9]{64}$`), content-addressed
path layout: `<layers>/sha256/<first-2-hex>/<full-64-hex>`
- `writer.go`: `BlobWriter` wrapping `*os.File` + `crypto/sha256` running hash
via `io.MultiWriter`. `StartUpload(uuid)` creates temp file in uploads dir.
`Write()` updates both file and hash. `Commit(expectedDigest)` finalizes hash,
verifies digest, `MkdirAll` prefix dir, `Rename` atomically. `Cancel()` cleans
up temp file. `BytesWritten()` returns offset.
- `reader.go`: `Open(digest)` returns `io.ReadCloser`, `Stat(digest)` returns
size, `Delete(digest)` removes blob + best-effort prefix dir cleanup,
`Exists(digest)` returns bool. All validate digest format first.
- `errors.go`: `ErrBlobNotFound`, `ErrDigestMismatch`, `ErrInvalidDigest`
- No new dependencies (stdlib only)
Phase 3 — `internal/auth/` (Steps 3.1) + `internal/server/` (Steps 3.23.4):
- `auth/client.go`: `Client` with `NewClient(serverURL, caCert, serviceName,
tags)`, TLS 1.3 minimum, optional custom CA cert, 10s HTTP timeout.
`Login()` POSTs to MCIAS `/v1/auth/login`. `ValidateToken()` with SHA-256
cache keying and 30s TTL.
- `auth/claims.go`: `Claims` struct (Subject, AccountType, Roles) with context
helpers `ContextWithClaims`/`ClaimsFromContext`
- `auth/cache.go`: `validationCache` with `sync.RWMutex`, lazy eviction,
injectable `now` function for testing
- `auth/errors.go`: `ErrUnauthorized`, `ErrMCIASUnavailable`
- `server/middleware.go`: `TokenValidator` interface, `RequireAuth` middleware
(Bearer token extraction, `WWW-Authenticate` header, OCI error format)
- `server/token.go`: `LoginClient` interface, `TokenHandler` (Basic auth →
bearer token exchange via MCIAS, RFC 3339 `issued_at`)
- `server/v2.go`: `V2Handler` returning 200 `{}`
- `server/routes.go`: `NewRouter` with chi: `/v2/token` (no auth),
`/v2/` (RequireAuth middleware)
- `server/ocierror.go`: `writeOCIError()` helper for OCI error JSON format
- New dependency: `github.com/go-chi/chi/v5`
**Verification:**
- `make all` passes: vet clean, lint 0 issues, 52 tests passing
(7 config + 13 db/audit + 14 storage + 9 auth + 9 server), all 3 binaries built
- Storage tests: new store, digest validation (3 valid + 9 invalid), path layout,
write+commit, digest mismatch rejection (temp cleanup verified), cancel cleanup,
bytes written tracking, concurrent writes to different UUIDs, open after write,
stat, exists, delete (verify gone), open not found, invalid digest format
(covers Open/Stat/Delete/Exists)
- Auth tests: cache put/get, TTL expiry with clock injection, concurrent cache
access, login success/failure (httptest mock), validate success/revoked,
cache hit (request counter), cache expiry (clock advance)
- Server tests: RequireAuth valid/missing/invalid token, token handler
success/invalid creds/missing auth, routes integration (authenticated /v2/,
unauthenticated /v2/ → 401, token endpoint bypasses auth)
---
### 2026-03-19 — Phase 1: Configuration & database
**Task:** Implement TOML config loading with env overrides and validation,

View File

@@ -11,8 +11,8 @@ design specification.
|-------|-------------|--------|
| 0 | Project scaffolding | **Complete** |
| 1 | Configuration & database | **Complete** |
| 2 | Blob storage layer | Not started |
| 3 | MCIAS authentication | Not started |
| 2 | Blob storage layer | **Complete** |
| 3 | MCIAS authentication | **Complete** |
| 4 | Policy engine | Not started |
| 5 | OCI API — pull path | Not started |
| 6 | OCI API — push path | Not started |

1
go.mod
View File

@@ -4,6 +4,7 @@ go 1.25.7
require (
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/go-chi/chi/v5 v5.2.5 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect

2
go.sum
View File

@@ -1,6 +1,8 @@
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug=
github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=

65
internal/auth/cache.go Normal file
View File

@@ -0,0 +1,65 @@
package auth
import (
"sync"
"time"
)
// cacheEntry holds a cached Claims value and its expiration time.
type cacheEntry struct {
claims *Claims
expiresAt time.Time
}
// validationCache provides a concurrency-safe, TTL-based cache for token
// validation results. Tokens are keyed by their SHA-256 hex digest.
type validationCache struct {
mu sync.RWMutex
entries map[string]cacheEntry
ttl time.Duration
now func() time.Time // injectable clock for testing
}
// newCache creates a validationCache with the given TTL.
func newCache(ttl time.Duration) *validationCache {
return &validationCache{
entries: make(map[string]cacheEntry),
ttl: ttl,
now: time.Now,
}
}
// get returns cached claims for the given token hash, or false if the
// entry is missing or expired. Expired entries are lazily evicted.
func (c *validationCache) get(tokenHash string) (*Claims, bool) {
c.mu.RLock()
entry, ok := c.entries[tokenHash]
c.mu.RUnlock()
if !ok {
return nil, false
}
if c.now().After(entry.expiresAt) {
// Lazy evict the expired entry.
c.mu.Lock()
// Re-check under write lock in case another goroutine already evicted.
if e, exists := c.entries[tokenHash]; exists && c.now().After(e.expiresAt) {
delete(c.entries, tokenHash)
}
c.mu.Unlock()
return nil, false
}
return entry.claims, true
}
// put stores claims in the cache with an expiration of now + TTL.
func (c *validationCache) put(tokenHash string, claims *Claims) {
c.mu.Lock()
c.entries[tokenHash] = cacheEntry{
claims: claims,
expiresAt: c.now().Add(c.ttl),
}
c.mu.Unlock()
}

View File

@@ -0,0 +1,71 @@
package auth
import (
"sync"
"testing"
"time"
)
func TestCachePutGet(t *testing.T) {
t.Helper()
c := newCache(30 * time.Second)
claims := &Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}}
c.put("abc123", claims)
got, ok := c.get("abc123")
if !ok {
t.Fatal("expected cache hit, got miss")
}
if got.Subject != "alice" {
t.Fatalf("subject: got %q, want %q", got.Subject, "alice")
}
}
func TestCacheTTLExpiry(t *testing.T) {
t.Helper()
now := time.Now()
c := newCache(30 * time.Second)
c.now = func() time.Time { return now }
claims := &Claims{Subject: "bob"}
c.put("def456", claims)
// Still within TTL.
got, ok := c.get("def456")
if !ok {
t.Fatal("expected cache hit before TTL expiry")
}
if got.Subject != "bob" {
t.Fatalf("subject: got %q, want %q", got.Subject, "bob")
}
// Advance clock past TTL.
c.now = func() time.Time { return now.Add(31 * time.Second) }
_, ok = c.get("def456")
if ok {
t.Fatal("expected cache miss after TTL expiry, got hit")
}
}
func TestCacheConcurrent(t *testing.T) {
t.Helper()
c := newCache(30 * time.Second)
var wg sync.WaitGroup
for i := range 100 {
wg.Add(2)
key := string(rune('A' + i%26))
go func() {
defer wg.Done()
c.put(key, &Claims{Subject: key})
}()
go func() {
defer wg.Done()
c.get(key) //nolint:gosec // result intentionally ignored in concurrency test
}()
}
wg.Wait()
// If we get here without a race detector complaint, the test passes.
}

27
internal/auth/claims.go Normal file
View File

@@ -0,0 +1,27 @@
package auth
import "context"
// Claims represents the validated identity and roles extracted from an
// MCIAS token.
type Claims struct {
Subject string
AccountType string
Roles []string
}
// claimsKey is an unexported type used as the context key for Claims,
// preventing collisions with keys from other packages.
type claimsKey struct{}
// ContextWithClaims returns a new context carrying the given Claims.
func ContextWithClaims(ctx context.Context, c *Claims) context.Context {
return context.WithValue(ctx, claimsKey{}, c)
}
// ClaimsFromContext extracts Claims from the context. It returns nil if
// no claims are present.
func ClaimsFromContext(ctx context.Context) *Claims {
c, _ := ctx.Value(claimsKey{}).(*Claims)
return c
}

179
internal/auth/client.go Normal file
View File

@@ -0,0 +1,179 @@
package auth
import (
"bytes"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"encoding/hex"
"encoding/json"
"fmt"
"net/http"
"os"
"strings"
"time"
)
const cacheTTL = 30 * time.Second
// Client communicates with an MCIAS server for authentication and token
// validation. It caches successful validation results for 30 seconds.
type Client struct {
httpClient *http.Client
baseURL string
serviceName string
tags []string
cache *validationCache
}
// NewClient creates an auth Client that talks to the MCIAS server at
// serverURL. If caCert is non-empty, it is loaded as a PEM file and
// used as the only trusted root CA. TLS 1.3 is required for all HTTPS
// connections.
//
// For plain HTTP URLs (used in tests), TLS configuration is skipped.
func NewClient(serverURL, caCert, serviceName string, tags []string) (*Client, error) {
transport := &http.Transport{}
if !strings.HasPrefix(serverURL, "http://") {
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS13,
}
if caCert != "" {
pem, err := os.ReadFile(caCert) //nolint:gosec // CA cert path is operator-supplied
if err != nil {
return nil, fmt.Errorf("auth: read CA cert %s: %w", caCert, err)
}
pool := x509.NewCertPool()
if !pool.AppendCertsFromPEM(pem) {
return nil, fmt.Errorf("auth: no valid certificates in %s", caCert)
}
tlsCfg.RootCAs = pool
}
transport.TLSClientConfig = tlsCfg
}
return &Client{
httpClient: &http.Client{
Transport: transport,
Timeout: 10 * time.Second,
},
baseURL: strings.TrimRight(serverURL, "/"),
serviceName: serviceName,
tags: tags,
cache: newCache(cacheTTL),
}, nil
}
// loginRequest is the JSON body sent to MCIAS /v1/auth/login.
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
ServiceName string `json:"service_name"`
Tags []string `json:"tags,omitempty"`
}
// loginResponse is the JSON body returned by MCIAS /v1/auth/login.
type loginResponse struct {
Token string `json:"token"`
ExpiresIn int `json:"expires_in"`
}
// Login authenticates a user against MCIAS and returns a bearer token.
func (c *Client) Login(username, password string) (token string, expiresIn int, err error) {
body, err := json.Marshal(loginRequest{ //nolint:gosec // G117: password is intentionally sent to MCIAS for authentication
Username: username,
Password: password,
ServiceName: c.serviceName,
Tags: c.tags,
})
if err != nil {
return "", 0, fmt.Errorf("auth: marshal login request: %w", err)
}
resp, err := c.httpClient.Post(
c.baseURL+"/v1/auth/login",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return "", 0, fmt.Errorf("auth: MCIAS login: %w", ErrMCIASUnavailable)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return "", 0, ErrUnauthorized
}
var lr loginResponse
if err := json.NewDecoder(resp.Body).Decode(&lr); err != nil {
return "", 0, fmt.Errorf("auth: decode login response: %w", err)
}
return lr.Token, lr.ExpiresIn, nil
}
// validateRequest is the JSON body sent to MCIAS /v1/token/validate.
type validateRequest struct {
Token string `json:"token"`
}
// validateResponse is the JSON body returned by MCIAS /v1/token/validate.
type validateResponse struct {
Valid bool `json:"valid"`
Claims struct {
Subject string `json:"subject"`
AccountType string `json:"account_type"`
Roles []string `json:"roles"`
} `json:"claims"`
}
// ValidateToken checks a bearer token against MCIAS. Results are cached
// by SHA-256 hash for 30 seconds.
func (c *Client) ValidateToken(token string) (*Claims, error) {
h := sha256.Sum256([]byte(token))
tokenHash := hex.EncodeToString(h[:])
if claims, ok := c.cache.get(tokenHash); ok {
return claims, nil
}
body, err := json.Marshal(validateRequest{Token: token})
if err != nil {
return nil, fmt.Errorf("auth: marshal validate request: %w", err)
}
resp, err := c.httpClient.Post(
c.baseURL+"/v1/token/validate",
"application/json",
bytes.NewReader(body),
)
if err != nil {
return nil, fmt.Errorf("auth: MCIAS validate: %w", ErrMCIASUnavailable)
}
defer func() { _ = resp.Body.Close() }()
if resp.StatusCode != http.StatusOK {
return nil, ErrUnauthorized
}
var vr validateResponse
if err := json.NewDecoder(resp.Body).Decode(&vr); err != nil {
return nil, fmt.Errorf("auth: decode validate response: %w", err)
}
if !vr.Valid {
return nil, ErrUnauthorized
}
claims := &Claims{
Subject: vr.Claims.Subject,
AccountType: vr.Claims.AccountType,
Roles: vr.Claims.Roles,
}
c.cache.put(tokenHash, claims)
return claims, nil
}

View File

@@ -0,0 +1,220 @@
package auth
import (
"encoding/json"
"errors"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"time"
)
// newTestServer starts an httptest.Server that routes MCIAS endpoints.
// The handler functions are pluggable per test.
func newTestServer(t *testing.T, loginHandler, validateHandler http.HandlerFunc) *httptest.Server {
t.Helper()
mux := http.NewServeMux()
if loginHandler != nil {
mux.HandleFunc("/v1/auth/login", loginHandler)
}
if validateHandler != nil {
mux.HandleFunc("/v1/token/validate", validateHandler)
}
srv := httptest.NewServer(mux)
t.Cleanup(srv.Close)
return srv
}
func newTestClient(t *testing.T, serverURL string) *Client {
t.Helper()
c, err := NewClient(serverURL, "", "mcr-test", []string{"env:test"})
if err != nil {
t.Fatalf("NewClient: %v", err)
}
return c
}
func TestLoginSuccess(t *testing.T) {
srv := newTestServer(t, func(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(loginResponse{
Token: "tok-abc",
ExpiresIn: 3600,
})
}, nil)
c := newTestClient(t, srv.URL)
token, expiresIn, err := c.Login("alice", "secret")
if err != nil {
t.Fatalf("Login: %v", err)
}
if token != "tok-abc" {
t.Fatalf("token: got %q, want %q", token, "tok-abc")
}
if expiresIn != 3600 {
t.Fatalf("expiresIn: got %d, want %d", expiresIn, 3600)
}
}
func TestLoginFailure(t *testing.T) {
srv := newTestServer(t, func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
}, nil)
c := newTestClient(t, srv.URL)
_, _, err := c.Login("alice", "wrong")
if !errors.Is(err, ErrUnauthorized) {
t.Fatalf("Login error: got %v, want %v", err, ErrUnauthorized)
}
}
func TestValidateSuccess(t *testing.T) {
srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(validateResponse{
Valid: true,
Claims: struct {
Subject string `json:"subject"`
AccountType string `json:"account_type"`
Roles []string `json:"roles"`
}{
Subject: "alice",
AccountType: "user",
Roles: []string{"reader", "writer"},
},
})
})
c := newTestClient(t, srv.URL)
claims, err := c.ValidateToken("valid-token-123")
if err != nil {
t.Fatalf("ValidateToken: %v", err)
}
if claims.Subject != "alice" {
t.Fatalf("subject: got %q, want %q", claims.Subject, "alice")
}
if claims.AccountType != "user" {
t.Fatalf("account_type: got %q, want %q", claims.AccountType, "user")
}
if len(claims.Roles) != 2 || claims.Roles[0] != "reader" || claims.Roles[1] != "writer" {
t.Fatalf("roles: got %v, want [reader writer]", claims.Roles)
}
}
func TestValidateRevoked(t *testing.T) {
srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(validateResponse{Valid: false})
})
c := newTestClient(t, srv.URL)
_, err := c.ValidateToken("revoked-token")
if !errors.Is(err, ErrUnauthorized) {
t.Fatalf("ValidateToken error: got %v, want %v", err, ErrUnauthorized)
}
}
func TestValidateCacheHit(t *testing.T) {
var callCount atomic.Int64
srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) {
callCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(validateResponse{
Valid: true,
Claims: struct {
Subject string `json:"subject"`
AccountType string `json:"account_type"`
Roles []string `json:"roles"`
}{
Subject: "bob",
AccountType: "service",
Roles: []string{"admin"},
},
})
})
c := newTestClient(t, srv.URL)
// First call — should hit the server.
claims1, err := c.ValidateToken("cached-token")
if err != nil {
t.Fatalf("first ValidateToken: %v", err)
}
if callCount.Load() != 1 {
t.Fatalf("expected 1 server call after first validate, got %d", callCount.Load())
}
// Second call — should come from cache.
claims2, err := c.ValidateToken("cached-token")
if err != nil {
t.Fatalf("second ValidateToken: %v", err)
}
if callCount.Load() != 1 {
t.Fatalf("expected 1 server call after second validate (cache hit), got %d", callCount.Load())
}
if claims1.Subject != claims2.Subject {
t.Fatalf("cached claims mismatch: %q vs %q", claims1.Subject, claims2.Subject)
}
}
func TestValidateCacheExpiry(t *testing.T) {
var callCount atomic.Int64
srv := newTestServer(t, nil, func(w http.ResponseWriter, _ *http.Request) {
callCount.Add(1)
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(validateResponse{
Valid: true,
Claims: struct {
Subject string `json:"subject"`
AccountType string `json:"account_type"`
Roles []string `json:"roles"`
}{
Subject: "charlie",
AccountType: "user",
Roles: nil,
},
})
})
c := newTestClient(t, srv.URL)
// Inject a controllable clock.
now := time.Now()
c.cache.now = func() time.Time { return now }
// First call.
if _, err := c.ValidateToken("expiry-token"); err != nil {
t.Fatalf("first ValidateToken: %v", err)
}
if callCount.Load() != 1 {
t.Fatalf("expected 1 server call, got %d", callCount.Load())
}
// Second call within TTL — cache hit.
if _, err := c.ValidateToken("expiry-token"); err != nil {
t.Fatalf("second ValidateToken: %v", err)
}
if callCount.Load() != 1 {
t.Fatalf("expected 1 server call (cache hit), got %d", callCount.Load())
}
// Advance clock past the 30s TTL.
c.cache.now = func() time.Time { return now.Add(31 * time.Second) }
// Third call — cache miss, should hit server again.
if _, err := c.ValidateToken("expiry-token"); err != nil {
t.Fatalf("third ValidateToken: %v", err)
}
if callCount.Load() != 2 {
t.Fatalf("expected 2 server calls after cache expiry, got %d", callCount.Load())
}
}

8
internal/auth/errors.go Normal file
View File

@@ -0,0 +1,8 @@
package auth
import "errors"
var (
ErrUnauthorized = errors.New("auth: unauthorized")
ErrMCIASUnavailable = errors.New("auth: MCIAS unavailable")
)

View File

@@ -0,0 +1,59 @@
package server
import (
"fmt"
"net/http"
"strings"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
// TokenValidator abstracts token validation so the middleware can work
// with the real MCIAS client or a test fake.
type TokenValidator interface {
ValidateToken(token string) (*auth.Claims, error)
}
// RequireAuth returns middleware that validates Bearer tokens via the
// given TokenValidator. On success the authenticated Claims are injected
// into the request context. On failure a 401 with an OCI-format error
// body and a WWW-Authenticate header is returned.
func RequireAuth(validator TokenValidator, serviceName string) func(http.Handler) http.Handler {
wwwAuth := fmt.Sprintf(`Bearer realm="/v2/token",service="%s"`, serviceName)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractBearerToken(r)
if token == "" {
w.Header().Set("WWW-Authenticate", wwwAuth)
writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required")
return
}
claims, err := validator.ValidateToken(token)
if err != nil {
w.Header().Set("WWW-Authenticate", wwwAuth)
writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication required")
return
}
ctx := auth.ContextWithClaims(r.Context(), claims)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// extractBearerToken parses a "Bearer <token>" value from the
// Authorization header. It returns an empty string if the header is
// missing or malformed.
func extractBearerToken(r *http.Request) string {
h := r.Header.Get("Authorization")
if h == "" {
return ""
}
const prefix = "Bearer "
if !strings.HasPrefix(h, prefix) {
return ""
}
return strings.TrimSpace(h[len(prefix):])
}

View File

@@ -0,0 +1,112 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
type fakeValidator struct {
claims *auth.Claims
err error
}
func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) {
return f.claims, f.err
}
func TestRequireAuthValid(t *testing.T) {
t.Helper()
claims := &auth.Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}}
validator := &fakeValidator{claims: claims}
var gotClaims *auth.Claims
inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
gotClaims = auth.ClaimsFromContext(r.Context())
})
handler := RequireAuth(validator, "mcr-test")(inner)
req := httptest.NewRequest(http.MethodGet, "/v2/", nil)
req.Header.Set("Authorization", "Bearer valid-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK)
}
if gotClaims == nil {
t.Fatal("expected claims in context, got nil")
}
if gotClaims.Subject != "alice" {
t.Fatalf("subject: got %q, want %q", gotClaims.Subject, "alice")
}
}
func TestRequireAuthMissing(t *testing.T) {
t.Helper()
validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized}
inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
t.Fatal("inner handler should not be called")
})
handler := RequireAuth(validator, "mcr-test")(inner)
req := httptest.NewRequest(http.MethodGet, "/v2/", nil)
// No Authorization header.
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
wwwAuth := rec.Header().Get("WWW-Authenticate")
want := `Bearer realm="/v2/token",service="mcr-test"`
if wwwAuth != want {
t.Fatalf("WWW-Authenticate: got %q, want %q", wwwAuth, want)
}
var ociErr ociErrorResponse
if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil {
t.Fatalf("decode OCI error: %v", err)
}
if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" {
t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors)
}
}
func TestRequireAuthInvalid(t *testing.T) {
t.Helper()
validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized}
inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
t.Fatal("inner handler should not be called")
})
handler := RequireAuth(validator, "mcr-test")(inner)
req := httptest.NewRequest(http.MethodGet, "/v2/", nil)
req.Header.Set("Authorization", "Bearer bad-token")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
var ociErr ociErrorResponse
if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil {
t.Fatalf("decode OCI error: %v", err)
}
if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" {
t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors)
}
}

View File

@@ -0,0 +1,23 @@
package server
import (
"encoding/json"
"net/http"
)
type ociErrorEntry struct {
Code string `json:"code"`
Message string `json:"message"`
}
type ociErrorResponse struct {
Errors []ociErrorEntry `json:"errors"`
}
func writeOCIError(w http.ResponseWriter, code string, status int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(ociErrorResponse{
Errors: []ociErrorEntry{{Code: code, Message: message}},
})
}

21
internal/server/routes.go Normal file
View File

@@ -0,0 +1,21 @@
package server
import "github.com/go-chi/chi/v5"
// NewRouter builds the chi router with all OCI Distribution Spec
// endpoints and auth middleware wired up.
func NewRouter(validator TokenValidator, loginClient LoginClient, serviceName string) *chi.Mux {
r := chi.NewRouter()
// Token endpoint is NOT behind RequireAuth — clients use Basic auth
// here to obtain a bearer token.
r.Get("/v2/token", TokenHandler(loginClient))
// All other /v2 endpoints require a valid bearer token.
r.Route("/v2", func(v2 chi.Router) {
v2.Use(RequireAuth(validator, serviceName))
v2.Get("/", V2Handler())
})
return r
}

View File

@@ -0,0 +1,94 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
func TestRoutesV2Authenticated(t *testing.T) {
t.Helper()
validator := &fakeValidator{
claims: &auth.Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}},
}
loginClient := &fakeLoginClient{token: "tok-abc", expiresIn: 3600}
router := NewRouter(validator, loginClient, "mcr-test")
req := httptest.NewRequest(http.MethodGet, "/v2/", nil)
req.Header.Set("Authorization", "Bearer valid-token")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK)
}
ct := rec.Header().Get("Content-Type")
if ct != "application/json" {
t.Fatalf("Content-Type: got %q, want %q", ct, "application/json")
}
body := rec.Body.String()
if body != "{}" {
t.Fatalf("body: got %q, want %q", body, "{}")
}
}
func TestRoutesV2Unauthenticated(t *testing.T) {
t.Helper()
validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized}
loginClient := &fakeLoginClient{}
router := NewRouter(validator, loginClient, "mcr-test")
req := httptest.NewRequest(http.MethodGet, "/v2/", nil)
// No Authorization header.
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
wwwAuth := rec.Header().Get("WWW-Authenticate")
want := `Bearer realm="/v2/token",service="mcr-test"`
if wwwAuth != want {
t.Fatalf("WWW-Authenticate: got %q, want %q", wwwAuth, want)
}
}
func TestRoutesTokenEndpoint(t *testing.T) {
t.Helper()
// The validator should never be called for /v2/token.
validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized}
loginClient := &fakeLoginClient{token: "tok-from-login", expiresIn: 1800}
router := NewRouter(validator, loginClient, "mcr-test")
req := httptest.NewRequest(http.MethodGet, "/v2/token", nil)
req.SetBasicAuth("bob", "password")
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK)
}
var resp tokenResponse
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if resp.Token != "tok-from-login" {
t.Fatalf("token: got %q, want %q", resp.Token, "tok-from-login")
}
if resp.ExpiresIn != 1800 {
t.Fatalf("expires_in: got %d, want %d", resp.ExpiresIn, 1800)
}
if resp.IssuedAt == "" {
t.Fatal("issued_at: expected non-empty RFC 3339 timestamp")
}
}

45
internal/server/token.go Normal file
View File

@@ -0,0 +1,45 @@
package server
import (
"encoding/json"
"net/http"
"time"
)
// LoginClient abstracts the MCIAS login call so the handler can work
// with the real client or a test fake.
type LoginClient interface {
Login(username, password string) (token string, expiresIn int, err error)
}
// tokenResponse is the JSON body returned by the token endpoint.
type tokenResponse struct {
Token string `json:"token"`
ExpiresIn int `json:"expires_in"`
IssuedAt string `json:"issued_at"`
}
// TokenHandler returns an http.HandlerFunc that exchanges Basic
// credentials for a bearer token via the given LoginClient.
func TokenHandler(loginClient LoginClient) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok || username == "" {
writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "basic authentication required")
return
}
token, expiresIn, err := loginClient.Login(username, password)
if err != nil {
writeOCIError(w, "UNAUTHORIZED", http.StatusUnauthorized, "authentication failed")
return
}
w.Header().Set("Content-Type", "application/json")
_ = json.NewEncoder(w).Encode(tokenResponse{
Token: token,
ExpiresIn: expiresIn,
IssuedAt: time.Now().UTC().Format(time.RFC3339),
})
}
}

View File

@@ -0,0 +1,98 @@
package server
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"git.wntrmute.dev/kyle/mcr/internal/auth"
)
type fakeLoginClient struct {
token string
expiresIn int
err error
}
func (f *fakeLoginClient) Login(_, _ string) (string, int, error) {
return f.token, f.expiresIn, f.err
}
func TestTokenHandlerSuccess(t *testing.T) {
t.Helper()
lc := &fakeLoginClient{token: "tok-xyz", expiresIn: 7200}
handler := TokenHandler(lc)
req := httptest.NewRequest(http.MethodGet, "/v2/token", nil)
req.SetBasicAuth("alice", "secret")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK)
}
var resp tokenResponse
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if resp.Token != "tok-xyz" {
t.Fatalf("token: got %q, want %q", resp.Token, "tok-xyz")
}
if resp.ExpiresIn != 7200 {
t.Fatalf("expires_in: got %d, want %d", resp.ExpiresIn, 7200)
}
if resp.IssuedAt == "" {
t.Fatal("issued_at: expected non-empty RFC 3339 timestamp")
}
}
func TestTokenHandlerInvalidCreds(t *testing.T) {
t.Helper()
lc := &fakeLoginClient{err: auth.ErrUnauthorized}
handler := TokenHandler(lc)
req := httptest.NewRequest(http.MethodGet, "/v2/token", nil)
req.SetBasicAuth("alice", "wrong")
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
var ociErr ociErrorResponse
if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil {
t.Fatalf("decode OCI error: %v", err)
}
if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" {
t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors)
}
}
func TestTokenHandlerMissingAuth(t *testing.T) {
t.Helper()
lc := &fakeLoginClient{token: "should-not-matter"}
handler := TokenHandler(lc)
req := httptest.NewRequest(http.MethodGet, "/v2/token", nil)
// No Authorization header.
rec := httptest.NewRecorder()
handler.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized)
}
var ociErr ociErrorResponse
if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil {
t.Fatalf("decode OCI error: %v", err)
}
if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" {
t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors)
}
}

13
internal/server/v2.go Normal file
View File

@@ -0,0 +1,13 @@
package server
import "net/http"
// V2Handler returns an http.HandlerFunc that responds with 200 OK and
// an empty JSON object, per the OCI Distribution Spec version check.
func V2Handler() http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
}
}

View File

@@ -0,0 +1,9 @@
package storage
import "errors"
var (
ErrBlobNotFound = errors.New("storage: blob not found")
ErrDigestMismatch = errors.New("storage: digest mismatch")
ErrInvalidDigest = errors.New("storage: invalid digest format")
)

View File

@@ -0,0 +1,75 @@
package storage
import (
"errors"
"fmt"
"io"
"os"
"path/filepath"
)
// Open validates the digest and returns a ReadCloser for the blob.
// Returns ErrBlobNotFound if the blob does not exist on disk.
func (s *Store) Open(digest string) (io.ReadCloser, error) {
if err := validateDigest(digest); err != nil {
return nil, err
}
f, err := os.Open(s.blobPath(digest))
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil, ErrBlobNotFound
}
return nil, fmt.Errorf("storage: open blob: %w", err)
}
return f, nil
}
// Stat returns the size of the blob in bytes.
// Returns ErrBlobNotFound if the blob does not exist on disk.
func (s *Store) Stat(digest string) (int64, error) {
if err := validateDigest(digest); err != nil {
return 0, err
}
info, err := os.Stat(s.blobPath(digest))
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return 0, ErrBlobNotFound
}
return 0, fmt.Errorf("storage: stat blob: %w", err)
}
return info.Size(), nil
}
// Delete removes the blob file and attempts to clean up its prefix
// directory. Non-empty or already-removed prefix directories are
// silently ignored.
func (s *Store) Delete(digest string) error {
if err := validateDigest(digest); err != nil {
return err
}
path := s.blobPath(digest)
if err := os.Remove(path); err != nil {
if errors.Is(err, os.ErrNotExist) {
return ErrBlobNotFound
}
return fmt.Errorf("storage: delete blob: %w", err)
}
// Best-effort cleanup of the prefix directory.
_ = os.Remove(filepath.Dir(path))
return nil
}
// Exists reports whether the blob exists on disk.
func (s *Store) Exists(digest string) bool {
if err := validateDigest(digest); err != nil {
return false
}
_, err := os.Stat(s.blobPath(digest))
return err == nil
}

View File

@@ -0,0 +1,107 @@
package storage
import (
"errors"
"io"
"testing"
)
func TestOpenAfterWrite(t *testing.T) {
s := newTestStore(t)
data := []byte("readable blob content")
digest := writeTestBlob(t, s, data)
rc, err := s.Open(digest)
if err != nil {
t.Fatalf("Open: %v", err)
}
defer func() { _ = rc.Close() }()
got, err := io.ReadAll(rc)
if err != nil {
t.Fatalf("ReadAll: %v", err)
}
if string(got) != string(data) {
t.Fatalf("Open content: got %q, want %q", got, data)
}
}
func TestStatAfterWrite(t *testing.T) {
s := newTestStore(t)
data := []byte("stat this blob")
digest := writeTestBlob(t, s, data)
size, err := s.Stat(digest)
if err != nil {
t.Fatalf("Stat: %v", err)
}
if size != int64(len(data)) {
t.Fatalf("Stat size: got %d, want %d", size, len(data))
}
}
func TestExists(t *testing.T) {
s := newTestStore(t)
data := []byte("existence check")
digest := writeTestBlob(t, s, data)
if !s.Exists(digest) {
t.Fatal("Exists returned false for written blob")
}
nonexistent := "sha256:0000000000000000000000000000000000000000000000000000000000000000"
if s.Exists(nonexistent) {
t.Fatal("Exists returned true for nonexistent blob")
}
}
func TestDelete(t *testing.T) {
s := newTestStore(t)
data := []byte("delete me")
digest := writeTestBlob(t, s, data)
if err := s.Delete(digest); err != nil {
t.Fatalf("Delete: %v", err)
}
if s.Exists(digest) {
t.Fatal("Exists returned true after Delete")
}
_, err := s.Open(digest)
if !errors.Is(err, ErrBlobNotFound) {
t.Fatalf("Open after Delete: got %v, want ErrBlobNotFound", err)
}
}
func TestOpenNotFound(t *testing.T) {
s := newTestStore(t)
digest := "sha256:aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
_, err := s.Open(digest)
if !errors.Is(err, ErrBlobNotFound) {
t.Fatalf("Open nonexistent: got %v, want ErrBlobNotFound", err)
}
}
func TestInvalidDigestFormat(t *testing.T) {
s := newTestStore(t)
bad := "not-a-digest"
if _, err := s.Open(bad); !errors.Is(err, ErrInvalidDigest) {
t.Fatalf("Open with bad digest: got %v, want ErrInvalidDigest", err)
}
if _, err := s.Stat(bad); !errors.Is(err, ErrInvalidDigest) {
t.Fatalf("Stat with bad digest: got %v, want ErrInvalidDigest", err)
}
if err := s.Delete(bad); !errors.Is(err, ErrInvalidDigest) {
t.Fatalf("Delete with bad digest: got %v, want ErrInvalidDigest", err)
}
// Exists should return false for an invalid digest, not panic.
if s.Exists(bad) {
t.Fatal("Exists returned true for invalid digest")
}
}

View File

@@ -0,0 +1,38 @@
package storage
import (
"path/filepath"
"regexp"
)
var digestRe = regexp.MustCompile(`^sha256:[a-f0-9]{64}$`)
// Store manages blob storage on the local filesystem.
type Store struct {
layersPath string
uploadsPath string
}
// New creates a Store that will write final blobs under layersPath and
// stage in-progress uploads under uploadsPath.
func New(layersPath, uploadsPath string) *Store {
return &Store{
layersPath: layersPath,
uploadsPath: uploadsPath,
}
}
// validateDigest checks that digest matches sha256:<64 lowercase hex chars>.
func validateDigest(digest string) error {
if !digestRe.MatchString(digest) {
return ErrInvalidDigest
}
return nil
}
// blobPath returns the filesystem path for a blob with the given digest.
// The layout is: <layersPath>/sha256/<first-2-hex>/<full-64-hex>
func (s *Store) blobPath(digest string) string {
hex := digest[len("sha256:"):]
return filepath.Join(s.layersPath, "sha256", hex[0:2], hex)
}

View File

@@ -0,0 +1,68 @@
package storage
import (
"errors"
"path/filepath"
"testing"
)
func newTestStore(t *testing.T) *Store {
t.Helper()
dir := t.TempDir()
return New(filepath.Join(dir, "layers"), filepath.Join(dir, "uploads"))
}
func TestNew(t *testing.T) {
s := newTestStore(t)
if s == nil {
t.Fatal("New returned nil")
}
if s.layersPath == "" {
t.Fatal("layersPath is empty")
}
if s.uploadsPath == "" {
t.Fatal("uploadsPath is empty")
}
}
func TestValidateDigest(t *testing.T) {
valid := []string{
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855",
"sha256:0000000000000000000000000000000000000000000000000000000000000000",
"sha256:abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789",
}
for _, d := range valid {
if err := validateDigest(d); err != nil {
t.Errorf("validateDigest(%q) = %v, want nil", d, err)
}
}
invalid := []string{
"",
"sha256:",
"sha256:abc",
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b85", // 63 chars
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b8555", // 65 chars
"sha256:E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855", // uppercase
"md5:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // wrong algo
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", // missing prefix
"sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b85g", // non-hex char
}
for _, d := range invalid {
if err := validateDigest(d); !errors.Is(err, ErrInvalidDigest) {
t.Errorf("validateDigest(%q) = %v, want ErrInvalidDigest", d, err)
}
}
}
func TestBlobPath(t *testing.T) {
s := New("/data/layers", "/data/uploads")
digest := "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
got := s.blobPath(digest)
want := filepath.Join("/data/layers", "sha256", "e3",
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855")
if got != want {
t.Fatalf("blobPath(%q)\n got %q\nwant %q", digest, got, want)
}
}

107
internal/storage/writer.go Normal file
View File

@@ -0,0 +1,107 @@
package storage
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"hash"
"io"
"os"
"path/filepath"
)
// BlobWriter stages blob data in a temporary file while computing its
// SHA-256 digest on the fly.
type BlobWriter struct {
file *os.File
hash hash.Hash
mw io.Writer
path string
written int64
closed bool
store *Store
}
// StartUpload begins a new blob upload, creating a temp file at
// <uploadsPath>/<uuid>.
func (s *Store) StartUpload(uuid string) (*BlobWriter, error) {
if err := os.MkdirAll(s.uploadsPath, 0700); err != nil {
return nil, fmt.Errorf("storage: create uploads dir: %w", err)
}
path := filepath.Join(s.uploadsPath, uuid)
f, err := os.Create(path) //nolint:gosec // upload UUID is server-generated, not user input
if err != nil {
return nil, fmt.Errorf("storage: create upload file: %w", err)
}
h := sha256.New()
return &BlobWriter{
file: f,
hash: h,
mw: io.MultiWriter(f, h),
path: path,
store: s,
}, nil
}
// Write writes p to both the staging file and the running hash.
func (bw *BlobWriter) Write(p []byte) (int, error) {
n, err := bw.mw.Write(p)
bw.written += int64(n)
if err != nil {
return n, fmt.Errorf("storage: write: %w", err)
}
return n, nil
}
// Commit finalises the upload. It closes the staging file, verifies
// the computed digest matches expectedDigest, and atomically moves
// the file to its content-addressed location.
func (bw *BlobWriter) Commit(expectedDigest string) (string, error) {
if !bw.closed {
bw.closed = true
if err := bw.file.Close(); err != nil {
return "", fmt.Errorf("storage: close upload file: %w", err)
}
}
if err := validateDigest(expectedDigest); err != nil {
_ = os.Remove(bw.path)
return "", err
}
computed := "sha256:" + hex.EncodeToString(bw.hash.Sum(nil))
if computed != expectedDigest {
_ = os.Remove(bw.path)
return "", ErrDigestMismatch
}
dst := bw.store.blobPath(computed)
if err := os.MkdirAll(filepath.Dir(dst), 0700); err != nil {
return "", fmt.Errorf("storage: create blob dir: %w", err)
}
if err := os.Rename(bw.path, dst); err != nil {
return "", fmt.Errorf("storage: rename blob: %w", err)
}
return computed, nil
}
// Cancel aborts the upload, closing and removing the temp file.
func (bw *BlobWriter) Cancel() error {
if !bw.closed {
bw.closed = true
_ = bw.file.Close()
}
if err := os.Remove(bw.path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("storage: remove upload file: %w", err)
}
return nil
}
// BytesWritten returns the number of bytes written so far.
func (bw *BlobWriter) BytesWritten() int64 {
return bw.written
}

View File

@@ -0,0 +1,188 @@
package storage
import (
"crypto/sha256"
"encoding/hex"
"errors"
"fmt"
"os"
"sync"
"testing"
"time"
)
func writeTestBlob(t *testing.T, s *Store, data []byte) string {
t.Helper()
uuid := "test-upload-" + fmt.Sprintf("%d", time.Now().UnixNano())
w, err := s.StartUpload(uuid)
if err != nil {
t.Fatalf("StartUpload: %v", err)
}
if _, err := w.Write(data); err != nil {
t.Fatalf("Write: %v", err)
}
h := sha256.Sum256(data)
digest := "sha256:" + hex.EncodeToString(h[:])
got, err := w.Commit(digest)
if err != nil {
t.Fatalf("Commit: %v", err)
}
if got != digest {
t.Fatalf("Commit returned %q, want %q", got, digest)
}
return digest
}
func TestWriteAndCommit(t *testing.T) {
s := newTestStore(t)
data := []byte("hello, blob storage")
digest := writeTestBlob(t, s, data)
// Verify file exists at expected path.
path := s.blobPath(digest)
info, err := os.Stat(path)
if err != nil {
t.Fatalf("stat blob file: %v", err)
}
if info.Size() != int64(len(data)) {
t.Fatalf("blob size: got %d, want %d", info.Size(), len(data))
}
// Verify content.
content, err := os.ReadFile(path) //nolint:gosec // test file path from t.TempDir()
if err != nil {
t.Fatalf("read blob file: %v", err)
}
if string(content) != string(data) {
t.Fatalf("blob content: got %q, want %q", content, data)
}
}
func TestDigestMismatch(t *testing.T) {
s := newTestStore(t)
data := []byte("some data")
uuid := "mismatch-upload"
w, err := s.StartUpload(uuid)
if err != nil {
t.Fatalf("StartUpload: %v", err)
}
if _, err := w.Write(data); err != nil {
t.Fatalf("Write: %v", err)
}
wrongDigest := "sha256:0000000000000000000000000000000000000000000000000000000000000000"
_, err = w.Commit(wrongDigest)
if !errors.Is(err, ErrDigestMismatch) {
t.Fatalf("Commit with wrong digest: got %v, want ErrDigestMismatch", err)
}
// Verify temp file was cleaned up.
tempPath := w.path
if _, err := os.Stat(tempPath); !os.IsNotExist(err) {
t.Fatalf("temp file should be removed after digest mismatch, stat err: %v", err)
}
}
func TestCancel(t *testing.T) {
s := newTestStore(t)
uuid := "cancel-upload"
w, err := s.StartUpload(uuid)
if err != nil {
t.Fatalf("StartUpload: %v", err)
}
if _, err := w.Write([]byte("partial data")); err != nil {
t.Fatalf("Write: %v", err)
}
tempPath := w.path
if err := w.Cancel(); err != nil {
t.Fatalf("Cancel: %v", err)
}
if _, err := os.Stat(tempPath); !os.IsNotExist(err) {
t.Fatalf("temp file should be removed after Cancel, stat err: %v", err)
}
}
func TestBytesWritten(t *testing.T) {
s := newTestStore(t)
uuid := "bytes-upload"
w, err := s.StartUpload(uuid)
if err != nil {
t.Fatalf("StartUpload: %v", err)
}
t.Cleanup(func() { _ = w.Cancel() })
if w.BytesWritten() != 0 {
t.Fatalf("BytesWritten before write: got %d, want 0", w.BytesWritten())
}
data1 := []byte("first chunk")
if _, err := w.Write(data1); err != nil {
t.Fatalf("Write: %v", err)
}
if w.BytesWritten() != int64(len(data1)) {
t.Fatalf("BytesWritten after first write: got %d, want %d", w.BytesWritten(), len(data1))
}
data2 := []byte("second chunk")
if _, err := w.Write(data2); err != nil {
t.Fatalf("Write: %v", err)
}
want := int64(len(data1) + len(data2))
if w.BytesWritten() != want {
t.Fatalf("BytesWritten after second write: got %d, want %d", w.BytesWritten(), want)
}
}
func TestConcurrentWrites(t *testing.T) {
s := newTestStore(t)
type result struct {
digest string
err error
}
blobs := [][]byte{
[]byte("concurrent blob alpha"),
[]byte("concurrent blob beta"),
}
results := make([]result, len(blobs))
var wg sync.WaitGroup
for i, data := range blobs {
wg.Add(1)
go func(idx int, d []byte) {
defer wg.Done()
uuid := fmt.Sprintf("concurrent-%d-%d", idx, time.Now().UnixNano())
w, err := s.StartUpload(uuid)
if err != nil {
results[idx] = result{err: err}
return
}
if _, err := w.Write(d); err != nil {
results[idx] = result{err: err}
return
}
h := sha256.Sum256(d)
digest := "sha256:" + hex.EncodeToString(h[:])
got, err := w.Commit(digest)
results[idx] = result{digest: got, err: err}
}(i, data)
}
wg.Wait()
for i, r := range results {
if r.err != nil {
t.Fatalf("blob %d: %v", i, r.err)
}
if !s.Exists(r.digest) {
t.Fatalf("blob %d: digest %q not found after concurrent write", i, r.digest)
}
}
}