package server import ( "context" "encoding/json" "fmt" "net/http" "net/http/httptest" "path/filepath" "strings" "testing" "github.com/go-chi/chi/v5" mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth" "git.wntrmute.dev/kyle/mcns/internal/db" ) // openTestDB creates a temporary SQLite database with all migrations applied. func openTestDB(t *testing.T) *db.DB { t.Helper() dir := t.TempDir() database, err := db.Open(filepath.Join(dir, "test.db")) if err != nil { t.Fatalf("open db: %v", err) } if err := database.Migrate(); err != nil { t.Fatalf("migrate: %v", err) } t.Cleanup(func() { _ = database.Close() }) return database } // createTestZone inserts a zone for use by record tests. func createTestZone(t *testing.T, database *db.DB) *db.Zone { t.Helper() zone, err := database.CreateZone("test.example.com", "ns.example.com.", "admin.example.com.", 3600, 600, 86400, 300) if err != nil { t.Fatalf("create zone: %v", err) } return zone } // newChiRequest builds a request with chi URL params injected into the context. func newChiRequest(method, target string, body string, params map[string]string) *http.Request { var r *http.Request if body != "" { r = httptest.NewRequest(method, target, strings.NewReader(body)) } else { r = httptest.NewRequest(method, target, nil) } r.Header.Set("Content-Type", "application/json") if len(params) > 0 { rctx := chi.NewRouteContext() for k, v := range params { rctx.URLParams.Add(k, v) } r = r.WithContext(context.WithValue(r.Context(), chi.RouteCtxKey, rctx)) } return r } // decodeJSON decodes the response body into v. func decodeJSON(t *testing.T, rec *httptest.ResponseRecorder, v any) { t.Helper() if err := json.NewDecoder(rec.Body).Decode(v); err != nil { t.Fatalf("decode json: %v", err) } } // ---- Zone handler tests ---- func TestListZonesHandler_SeedOnly(t *testing.T) { database := openTestDB(t) handler := listZonesHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones", "", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var resp map[string][]db.Zone decodeJSON(t, rec, &resp) zones := resp["zones"] if len(zones) != 2 { t.Fatalf("got %d zones, want 2 (seed zones)", len(zones)) } } func TestListZonesHandler_Populated(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := listZonesHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones", "", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var resp map[string][]db.Zone decodeJSON(t, rec, &resp) zones := resp["zones"] // 2 seed + 1 created = 3. if len(zones) != 3 { t.Fatalf("got %d zones, want 3", len(zones)) } } func TestGetZoneHandler_Found(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := getZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var zone db.Zone decodeJSON(t, rec, &zone) if zone.Name != "test.example.com" { t.Fatalf("zone name = %q, want %q", zone.Name, "test.example.com") } } func TestGetZoneHandler_NotFound(t *testing.T) { database := openTestDB(t) handler := getZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestCreateZoneHandler_Success(t *testing.T) { database := openTestDB(t) body := `{"name":"new.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}` handler := createZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones", body, nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) } var zone db.Zone decodeJSON(t, rec, &zone) if zone.Name != "new.example.com" { t.Fatalf("zone name = %q, want %q", zone.Name, "new.example.com") } if zone.PrimaryNS != "ns1.example.com." { t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns1.example.com.") } // SOA defaults should be applied. if zone.Refresh != 3600 { t.Fatalf("refresh = %d, want 3600", zone.Refresh) } } func TestCreateZoneHandler_MissingFields(t *testing.T) { tests := []struct { name string body string }{ {"missing name", `{"primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}`}, {"missing primary_ns", `{"name":"new.example.com","admin_email":"admin.example.com."}`}, {"missing admin_email", `{"name":"new.example.com","primary_ns":"ns1.example.com."}`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { database := openTestDB(t) handler := createZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones", tt.body, nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } }) } } func TestCreateZoneHandler_Duplicate(t *testing.T) { database := openTestDB(t) createTestZone(t, database) body := `{"name":"test.example.com","primary_ns":"ns1.example.com.","admin_email":"admin.example.com."}` handler := createZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones", body, nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusConflict { t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict) } } func TestCreateZoneHandler_InvalidJSON(t *testing.T) { database := openTestDB(t) handler := createZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones", "not json", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestUpdateZoneHandler_Success(t *testing.T) { database := openTestDB(t) createTestZone(t, database) body := `{"primary_ns":"ns2.example.com.","admin_email":"newadmin.example.com.","refresh":7200}` handler := updateZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var zone db.Zone decodeJSON(t, rec, &zone) if zone.PrimaryNS != "ns2.example.com." { t.Fatalf("primary_ns = %q, want %q", zone.PrimaryNS, "ns2.example.com.") } if zone.Refresh != 7200 { t.Fatalf("refresh = %d, want 7200", zone.Refresh) } } func TestUpdateZoneHandler_NotFound(t *testing.T) { database := openTestDB(t) body := `{"primary_ns":"ns2.example.com.","admin_email":"admin.example.com."}` handler := updateZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/nonexistent.com", body, map[string]string{"zone": "nonexistent.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestUpdateZoneHandler_MissingFields(t *testing.T) { database := openTestDB(t) createTestZone(t, database) body := `{"admin_email":"admin.example.com."}` handler := updateZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com", body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestDeleteZoneHandler_Success(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := deleteZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com", "", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } // Verify the zone is gone. _, err := database.GetZone("test.example.com") if err != db.ErrNotFound { t.Fatalf("expected ErrNotFound after delete, got %v", err) } } func TestDeleteZoneHandler_NotFound(t *testing.T) { database := openTestDB(t) handler := deleteZoneHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodDelete, "/v1/zones/nonexistent.com", "", map[string]string{"zone": "nonexistent.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } // ---- Record handler tests ---- func TestListRecordsHandler_WithZone(t *testing.T) { database := openTestDB(t) createTestZone(t, database) _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } _, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.2", 300) if err != nil { t.Fatalf("create record: %v", err) } handler := listRecordsHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var resp map[string][]db.Record decodeJSON(t, rec, &resp) records := resp["records"] if len(records) != 2 { t.Fatalf("got %d records, want 2", len(records)) } } func TestListRecordsHandler_ZoneNotFound(t *testing.T) { database := openTestDB(t) handler := listRecordsHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/nonexistent.com/records", "", map[string]string{"zone": "nonexistent.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestListRecordsHandler_EmptyZone(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := listRecordsHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records", "", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var resp map[string][]db.Record decodeJSON(t, rec, &resp) records := resp["records"] if len(records) != 0 { t.Fatalf("got %d records, want 0", len(records)) } } func TestListRecordsHandler_WithFilters(t *testing.T) { database := openTestDB(t) createTestZone(t, database) _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } _, err = database.CreateRecord("test.example.com", "www", "A", "10.0.0.2", 300) if err != nil { t.Fatalf("create record: %v", err) } _, err = database.CreateRecord("test.example.com", "mail", "A", "10.0.0.3", 300) if err != nil { t.Fatalf("create record: %v", err) } handler := listRecordsHandler(database) // Filter by name. rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records?name=www", "", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var resp map[string][]db.Record decodeJSON(t, rec, &resp) if len(resp["records"]) != 2 { t.Fatalf("got %d records for name=www, want 2", len(resp["records"])) } } func TestGetRecordHandler_Found(t *testing.T) { database := openTestDB(t) createTestZone(t, database) created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } handler := getRecordHandler(database) rec := httptest.NewRecorder() idStr := fmt.Sprintf("%d", created.ID) req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{ "zone": "test.example.com", "id": idStr, }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } var record db.Record decodeJSON(t, rec, &record) if record.Name != "www" { t.Fatalf("record name = %q, want %q", record.Name, "www") } if record.Value != "10.0.0.1" { t.Fatalf("record value = %q, want %q", record.Value, "10.0.0.1") } } func TestGetRecordHandler_NotFound(t *testing.T) { database := openTestDB(t) handler := getRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/99999", "", map[string]string{ "zone": "test.example.com", "id": "99999", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestGetRecordHandler_InvalidID(t *testing.T) { database := openTestDB(t) handler := getRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodGet, "/v1/zones/test.example.com/records/abc", "", map[string]string{ "zone": "test.example.com", "id": "abc", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestCreateRecordHandler_Success(t *testing.T) { database := openTestDB(t) createTestZone(t, database) body := `{"name":"www","type":"A","value":"10.0.0.1","ttl":600}` handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusCreated, rec.Body.String()) } var record db.Record decodeJSON(t, rec, &record) if record.Name != "www" { t.Fatalf("record name = %q, want %q", record.Name, "www") } if record.Type != "A" { t.Fatalf("record type = %q, want %q", record.Type, "A") } if record.TTL != 600 { t.Fatalf("ttl = %d, want 600", record.TTL) } } func TestCreateRecordHandler_MissingFields(t *testing.T) { tests := []struct { name string body string }{ {"missing name", `{"type":"A","value":"10.0.0.1"}`}, {"missing type", `{"name":"www","value":"10.0.0.1"}`}, {"missing value", `{"name":"www","type":"A"}`}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", tt.body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } }) } } func TestCreateRecordHandler_InvalidIP(t *testing.T) { database := openTestDB(t) createTestZone(t, database) body := `{"name":"www","type":"A","value":"not-an-ip"}` handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestCreateRecordHandler_CNAMEConflict(t *testing.T) { database := openTestDB(t) createTestZone(t, database) // Create an A record first. _, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create A record: %v", err) } // Try to create a CNAME for the same name via handler. body := `{"name":"www","type":"CNAME","value":"other.example.com."}` handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", body, map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusConflict { t.Fatalf("status = %d, want %d", rec.Code, http.StatusConflict) } } func TestCreateRecordHandler_ZoneNotFound(t *testing.T) { database := openTestDB(t) body := `{"name":"www","type":"A","value":"10.0.0.1"}` handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/nonexistent.com/records", body, map[string]string{"zone": "nonexistent.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestCreateRecordHandler_InvalidJSON(t *testing.T) { database := openTestDB(t) createTestZone(t, database) handler := createRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPost, "/v1/zones/test.example.com/records", "not json", map[string]string{"zone": "test.example.com"}) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestUpdateRecordHandler_Success(t *testing.T) { database := openTestDB(t) createTestZone(t, database) created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } idStr := fmt.Sprintf("%d", created.ID) body := `{"name":"www","type":"A","value":"10.0.0.2","ttl":600}` handler := updateRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{ "zone": "test.example.com", "id": idStr, }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d; body: %s", rec.Code, http.StatusOK, rec.Body.String()) } var record db.Record decodeJSON(t, rec, &record) if record.Value != "10.0.0.2" { t.Fatalf("value = %q, want %q", record.Value, "10.0.0.2") } if record.TTL != 600 { t.Fatalf("ttl = %d, want 600", record.TTL) } } func TestUpdateRecordHandler_NotFound(t *testing.T) { database := openTestDB(t) body := `{"name":"www","type":"A","value":"10.0.0.1"}` handler := updateRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/99999", body, map[string]string{ "zone": "test.example.com", "id": "99999", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestUpdateRecordHandler_InvalidID(t *testing.T) { database := openTestDB(t) body := `{"name":"www","type":"A","value":"10.0.0.1"}` handler := updateRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/abc", body, map[string]string{ "zone": "test.example.com", "id": "abc", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestUpdateRecordHandler_MissingFields(t *testing.T) { database := openTestDB(t) createTestZone(t, database) created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } idStr := fmt.Sprintf("%d", created.ID) // Missing name. body := `{"type":"A","value":"10.0.0.1"}` handler := updateRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodPut, "/v1/zones/test.example.com/records/"+idStr, body, map[string]string{ "zone": "test.example.com", "id": idStr, }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } func TestDeleteRecordHandler_Success(t *testing.T) { database := openTestDB(t) createTestZone(t, database) created, err := database.CreateRecord("test.example.com", "www", "A", "10.0.0.1", 300) if err != nil { t.Fatalf("create record: %v", err) } idStr := fmt.Sprintf("%d", created.ID) handler := deleteRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/"+idStr, "", map[string]string{ "zone": "test.example.com", "id": idStr, }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNoContent) } // Verify record is gone. _, err = database.GetRecord(created.ID) if err != db.ErrNotFound { t.Fatalf("expected ErrNotFound after delete, got %v", err) } } func TestDeleteRecordHandler_NotFound(t *testing.T) { database := openTestDB(t) handler := deleteRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/99999", "", map[string]string{ "zone": "test.example.com", "id": "99999", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusNotFound { t.Fatalf("status = %d, want %d", rec.Code, http.StatusNotFound) } } func TestDeleteRecordHandler_InvalidID(t *testing.T) { database := openTestDB(t) handler := deleteRecordHandler(database) rec := httptest.NewRecorder() req := newChiRequest(http.MethodDelete, "/v1/zones/test.example.com/records/abc", "", map[string]string{ "zone": "test.example.com", "id": "abc", }) handler.ServeHTTP(rec, req) if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } } // ---- Middleware tests ---- func TestRequireAdmin_WithAdminContext(t *testing.T) { called := false inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) handler := requireAdmin(inner) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/test", nil) // Inject admin TokenInfo into context. info := &mcdslauth.TokenInfo{ Username: "admin-user", IsAdmin: true, Roles: []string{"admin"}, } ctx := context.WithValue(req.Context(), tokenInfoKey, info) req = req.WithContext(ctx) handler.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } if !called { t.Fatal("inner handler was not called") } } func TestRequireAdmin_WithNonAdminContext(t *testing.T) { called := false inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) handler := requireAdmin(inner) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/test", nil) // Inject non-admin TokenInfo into context. info := &mcdslauth.TokenInfo{ Username: "regular-user", IsAdmin: false, Roles: []string{"viewer"}, } ctx := context.WithValue(req.Context(), tokenInfoKey, info) req = req.WithContext(ctx) handler.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) } if called { t.Fatal("inner handler should not have been called") } } func TestRequireAdmin_NoTokenInfo(t *testing.T) { called := false inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { called = true w.WriteHeader(http.StatusOK) }) handler := requireAdmin(inner) rec := httptest.NewRecorder() req := httptest.NewRequest(http.MethodGet, "/test", nil) handler.ServeHTTP(rec, req) if rec.Code != http.StatusForbidden { t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) } if called { t.Fatal("inner handler should not have been called") } } func TestExtractBearerToken(t *testing.T) { tests := []struct { name string header string want string }{ {"valid bearer", "Bearer abc123", "abc123"}, {"empty header", "", ""}, {"no prefix", "abc123", ""}, {"basic auth", "Basic abc123", ""}, {"bearer with spaces", "Bearer token-with-space ", "token-with-space"}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { r := httptest.NewRequest(http.MethodGet, "/", nil) if tt.header != "" { r.Header.Set("Authorization", tt.header) } got := extractBearerToken(r) if got != tt.want { t.Fatalf("extractBearerToken(%q) = %q, want %q", tt.header, got, tt.want) } }) } } func TestTokenInfoFromContext(t *testing.T) { // No token info in context. ctx := context.Background() if info := tokenInfoFromContext(ctx); info != nil { t.Fatal("expected nil, got token info") } // With token info. expected := &mcdslauth.TokenInfo{Username: "testuser", IsAdmin: true} ctx = context.WithValue(ctx, tokenInfoKey, expected) got := tokenInfoFromContext(ctx) if got == nil { t.Fatal("expected token info, got nil") } if got.Username != expected.Username { t.Fatalf("username = %q, want %q", got.Username, expected.Username) } if !got.IsAdmin { t.Fatal("expected IsAdmin to be true") } } // ---- writeJSON / writeError tests ---- func TestWriteJSON(t *testing.T) { rec := httptest.NewRecorder() writeJSON(rec, http.StatusOK, map[string]string{"key": "value"}) if rec.Code != http.StatusOK { t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK) } if ct := rec.Header().Get("Content-Type"); ct != "application/json" { t.Fatalf("content-type = %q, want %q", ct, "application/json") } var resp map[string]string decodeJSON(t, rec, &resp) if resp["key"] != "value" { t.Fatalf("got key=%q, want %q", resp["key"], "value") } } func TestWriteError(t *testing.T) { rec := httptest.NewRecorder() writeError(rec, http.StatusBadRequest, "bad input") if rec.Code != http.StatusBadRequest { t.Fatalf("status = %d, want %d", rec.Code, http.StatusBadRequest) } var resp map[string]string decodeJSON(t, rec, &resp) if resp["error"] != "bad input" { t.Fatalf("got error=%q, want %q", resp["error"], "bad input") } }