Phases 5, 6, 8: OCI pull/push paths and admin REST API
Phase 5 (OCI pull): internal/oci/ package with manifest GET/HEAD by tag/digest, blob GET/HEAD with repo membership check, tag listing with OCI pagination, catalog listing. Multi-segment repo names via parseOCIPath() right-split routing. DB query layer in internal/db/repository.go. Phase 6 (OCI push): blob uploads (monolithic and chunked) with uploadManager tracking in-progress BlobWriters, manifest push implementing full ARCHITECTURE.md §5 flow in a single SQLite transaction (create repo, upsert manifest, populate manifest_blobs, atomic tag move). Digest verification on both blob commit and manifest push-by-digest. Phase 8 (admin REST): /v1 endpoints for auth (login/logout/health), repository management (list/detail/delete), policy CRUD with engine reload, audit log listing with filters, GC trigger/status stubs. RequireAdmin middleware, platform-standard error format. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
52
internal/server/admin.go
Normal file
52
internal/server/admin.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
)
|
||||
|
||||
type adminErrorResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func writeAdminError(w http.ResponseWriter, status int, message string) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(adminErrorResponse{Error: message})
|
||||
}
|
||||
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// RequireAdmin returns middleware that checks for the admin role.
|
||||
// Returns 403 with an admin error format if the caller is not an admin.
|
||||
func RequireAdmin() func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
claims := auth.ClaimsFromContext(r.Context())
|
||||
if claims == nil {
|
||||
writeAdminError(w, http.StatusUnauthorized, "authentication required")
|
||||
return
|
||||
}
|
||||
if !hasRole(claims.Roles, "admin") {
|
||||
writeAdminError(w, http.StatusForbidden, "admin role required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func hasRole(roles []string, target string) bool {
|
||||
for _, r := range roles {
|
||||
if r == target {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
48
internal/server/admin_audit.go
Normal file
48
internal/server/admin_audit.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// AdminListAuditHandler handles GET /v1/audit.
|
||||
func AdminListAuditHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
q := r.URL.Query()
|
||||
|
||||
limit := 50
|
||||
if n := q.Get("n"); n != "" {
|
||||
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||
limit = v
|
||||
}
|
||||
}
|
||||
offset := 0
|
||||
if o := q.Get("offset"); o != "" {
|
||||
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||
offset = v
|
||||
}
|
||||
}
|
||||
|
||||
filter := db.AuditFilter{
|
||||
EventType: q.Get("event_type"),
|
||||
ActorID: q.Get("actor_id"),
|
||||
Repository: q.Get("repository"),
|
||||
Since: q.Get("since"),
|
||||
Until: q.Get("until"),
|
||||
Limit: limit,
|
||||
Offset: offset,
|
||||
}
|
||||
|
||||
events, err := database.ListAuditEvents(filter)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
if events == nil {
|
||||
events = []db.AuditEvent{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, events)
|
||||
}
|
||||
}
|
||||
152
internal/server/admin_audit_test.go
Normal file
152
internal/server/admin_audit_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
func TestAdminListAuditEvents(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// Seed some audit events.
|
||||
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "10.0.0.1", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "sha256:abc", "10.0.0.1",
|
||||
map[string]string{"tag": "latest"}); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "10.0.0.2", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var events []db.AuditEvent
|
||||
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(events) != 3 {
|
||||
t.Fatalf("event count: got %d, want 3", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListAuditEventsWithFilter(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent("manifest_pushed", "user-1", "myapp", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent("login_ok", "user-2", "", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
|
||||
// Filter by event_type.
|
||||
rr := adminReq(t, router, "GET", "/v1/audit?event_type=login_ok", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var events []db.AuditEvent
|
||||
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count: got %d, want 2", len(events))
|
||||
}
|
||||
for _, e := range events {
|
||||
if e.EventType != "login_ok" {
|
||||
t.Fatalf("unexpected event type: %q", e.EventType)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListAuditEventsFilterByActor(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
if err := database.WriteAuditEvent("login_ok", "actor-a", "", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
if err := database.WriteAuditEvent("login_ok", "actor-b", "", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/audit?actor_id=actor-a", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var events []db.AuditEvent
|
||||
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(events) != 1 {
|
||||
t.Fatalf("event count: got %d, want 1", len(events))
|
||||
}
|
||||
if events[0].ActorID != "actor-a" {
|
||||
t.Fatalf("actor_id: got %q, want %q", events[0].ActorID, "actor-a")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListAuditEventsPagination(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
for range 5 {
|
||||
if err := database.WriteAuditEvent("login_ok", "user-1", "", "", "", nil); err != nil {
|
||||
t.Fatalf("WriteAuditEvent: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/audit?n=2&offset=0", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var events []db.AuditEvent
|
||||
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(events) != 2 {
|
||||
t.Fatalf("event count: got %d, want 2", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListAuditEventsEmpty(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var events []db.AuditEvent
|
||||
if err := json.NewDecoder(rr.Body).Decode(&events); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(events) != 0 {
|
||||
t.Fatalf("event count: got %d, want 0", len(events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListAuditEventsNonAdmin(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router := buildNonAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/audit", "")
|
||||
if rr.Code != 403 {
|
||||
t.Fatalf("status: got %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
60
internal/server/admin_auth.go
Normal file
60
internal/server/admin_auth.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
type adminLoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
type adminLoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
ExpiresAt string `json:"expires_at"`
|
||||
}
|
||||
|
||||
// AdminLoginHandler handles POST /v1/auth/login.
|
||||
func AdminLoginHandler(loginClient LoginClient) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req adminLoginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
if req.Username == "" || req.Password == "" {
|
||||
writeAdminError(w, http.StatusBadRequest, "username and password required")
|
||||
return
|
||||
}
|
||||
|
||||
token, expiresIn, err := loginClient.Login(req.Username, req.Password)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusUnauthorized, "authentication failed")
|
||||
return
|
||||
}
|
||||
|
||||
expiresAt := time.Now().UTC().Add(time.Duration(expiresIn) * time.Second).Format(time.RFC3339)
|
||||
writeJSON(w, http.StatusOK, adminLoginResponse{
|
||||
Token: token,
|
||||
ExpiresAt: expiresAt,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// AdminLogoutHandler handles POST /v1/auth/logout.
|
||||
func AdminLogoutHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
// MCIAS token revocation is not currently supported.
|
||||
// The client should discard the token.
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminHealthHandler handles GET /v1/health.
|
||||
func AdminHealthHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
|
||||
}
|
||||
}
|
||||
116
internal/server/admin_auth_test.go
Normal file
116
internal/server/admin_auth_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
)
|
||||
|
||||
func TestAdminHealthHandler(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// Health endpoint does not require auth.
|
||||
rr := adminReq(t, router, "GET", "/v1/health", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var resp map[string]string
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if resp["status"] != "ok" {
|
||||
t.Fatalf("status field: got %q, want %q", resp["status"], "ok")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLoginSuccess(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
body := `{"username":"admin","password":"secret"}`
|
||||
rr := adminReq(t, router, "POST", "/v1/auth/login", body)
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var resp adminLoginResponse
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if resp.Token != "test-token" {
|
||||
t.Fatalf("token: got %q, want %q", resp.Token, "test-token")
|
||||
}
|
||||
if resp.ExpiresAt == "" {
|
||||
t.Fatal("expires_at: expected non-empty")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLoginInvalidCreds(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||
}
|
||||
login := &fakeLoginClient{err: auth.ErrUnauthorized}
|
||||
reloader := &fakePolicyReloader{}
|
||||
gcState := &GCState{}
|
||||
|
||||
r := chi.NewRouter()
|
||||
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||
DB: database,
|
||||
Login: login,
|
||||
Engine: reloader,
|
||||
AuditFn: nil,
|
||||
GCState: gcState,
|
||||
})
|
||||
|
||||
body := `{"username":"admin","password":"wrong"}`
|
||||
rr := adminReq(t, r, "POST", "/v1/auth/login", body)
|
||||
if rr.Code != 401 {
|
||||
t.Fatalf("status: got %d, want 401", rr.Code)
|
||||
}
|
||||
|
||||
var errResp adminErrorResponse
|
||||
if err := json.NewDecoder(rr.Body).Decode(&errResp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if errResp.Error != "authentication failed" {
|
||||
t.Fatalf("error: got %q, want %q", errResp.Error, "authentication failed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLoginMissingFields(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
body := `{"username":"admin"}`
|
||||
rr := adminReq(t, router, "POST", "/v1/auth/login", body)
|
||||
if rr.Code != 400 {
|
||||
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLoginBadJSON(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "POST", "/v1/auth/login", "not json")
|
||||
if rr.Code != 400 {
|
||||
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminLogout(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "POST", "/v1/auth/logout", "")
|
||||
if rr.Code != 204 {
|
||||
t.Fatalf("status: got %d, want 204", rr.Code)
|
||||
}
|
||||
}
|
||||
66
internal/server/admin_gc.go
Normal file
66
internal/server/admin_gc.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// GCLastRun records the result of the last GC run.
|
||||
type GCLastRun struct {
|
||||
StartedAt string `json:"started_at"`
|
||||
CompletedAt string `json:"completed_at,omitempty"`
|
||||
BlobsRemoved int `json:"blobs_removed"`
|
||||
BytesFreed int64 `json:"bytes_freed"`
|
||||
}
|
||||
|
||||
// GCState tracks the current state of garbage collection.
|
||||
type GCState struct {
|
||||
mu sync.Mutex
|
||||
Running bool `json:"running"`
|
||||
LastRun *GCLastRun `json:"last_run,omitempty"`
|
||||
}
|
||||
|
||||
type gcStatusResponse struct {
|
||||
Running bool `json:"running"`
|
||||
LastRun *GCLastRun `json:"last_run,omitempty"`
|
||||
}
|
||||
|
||||
type gcTriggerResponse struct {
|
||||
ID string `json:"id"`
|
||||
}
|
||||
|
||||
// AdminTriggerGCHandler handles POST /v1/gc.
|
||||
func AdminTriggerGCHandler(state *GCState) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
state.mu.Lock()
|
||||
if state.Running {
|
||||
state.mu.Unlock()
|
||||
writeAdminError(w, http.StatusConflict, "garbage collection already running")
|
||||
return
|
||||
}
|
||||
state.Running = true
|
||||
state.mu.Unlock()
|
||||
|
||||
// GC engine is Phase 9 -- for now, just mark as running and return.
|
||||
// The actual GC goroutine will be wired up in Phase 9.
|
||||
gcID := uuid.New().String()
|
||||
|
||||
writeJSON(w, http.StatusAccepted, gcTriggerResponse{ID: gcID})
|
||||
}
|
||||
}
|
||||
|
||||
// AdminGCStatusHandler handles GET /v1/gc/status.
|
||||
func AdminGCStatusHandler(state *GCState) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
state.mu.Lock()
|
||||
resp := gcStatusResponse{
|
||||
Running: state.Running,
|
||||
LastRun: state.LastRun,
|
||||
}
|
||||
state.mu.Unlock()
|
||||
|
||||
writeJSON(w, http.StatusOK, resp)
|
||||
}
|
||||
}
|
||||
94
internal/server/admin_gc_test.go
Normal file
94
internal/server/admin_gc_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAdminTriggerGC(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||
if rr.Code != 202 {
|
||||
t.Fatalf("status: got %d, want 202; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var resp gcTriggerResponse
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if resp.ID == "" {
|
||||
t.Fatal("expected non-empty GC ID")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminTriggerGCConflict(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// First trigger should succeed.
|
||||
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||
if rr.Code != 202 {
|
||||
t.Fatalf("first trigger status: got %d, want 202", rr.Code)
|
||||
}
|
||||
|
||||
// Second trigger should conflict because GC is still "running".
|
||||
rr = adminReq(t, router, "POST", "/v1/gc", "")
|
||||
if rr.Code != 409 {
|
||||
t.Fatalf("second trigger status: got %d, want 409; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGCStatus(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// Before triggering GC.
|
||||
rr := adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var resp gcStatusResponse
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if resp.Running {
|
||||
t.Fatal("expected running=false before trigger")
|
||||
}
|
||||
|
||||
// Trigger GC.
|
||||
rr = adminReq(t, router, "POST", "/v1/gc", "")
|
||||
if rr.Code != 202 {
|
||||
t.Fatalf("trigger status: got %d, want 202", rr.Code)
|
||||
}
|
||||
|
||||
// After triggering GC.
|
||||
rr = adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(rr.Body).Decode(&resp); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if !resp.Running {
|
||||
t.Fatal("expected running=true after trigger")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGCNonAdmin(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router := buildNonAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "POST", "/v1/gc", "")
|
||||
if rr.Code != 403 {
|
||||
t.Fatalf("trigger status: got %d, want 403", rr.Code)
|
||||
}
|
||||
|
||||
rr = adminReq(t, router, "GET", "/v1/gc/status", "")
|
||||
if rr.Code != 403 {
|
||||
t.Fatalf("status status: got %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
343
internal/server/admin_policy.go
Normal file
343
internal/server/admin_policy.go
Normal file
@@ -0,0 +1,343 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||
)
|
||||
|
||||
// PolicyReloader can reload policy rules from a store.
|
||||
type PolicyReloader interface {
|
||||
Reload(store policy.RuleStore) error
|
||||
}
|
||||
|
||||
// policyCreateRequest is the JSON body for creating a policy rule.
|
||||
type policyCreateRequest struct {
|
||||
Priority int `json:"priority"`
|
||||
Description string `json:"description"`
|
||||
Effect string `json:"effect"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
AccountTypes []string `json:"account_types,omitempty"`
|
||||
SubjectUUID string `json:"subject_uuid,omitempty"`
|
||||
Actions []string `json:"actions"`
|
||||
Repositories []string `json:"repositories,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"` // pointer to distinguish unset from false
|
||||
}
|
||||
|
||||
// policyUpdateRequest is the JSON body for updating a policy rule.
|
||||
type policyUpdateRequest struct {
|
||||
Priority *int `json:"priority,omitempty"`
|
||||
Description *string `json:"description,omitempty"`
|
||||
Effect *string `json:"effect,omitempty"`
|
||||
Roles []string `json:"roles,omitempty"`
|
||||
AccountTypes []string `json:"account_types,omitempty"`
|
||||
SubjectUUID *string `json:"subject_uuid,omitempty"`
|
||||
Actions []string `json:"actions,omitempty"`
|
||||
Repositories []string `json:"repositories,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
}
|
||||
|
||||
var validActions = map[string]bool{
|
||||
string(policy.ActionVersionCheck): true,
|
||||
string(policy.ActionPull): true,
|
||||
string(policy.ActionPush): true,
|
||||
string(policy.ActionDelete): true,
|
||||
string(policy.ActionCatalog): true,
|
||||
string(policy.ActionPolicyManage): true,
|
||||
}
|
||||
|
||||
func validateActions(actions []string) error {
|
||||
for _, a := range actions {
|
||||
if !validActions[a] {
|
||||
return fmt.Errorf("invalid action: %q", a)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateEffect(effect string) error {
|
||||
if effect != "allow" && effect != "deny" {
|
||||
return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AdminListPolicyRulesHandler handles GET /v1/policy/rules.
|
||||
func AdminListPolicyRulesHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
limit := 50
|
||||
offset := 0
|
||||
if n := r.URL.Query().Get("n"); n != "" {
|
||||
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||
limit = v
|
||||
}
|
||||
}
|
||||
if o := r.URL.Query().Get("offset"); o != "" {
|
||||
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||
offset = v
|
||||
}
|
||||
}
|
||||
|
||||
rules, err := database.ListPolicyRules(limit, offset)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
if rules == nil {
|
||||
rules = []db.PolicyRuleRow{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, rules)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminCreatePolicyRuleHandler handles POST /v1/policy/rules.
|
||||
func AdminCreatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req policyCreateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Priority < 1 {
|
||||
writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||
return
|
||||
}
|
||||
if req.Description == "" {
|
||||
writeAdminError(w, http.StatusBadRequest, "description is required")
|
||||
return
|
||||
}
|
||||
if err := validateEffect(req.Effect); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
if len(req.Actions) == 0 {
|
||||
writeAdminError(w, http.StatusBadRequest, "at least one action is required")
|
||||
return
|
||||
}
|
||||
if err := validateActions(req.Actions); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
enabled := true
|
||||
if req.Enabled != nil {
|
||||
enabled = *req.Enabled
|
||||
}
|
||||
|
||||
claims := auth.ClaimsFromContext(r.Context())
|
||||
createdBy := ""
|
||||
if claims != nil {
|
||||
createdBy = claims.Subject
|
||||
}
|
||||
|
||||
row := db.PolicyRuleRow{
|
||||
Priority: req.Priority,
|
||||
Description: req.Description,
|
||||
Effect: req.Effect,
|
||||
Roles: req.Roles,
|
||||
AccountTypes: req.AccountTypes,
|
||||
SubjectUUID: req.SubjectUUID,
|
||||
Actions: req.Actions,
|
||||
Repositories: req.Repositories,
|
||||
Enabled: enabled,
|
||||
CreatedBy: createdBy,
|
||||
}
|
||||
|
||||
id, err := database.CreatePolicyRule(row)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if engine != nil {
|
||||
_ = engine.Reload(database)
|
||||
}
|
||||
|
||||
if auditFn != nil {
|
||||
auditFn("policy_rule_created", createdBy, "", "", r.RemoteAddr, map[string]string{
|
||||
"rule_id": strconv.FormatInt(id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
created, err := database.GetPolicyRule(id)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusCreated, created)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminGetPolicyRuleHandler handles GET /v1/policy/rules/{id}.
|
||||
func AdminGetPolicyRuleHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
rule, err := database.GetPolicyRule(id)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||
return
|
||||
}
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, rule)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminUpdatePolicyRuleHandler handles PATCH /v1/policy/rules/{id}.
|
||||
func AdminUpdatePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req policyUpdateRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Priority != nil && *req.Priority < 1 {
|
||||
writeAdminError(w, http.StatusBadRequest, "priority must be >= 1 (0 is reserved for built-ins)")
|
||||
return
|
||||
}
|
||||
if req.Effect != nil {
|
||||
if err := validateEffect(*req.Effect); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.Actions != nil {
|
||||
if len(req.Actions) == 0 {
|
||||
writeAdminError(w, http.StatusBadRequest, "at least one action is required")
|
||||
return
|
||||
}
|
||||
if err := validateActions(req.Actions); err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
updates := db.PolicyRuleRow{}
|
||||
if req.Priority != nil {
|
||||
updates.Priority = *req.Priority
|
||||
}
|
||||
if req.Description != nil {
|
||||
updates.Description = *req.Description
|
||||
}
|
||||
if req.Effect != nil {
|
||||
updates.Effect = *req.Effect
|
||||
}
|
||||
if req.Roles != nil {
|
||||
updates.Roles = req.Roles
|
||||
}
|
||||
if req.AccountTypes != nil {
|
||||
updates.AccountTypes = req.AccountTypes
|
||||
}
|
||||
if req.SubjectUUID != nil {
|
||||
updates.SubjectUUID = *req.SubjectUUID
|
||||
}
|
||||
if req.Actions != nil {
|
||||
updates.Actions = req.Actions
|
||||
}
|
||||
if req.Repositories != nil {
|
||||
updates.Repositories = req.Repositories
|
||||
}
|
||||
|
||||
if err := database.UpdatePolicyRule(id, updates); err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||
return
|
||||
}
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Handle enabled separately since it's a bool.
|
||||
if req.Enabled != nil {
|
||||
if err := database.SetPolicyRuleEnabled(id, *req.Enabled); err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if engine != nil {
|
||||
_ = engine.Reload(database)
|
||||
}
|
||||
|
||||
if auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(r.Context())
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
auditFn("policy_rule_updated", actorID, "", "", r.RemoteAddr, map[string]string{
|
||||
"rule_id": strconv.FormatInt(id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
updated, err := database.GetPolicyRule(id)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, updated)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminDeletePolicyRuleHandler handles DELETE /v1/policy/rules/{id}.
|
||||
func AdminDeletePolicyRuleHandler(database *db.DB, engine PolicyReloader, auditFn AuditFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
id, err := strconv.ParseInt(chi.URLParam(r, "id"), 10, 64)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusBadRequest, "invalid rule ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := database.DeletePolicyRule(id); err != nil {
|
||||
if errors.Is(err, db.ErrPolicyRuleNotFound) {
|
||||
writeAdminError(w, http.StatusNotFound, "policy rule not found")
|
||||
return
|
||||
}
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
// Reload policy engine.
|
||||
if engine != nil {
|
||||
_ = engine.Reload(database)
|
||||
}
|
||||
|
||||
if auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(r.Context())
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
auditFn("policy_rule_deleted", actorID, "", "", r.RemoteAddr, map[string]string{
|
||||
"rule_id": strconv.FormatInt(id, 10),
|
||||
})
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
337
internal/server/admin_policy_test.go
Normal file
337
internal/server/admin_policy_test.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
func TestAdminPolicyCRUDCycle(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, reloader := buildAdminRouter(t, database)
|
||||
|
||||
// Create a rule.
|
||||
createBody := `{
|
||||
"priority": 50,
|
||||
"description": "allow CI push",
|
||||
"effect": "allow",
|
||||
"roles": ["ci"],
|
||||
"actions": ["registry:push", "registry:pull"],
|
||||
"repositories": ["production/*"],
|
||||
"enabled": true
|
||||
}`
|
||||
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||
if rr.Code != 201 {
|
||||
t.Fatalf("create status: got %d, want 201; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
if reloader.reloadCount != 1 {
|
||||
t.Fatalf("reload count after create: got %d, want 1", reloader.reloadCount)
|
||||
}
|
||||
|
||||
var created db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||
t.Fatalf("decode created: %v", err)
|
||||
}
|
||||
if created.ID == 0 {
|
||||
t.Fatal("expected non-zero ID")
|
||||
}
|
||||
if created.Effect != "allow" {
|
||||
t.Fatalf("effect: got %q, want %q", created.Effect, "allow")
|
||||
}
|
||||
if len(created.Roles) != 1 || created.Roles[0] != "ci" {
|
||||
t.Fatalf("roles: got %v, want [ci]", created.Roles)
|
||||
}
|
||||
if created.CreatedBy != "admin-uuid" {
|
||||
t.Fatalf("created_by: got %q, want %q", created.CreatedBy, "admin-uuid")
|
||||
}
|
||||
|
||||
// Get the rule.
|
||||
rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("get status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var got db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&got); err != nil {
|
||||
t.Fatalf("decode got: %v", err)
|
||||
}
|
||||
if got.ID != created.ID {
|
||||
t.Fatalf("id: got %d, want %d", got.ID, created.ID)
|
||||
}
|
||||
|
||||
// List rules.
|
||||
rr = adminReq(t, router, "GET", "/v1/policy/rules/", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("list status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var rules []db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil {
|
||||
t.Fatalf("decode list: %v", err)
|
||||
}
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("rule count: got %d, want 1", len(rules))
|
||||
}
|
||||
|
||||
// Update the rule.
|
||||
updateBody := `{"priority": 25, "description": "updated CI push"}`
|
||||
rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), updateBody)
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
if reloader.reloadCount != 2 {
|
||||
t.Fatalf("reload count after update: got %d, want 2", reloader.reloadCount)
|
||||
}
|
||||
|
||||
var updated db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil {
|
||||
t.Fatalf("decode updated: %v", err)
|
||||
}
|
||||
if updated.Priority != 25 {
|
||||
t.Fatalf("updated priority: got %d, want 25", updated.Priority)
|
||||
}
|
||||
if updated.Description != "updated CI push" {
|
||||
t.Fatalf("updated description: got %q, want %q", updated.Description, "updated CI push")
|
||||
}
|
||||
// Effect should be unchanged.
|
||||
if updated.Effect != "allow" {
|
||||
t.Fatalf("updated effect: got %q, want %q (unchanged)", updated.Effect, "allow")
|
||||
}
|
||||
|
||||
// Delete the rule.
|
||||
rr = adminReq(t, router, "DELETE", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||
if rr.Code != 204 {
|
||||
t.Fatalf("delete status: got %d, want 204", rr.Code)
|
||||
}
|
||||
if reloader.reloadCount != 3 {
|
||||
t.Fatalf("reload count after delete: got %d, want 3", reloader.reloadCount)
|
||||
}
|
||||
|
||||
// Verify it's gone.
|
||||
rr = adminReq(t, router, "GET", fmt.Sprintf("/v1/policy/rules/%d", created.ID), "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("after delete status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminCreatePolicyRuleValidation(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "priority too low",
|
||||
body: `{"priority":0,"description":"test","effect":"allow","actions":["registry:pull"]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "missing description",
|
||||
body: `{"priority":1,"effect":"allow","actions":["registry:pull"]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "invalid effect",
|
||||
body: `{"priority":1,"description":"test","effect":"maybe","actions":["registry:pull"]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "no actions",
|
||||
body: `{"priority":1,"description":"test","effect":"allow","actions":[]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
body: `{"priority":1,"description":"test","effect":"allow","actions":["bogus:action"]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "bad JSON",
|
||||
body: `not json`,
|
||||
want: 400,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rr := adminReq(t, router, "POST", "/v1/policy/rules/", tc.body)
|
||||
if rr.Code != tc.want {
|
||||
t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUpdatePolicyRuleValidation(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// Create a rule to update.
|
||||
createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"]}`
|
||||
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||
if rr.Code != 201 {
|
||||
t.Fatalf("create status: got %d, want 201", rr.Code)
|
||||
}
|
||||
var created db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body string
|
||||
want int
|
||||
}{
|
||||
{
|
||||
name: "priority too low",
|
||||
body: `{"priority":0}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "invalid effect",
|
||||
body: `{"effect":"maybe"}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "empty actions",
|
||||
body: `{"actions":[]}`,
|
||||
want: 400,
|
||||
},
|
||||
{
|
||||
name: "invalid action",
|
||||
body: `{"actions":["bogus:action"]}`,
|
||||
want: 400,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
rr := adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), tc.body)
|
||||
if rr.Code != tc.want {
|
||||
t.Fatalf("status: got %d, want %d; body: %s", rr.Code, tc.want, rr.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUpdatePolicyRuleNotFound(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "PATCH", "/v1/policy/rules/9999", `{"description":"nope"}`)
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDeletePolicyRuleNotFound(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "DELETE", "/v1/policy/rules/9999", "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetPolicyRuleNotFound(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/policy/rules/9999", "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetPolicyRuleInvalidID(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/policy/rules/not-a-number", "")
|
||||
if rr.Code != 400 {
|
||||
t.Fatalf("status: got %d, want 400", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminPolicyRulesNonAdmin(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router := buildNonAdminRouter(t, database)
|
||||
|
||||
// All policy rule endpoints require admin.
|
||||
endpoints := []struct {
|
||||
method string
|
||||
path string
|
||||
body string
|
||||
}{
|
||||
{"GET", "/v1/policy/rules/", ""},
|
||||
{"POST", "/v1/policy/rules/", `{"priority":1,"description":"test","effect":"allow","actions":["registry:pull"]}`},
|
||||
{"GET", "/v1/policy/rules/1", ""},
|
||||
{"PATCH", "/v1/policy/rules/1", `{"description":"updated"}`},
|
||||
{"DELETE", "/v1/policy/rules/1", ""},
|
||||
}
|
||||
|
||||
for _, ep := range endpoints {
|
||||
t.Run(ep.method+" "+ep.path, func(t *testing.T) {
|
||||
rr := adminReq(t, router, ep.method, ep.path, ep.body)
|
||||
if rr.Code != 403 {
|
||||
t.Fatalf("status: got %d, want 403; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminUpdatePolicyRuleEnabled(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
// Create an enabled rule.
|
||||
createBody := `{"priority":50,"description":"test","effect":"allow","actions":["registry:pull"],"enabled":true}`
|
||||
rr := adminReq(t, router, "POST", "/v1/policy/rules/", createBody)
|
||||
if rr.Code != 201 {
|
||||
t.Fatalf("create status: got %d, want 201", rr.Code)
|
||||
}
|
||||
var created db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&created); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if !created.Enabled {
|
||||
t.Fatal("expected created rule to be enabled")
|
||||
}
|
||||
|
||||
// Disable it.
|
||||
rr = adminReq(t, router, "PATCH", fmt.Sprintf("/v1/policy/rules/%d", created.ID), `{"enabled":false}`)
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("update status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
var updated db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&updated); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if updated.Enabled {
|
||||
t.Fatal("expected updated rule to be disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListPolicyRulesEmpty(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/policy/rules/", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var rules []db.PolicyRuleRow
|
||||
if err := json.NewDecoder(rr.Body).Decode(&rules); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(rules) != 0 {
|
||||
t.Fatalf("rule count: got %d, want 0", len(rules))
|
||||
}
|
||||
}
|
||||
94
internal/server/admin_repo.go
Normal file
94
internal/server/admin_repo.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// AdminListReposHandler handles GET /v1/repositories.
|
||||
func AdminListReposHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
limit := 50
|
||||
offset := 0
|
||||
if n := r.URL.Query().Get("n"); n != "" {
|
||||
if v, err := strconv.Atoi(n); err == nil && v > 0 {
|
||||
limit = v
|
||||
}
|
||||
}
|
||||
if o := r.URL.Query().Get("offset"); o != "" {
|
||||
if v, err := strconv.Atoi(o); err == nil && v >= 0 {
|
||||
offset = v
|
||||
}
|
||||
}
|
||||
|
||||
repos, err := database.ListRepositoriesWithMetadata(limit, offset)
|
||||
if err != nil {
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
if repos == nil {
|
||||
repos = []db.RepoMetadata{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, repos)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminGetRepoHandler handles GET /v1/repositories/*.
|
||||
func AdminGetRepoHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "*")
|
||||
if name == "" {
|
||||
writeAdminError(w, http.StatusBadRequest, "repository name required")
|
||||
return
|
||||
}
|
||||
|
||||
detail, err := database.GetRepositoryDetail(name)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrRepoNotFound) {
|
||||
writeAdminError(w, http.StatusNotFound, "repository not found")
|
||||
return
|
||||
}
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, detail)
|
||||
}
|
||||
}
|
||||
|
||||
// AdminDeleteRepoHandler handles DELETE /v1/repositories/*.
|
||||
// Requires admin role (enforced by RequireAdmin middleware).
|
||||
func AdminDeleteRepoHandler(database *db.DB, auditFn AuditFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "*")
|
||||
if name == "" {
|
||||
writeAdminError(w, http.StatusBadRequest, "repository name required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := database.DeleteRepository(name); err != nil {
|
||||
if errors.Is(err, db.ErrRepoNotFound) {
|
||||
writeAdminError(w, http.StatusNotFound, "repository not found")
|
||||
return
|
||||
}
|
||||
writeAdminError(w, http.StatusInternalServerError, "internal error")
|
||||
return
|
||||
}
|
||||
|
||||
if auditFn != nil {
|
||||
claims := auth.ClaimsFromContext(r.Context())
|
||||
actorID := ""
|
||||
if claims != nil {
|
||||
actorID = claims.Subject
|
||||
}
|
||||
auditFn("repo_deleted", actorID, name, "", r.RemoteAddr, nil)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
186
internal/server/admin_repo_test.go
Normal file
186
internal/server/admin_repo_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// seedRepoForAdmin inserts a repository with a manifest and tags into the test DB.
|
||||
func seedRepoForAdmin(t *testing.T, database *db.DB, name string, tags []string) {
|
||||
t.Helper()
|
||||
_, err := database.Exec(`INSERT INTO repositories (name) VALUES (?)`, name)
|
||||
if err != nil {
|
||||
t.Fatalf("insert repo %q: %v", name, err)
|
||||
}
|
||||
var repoID int64
|
||||
if err := database.QueryRow(`SELECT id FROM repositories WHERE name = ?`, name).Scan(&repoID); err != nil {
|
||||
t.Fatalf("select repo id: %v", err)
|
||||
}
|
||||
_, err = database.Exec(
|
||||
`INSERT INTO manifests (repository_id, digest, media_type, content, size)
|
||||
VALUES (?, ?, 'application/vnd.oci.image.manifest.v1+json', '{}', 512)`,
|
||||
repoID, "sha256:manifest-"+name,
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("insert manifest: %v", err)
|
||||
}
|
||||
var manifestID int64
|
||||
if err := database.QueryRow(`SELECT id FROM manifests WHERE repository_id = ?`, repoID).Scan(&manifestID); err != nil {
|
||||
t.Fatalf("select manifest id: %v", err)
|
||||
}
|
||||
for _, tag := range tags {
|
||||
_, err := database.Exec(`INSERT INTO tags (repository_id, name, manifest_id) VALUES (?, ?, ?)`,
|
||||
repoID, tag, manifestID)
|
||||
if err != nil {
|
||||
t.Fatalf("insert tag %q: %v", tag, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListRepos(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
seedRepoForAdmin(t, database, "alpha/app", []string{"latest", "v1.0"})
|
||||
seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"})
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/repositories", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var repos []db.RepoMetadata
|
||||
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(repos) != 2 {
|
||||
t.Fatalf("repo count: got %d, want 2", len(repos))
|
||||
}
|
||||
if repos[0].Name != "alpha/app" {
|
||||
t.Fatalf("first repo: got %q, want %q", repos[0].Name, "alpha/app")
|
||||
}
|
||||
if repos[0].TagCount != 2 {
|
||||
t.Fatalf("tag count: got %d, want 2", repos[0].TagCount)
|
||||
}
|
||||
if repos[0].ManifestCount != 1 {
|
||||
t.Fatalf("manifest count: got %d, want 1", repos[0].ManifestCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListReposEmpty(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/repositories", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var repos []db.RepoMetadata
|
||||
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(repos) != 0 {
|
||||
t.Fatalf("repo count: got %d, want 0", len(repos))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminListReposPagination(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
seedRepoForAdmin(t, database, "alpha/app", []string{"latest"})
|
||||
seedRepoForAdmin(t, database, "bravo/lib", []string{"latest"})
|
||||
seedRepoForAdmin(t, database, "charlie/svc", []string{"latest"})
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/repositories?n=2&offset=0", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200", rr.Code)
|
||||
}
|
||||
|
||||
var repos []db.RepoMetadata
|
||||
if err := json.NewDecoder(rr.Body).Decode(&repos); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if len(repos) != 2 {
|
||||
t.Fatalf("page 1 count: got %d, want 2", len(repos))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetRepoDetail(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest", "v1.0"})
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "")
|
||||
if rr.Code != 200 {
|
||||
t.Fatalf("status: got %d, want 200; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
var detail db.RepoDetail
|
||||
if err := json.NewDecoder(rr.Body).Decode(&detail); err != nil {
|
||||
t.Fatalf("decode: %v", err)
|
||||
}
|
||||
if detail.Name != "myorg/myapp" {
|
||||
t.Fatalf("name: got %q, want %q", detail.Name, "myorg/myapp")
|
||||
}
|
||||
if len(detail.Tags) != 2 {
|
||||
t.Fatalf("tag count: got %d, want 2", len(detail.Tags))
|
||||
}
|
||||
if len(detail.Manifests) != 1 {
|
||||
t.Fatalf("manifest count: got %d, want 1", len(detail.Manifests))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminGetRepoDetailNotFound(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "GET", "/v1/repositories/nonexistent/repo", "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDeleteRepo(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"})
|
||||
|
||||
rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "")
|
||||
if rr.Code != 204 {
|
||||
t.Fatalf("status: got %d, want 204; body: %s", rr.Code, rr.Body.String())
|
||||
}
|
||||
|
||||
// Verify it's gone.
|
||||
rr = adminReq(t, router, "GET", "/v1/repositories/myorg/myapp", "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("after delete status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDeleteRepoNotFound(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router, _ := buildAdminRouter(t, database)
|
||||
|
||||
rr := adminReq(t, router, "DELETE", "/v1/repositories/nonexistent/repo", "")
|
||||
if rr.Code != 404 {
|
||||
t.Fatalf("status: got %d, want 404", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdminDeleteRepoNonAdmin(t *testing.T) {
|
||||
database := openAdminTestDB(t)
|
||||
router := buildNonAdminRouter(t, database)
|
||||
|
||||
seedRepoForAdmin(t, database, "myorg/myapp", []string{"latest"})
|
||||
|
||||
rr := adminReq(t, router, "DELETE", "/v1/repositories/myorg/myapp", "")
|
||||
if rr.Code != 403 {
|
||||
t.Fatalf("status: got %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
56
internal/server/admin_routes.go
Normal file
56
internal/server/admin_routes.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
)
|
||||
|
||||
// AdminDeps holds the dependencies needed by admin routes.
|
||||
type AdminDeps struct {
|
||||
DB *db.DB
|
||||
Login LoginClient
|
||||
Engine PolicyReloader
|
||||
AuditFn AuditFunc
|
||||
GCState *GCState
|
||||
}
|
||||
|
||||
// MountAdminRoutes adds admin REST endpoints to the router.
|
||||
// Auth middleware is applied at the route group level.
|
||||
func MountAdminRoutes(r chi.Router, validator TokenValidator, serviceName string, deps AdminDeps) {
|
||||
// Health endpoint - no auth required.
|
||||
r.Get("/v1/health", AdminHealthHandler())
|
||||
|
||||
// Auth endpoints - no bearer auth required (login uses credentials).
|
||||
r.Post("/v1/auth/login", AdminLoginHandler(deps.Login))
|
||||
|
||||
// Authenticated endpoints.
|
||||
r.Route("/v1", func(v1 chi.Router) {
|
||||
v1.Use(RequireAuth(validator, serviceName))
|
||||
|
||||
// Logout.
|
||||
v1.Post("/auth/logout", AdminLogoutHandler())
|
||||
|
||||
// Repositories - list and detail require auth, delete requires admin.
|
||||
v1.Get("/repositories", AdminListReposHandler(deps.DB))
|
||||
v1.Get("/repositories/*", AdminGetRepoHandler(deps.DB))
|
||||
v1.With(RequireAdmin()).Delete("/repositories/*", AdminDeleteRepoHandler(deps.DB, deps.AuditFn))
|
||||
|
||||
// Policy - all require admin.
|
||||
v1.Route("/policy/rules", func(pr chi.Router) {
|
||||
pr.Use(RequireAdmin())
|
||||
pr.Get("/", AdminListPolicyRulesHandler(deps.DB))
|
||||
pr.Post("/", AdminCreatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||
pr.Get("/{id}", AdminGetPolicyRuleHandler(deps.DB))
|
||||
pr.Patch("/{id}", AdminUpdatePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||
pr.Delete("/{id}", AdminDeletePolicyRuleHandler(deps.DB, deps.Engine, deps.AuditFn))
|
||||
})
|
||||
|
||||
// Audit - requires admin.
|
||||
v1.With(RequireAdmin()).Get("/audit", AdminListAuditHandler(deps.DB))
|
||||
|
||||
// GC - requires admin.
|
||||
v1.With(RequireAdmin()).Post("/gc", AdminTriggerGCHandler(deps.GCState))
|
||||
v1.With(RequireAdmin()).Get("/gc/status", AdminGCStatusHandler(deps.GCState))
|
||||
})
|
||||
}
|
||||
148
internal/server/admin_test.go
Normal file
148
internal/server/admin_test.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/internal/auth"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/db"
|
||||
"git.wntrmute.dev/kyle/mcr/internal/policy"
|
||||
)
|
||||
|
||||
func openAdminTestDB(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
path := filepath.Join(t.TempDir(), "test.db")
|
||||
d, err := db.Open(path)
|
||||
if err != nil {
|
||||
t.Fatalf("Open: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = d.Close() })
|
||||
if err := d.Migrate(); err != nil {
|
||||
t.Fatalf("Migrate: %v", err)
|
||||
}
|
||||
return d
|
||||
}
|
||||
|
||||
type fakePolicyReloader struct {
|
||||
reloadCount int
|
||||
}
|
||||
|
||||
func (f *fakePolicyReloader) Reload(_ policy.RuleStore) error {
|
||||
f.reloadCount++
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildAdminRouter creates a chi router with admin routes wired up,
|
||||
// using a fakeValidator that returns admin claims for any bearer token.
|
||||
func buildAdminRouter(t *testing.T, database *db.DB) (chi.Router, *fakePolicyReloader) {
|
||||
t.Helper()
|
||||
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}},
|
||||
}
|
||||
login := &fakeLoginClient{token: "test-token", expiresIn: 3600}
|
||||
reloader := &fakePolicyReloader{}
|
||||
gcState := &GCState{}
|
||||
|
||||
r := chi.NewRouter()
|
||||
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||
DB: database,
|
||||
Login: login,
|
||||
Engine: reloader,
|
||||
AuditFn: nil,
|
||||
GCState: gcState,
|
||||
})
|
||||
return r, reloader
|
||||
}
|
||||
|
||||
// buildNonAdminRouter creates a chi router that returns non-admin claims.
|
||||
func buildNonAdminRouter(t *testing.T, database *db.DB) chi.Router {
|
||||
t.Helper()
|
||||
|
||||
validator := &fakeValidator{
|
||||
claims: &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}},
|
||||
}
|
||||
login := &fakeLoginClient{token: "test-token", expiresIn: 3600}
|
||||
reloader := &fakePolicyReloader{}
|
||||
gcState := &GCState{}
|
||||
|
||||
r := chi.NewRouter()
|
||||
MountAdminRoutes(r, validator, "mcr-test", AdminDeps{
|
||||
DB: database,
|
||||
Login: login,
|
||||
Engine: reloader,
|
||||
AuditFn: nil,
|
||||
GCState: gcState,
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
// adminReq is a convenience helper for making HTTP requests against the admin
|
||||
// router, automatically including the Authorization header.
|
||||
func adminReq(t *testing.T, router http.Handler, method, path string, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
var bodyReader io.Reader
|
||||
if body != "" {
|
||||
bodyReader = strings.NewReader(body)
|
||||
}
|
||||
req := httptest.NewRequest(method, path, bodyReader)
|
||||
req.Header.Set("Authorization", "Bearer valid-token")
|
||||
if body != "" {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
rr := httptest.NewRecorder()
|
||||
router.ServeHTTP(rr, req)
|
||||
return rr
|
||||
}
|
||||
|
||||
func TestRequireAdminAllowed(t *testing.T) {
|
||||
claims := &auth.Claims{Subject: "admin-uuid", AccountType: "human", Roles: []string{"admin"}}
|
||||
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusOK {
|
||||
t.Fatalf("admin allowed: got %d, want 200", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAdminDenied(t *testing.T) {
|
||||
claims := &auth.Claims{Subject: "user-uuid", AccountType: "human", Roles: []string{"user"}}
|
||||
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Fatal("inner handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
req = req.WithContext(auth.ContextWithClaims(req.Context(), claims))
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusForbidden {
|
||||
t.Fatalf("non-admin denied: got %d, want 403", rr.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequireAdminNoClaims(t *testing.T) {
|
||||
handler := RequireAdmin()(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
t.Fatal("inner handler should not be called")
|
||||
}))
|
||||
|
||||
req := httptest.NewRequest("GET", "/test", nil)
|
||||
rr := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rr, req)
|
||||
|
||||
if rr.Code != http.StatusUnauthorized {
|
||||
t.Fatalf("no claims: got %d, want 401", rr.Code)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user