package server import ( "encoding/json" "net/http" "net/http/httptest" "testing" "git.wntrmute.dev/mc/mcr/internal/auth" ) type fakeValidator struct { claims *auth.Claims err error } func (f *fakeValidator) ValidateToken(_ string) (*auth.Claims, error) { return f.claims, f.err } func TestRequireAuthValid(t *testing.T) { t.Helper() claims := &auth.Claims{Subject: "alice", AccountType: "user", Roles: []string{"reader"}} validator := &fakeValidator{claims: claims} var gotClaims *auth.Claims inner := http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) { gotClaims = auth.ClaimsFromContext(r.Context()) }) handler := RequireAuth(validator, "mcr-test")(inner) req := httptest.NewRequest(http.MethodGet, "/v2/", nil) req.Header.Set("Authorization", "Bearer valid-token") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusOK) } if gotClaims == nil { t.Fatal("expected claims in context, got nil") } if gotClaims.Subject != "alice" { t.Fatalf("subject: got %q, want %q", gotClaims.Subject, "alice") } } func TestRequireAuthMissing(t *testing.T) { t.Helper() validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { t.Fatal("inner handler should not be called") }) handler := RequireAuth(validator, "mcr-test")(inner) req := httptest.NewRequest(http.MethodGet, "/v2/", nil) // No Authorization header. rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) } wwwAuth := rec.Header().Get("WWW-Authenticate") want := `Bearer realm="https://example.com/v2/token",service="mcr-test"` if wwwAuth != want { t.Fatalf("WWW-Authenticate: got %q, want %q", wwwAuth, want) } var ociErr ociErrorResponse if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { t.Fatalf("decode OCI error: %v", err) } if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) } } func TestRequireAuthInvalid(t *testing.T) { t.Helper() validator := &fakeValidator{claims: nil, err: auth.ErrUnauthorized} inner := http.HandlerFunc(func(http.ResponseWriter, *http.Request) { t.Fatal("inner handler should not be called") }) handler := RequireAuth(validator, "mcr-test")(inner) req := httptest.NewRequest(http.MethodGet, "/v2/", nil) req.Header.Set("Authorization", "Bearer bad-token") rec := httptest.NewRecorder() handler.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Fatalf("status: got %d, want %d", rec.Code, http.StatusUnauthorized) } var ociErr ociErrorResponse if err := json.NewDecoder(rec.Body).Decode(&ociErr); err != nil { t.Fatalf("decode OCI error: %v", err) } if len(ociErr.Errors) != 1 || ociErr.Errors[0].Code != "UNAUTHORIZED" { t.Fatalf("OCI error: got %+v, want UNAUTHORIZED", ociErr.Errors) } }