- Added `web/templates/{dashboard,audit,base,accounts,account_detail}.html` for a consistent UI.
- Implemented new audit log endpoint (`GET /v1/audit`) with filtering and pagination via `ListAuditEventsPaged`.
- Extended `AuditQueryParams`, added `AuditEventView` for joined actor/target usernames.
- Updated configuration (`goimports` preference), linting rules, and E2E tests.
- No logic changes to existing APIs.
343 lines
9.6 KiB
Go
343 lines
9.6 KiB
Go
package middleware
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"log/slog"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"git.wntrmute.dev/kyle/mcias/internal/db"
|
|
"git.wntrmute.dev/kyle/mcias/internal/model"
|
|
"git.wntrmute.dev/kyle/mcias/internal/token"
|
|
)
|
|
|
|
func generateTestKey(t *testing.T) (ed25519.PublicKey, ed25519.PrivateKey) {
|
|
t.Helper()
|
|
pub, priv, err := ed25519.GenerateKey(rand.Reader)
|
|
if err != nil {
|
|
t.Fatalf("generate test key: %v", err)
|
|
}
|
|
return pub, priv
|
|
}
|
|
|
|
func openTestDB(t *testing.T) *db.DB {
|
|
t.Helper()
|
|
database, err := db.Open(":memory:")
|
|
if err != nil {
|
|
t.Fatalf("open test db: %v", err)
|
|
}
|
|
if err := db.Migrate(database); err != nil {
|
|
t.Fatalf("migrate test db: %v", err)
|
|
}
|
|
t.Cleanup(func() { _ = database.Close() })
|
|
return database
|
|
}
|
|
|
|
const testIssuer = "https://auth.example.com"
|
|
|
|
// issueAndTrackToken creates a valid JWT and records it in the DB.
|
|
func issueAndTrackToken(t *testing.T, priv ed25519.PrivateKey, database *db.DB, accountID int64, roles []string) string {
|
|
t.Helper()
|
|
tokenStr, claims, err := token.IssueToken(priv, testIssuer, "user-uuid", roles, time.Hour)
|
|
if err != nil {
|
|
t.Fatalf("IssueToken: %v", err)
|
|
}
|
|
if err := database.TrackToken(claims.JTI, accountID, claims.IssuedAt, claims.ExpiresAt); err != nil {
|
|
t.Fatalf("TrackToken: %v", err)
|
|
}
|
|
return tokenStr
|
|
}
|
|
|
|
func TestRequestLogger(t *testing.T) {
|
|
var buf bytes.Buffer
|
|
logger := slog.New(slog.NewTextHandler(&buf, nil))
|
|
|
|
handler := RequestLogger(logger)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", rr.Code)
|
|
}
|
|
logOutput := buf.String()
|
|
if logOutput == "" {
|
|
t.Error("expected log output, got empty string")
|
|
}
|
|
// Security: Authorization header must not appear in logs.
|
|
req2 := httptest.NewRequest(http.MethodGet, "/v1/health", nil)
|
|
req2.Header.Set("Authorization", "Bearer secret-token-value")
|
|
buf.Reset()
|
|
rr2 := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr2, req2)
|
|
if bytes.Contains(buf.Bytes(), []byte("secret-token-value")) {
|
|
t.Error("log output contains Authorization token value — credential leak!")
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthValid(t *testing.T) {
|
|
pub, priv := generateTestKey(t)
|
|
database := openTestDB(t)
|
|
|
|
acct, err := database.CreateAccount("alice", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, []string{"reader"})
|
|
|
|
reached := false
|
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
reached = true
|
|
claims := ClaimsFromContext(r.Context())
|
|
if claims == nil {
|
|
t.Error("claims not in context")
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want 200; body: %s", rr.Code, rr.Body.String())
|
|
}
|
|
if !reached {
|
|
t.Error("handler was not reached with valid token")
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthMissingHeader(t *testing.T) {
|
|
pub, priv := generateTestKey(t)
|
|
_ = priv
|
|
database := openTestDB(t)
|
|
|
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached without auth")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want 401", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthInvalidToken(t *testing.T) {
|
|
pub, _ := generateTestKey(t)
|
|
database := openTestDB(t)
|
|
|
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached with invalid token")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
|
req.Header.Set("Authorization", "Bearer not.a.valid.jwt")
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want 401", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthRevokedToken(t *testing.T) {
|
|
pub, priv := generateTestKey(t)
|
|
database := openTestDB(t)
|
|
|
|
acct, err := database.CreateAccount("bob", model.AccountTypeHuman, "hash")
|
|
if err != nil {
|
|
t.Fatalf("CreateAccount: %v", err)
|
|
}
|
|
|
|
tokenStr := issueAndTrackToken(t, priv, database, acct.ID, nil)
|
|
|
|
// Extract JTI and revoke the token.
|
|
claims, err := token.ValidateToken(pub, tokenStr, testIssuer)
|
|
if err != nil {
|
|
t.Fatalf("ValidateToken: %v", err)
|
|
}
|
|
if err := database.RevokeToken(claims.JTI, "test revocation"); err != nil {
|
|
t.Fatalf("RevokeToken: %v", err)
|
|
}
|
|
|
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached with revoked token")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want 401", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireAuthExpiredToken(t *testing.T) {
|
|
pub, priv := generateTestKey(t)
|
|
database := openTestDB(t)
|
|
|
|
// Issue an already-expired token.
|
|
tokenStr, _, err := token.IssueToken(priv, testIssuer, "user-uuid", nil, -time.Minute)
|
|
if err != nil {
|
|
t.Fatalf("IssueToken: %v", err)
|
|
}
|
|
|
|
handler := RequireAuth(pub, database, testIssuer)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached with expired token")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+tokenStr)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusUnauthorized {
|
|
t.Errorf("status = %d, want 401", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleGranted(t *testing.T) {
|
|
claims := &token.Claims{Roles: []string{"admin"}}
|
|
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
|
|
|
reached := false
|
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
reached = true
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("status = %d, want 200", rr.Code)
|
|
}
|
|
if !reached {
|
|
t.Error("handler not reached with correct role")
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleForbidden(t *testing.T) {
|
|
claims := &token.Claims{Roles: []string{"reader"}}
|
|
ctx := context.WithValue(context.Background(), claimsKey, claims)
|
|
|
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached without admin role")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil).WithContext(ctx)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("status = %d, want 403", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRequireRoleNoClaims(t *testing.T) {
|
|
handler := RequireRole("admin")(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
t.Error("handler should not be reached without claims in context")
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
|
|
if rr.Code != http.StatusForbidden {
|
|
t.Errorf("status = %d, want 403", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitAllows(t *testing.T) {
|
|
handler := RateLimit(10, 5)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
|
req.RemoteAddr = "127.0.0.1:12345"
|
|
|
|
// First 5 requests should be allowed (burst=5).
|
|
for i := range 5 {
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusOK {
|
|
t.Errorf("request %d: status = %d, want 200", i+1, rr.Code)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestRateLimitBlocks(t *testing.T) {
|
|
handler := RateLimit(0.1, 2)(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
}))
|
|
|
|
req := httptest.NewRequest(http.MethodPost, "/v1/auth/login", nil)
|
|
req.RemoteAddr = "10.0.0.1:9999"
|
|
|
|
// Exhaust the burst of 2.
|
|
for range 2 {
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
}
|
|
|
|
// Next request should be rate-limited.
|
|
rr := httptest.NewRecorder()
|
|
handler.ServeHTTP(rr, req)
|
|
if rr.Code != http.StatusTooManyRequests {
|
|
t.Errorf("status = %d, want 429 after burst exceeded", rr.Code)
|
|
}
|
|
}
|
|
|
|
func TestExtractBearerToken(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
header string
|
|
want string
|
|
wantErr bool
|
|
}{
|
|
{"valid", "Bearer mytoken123", "mytoken123", false},
|
|
{"missing header", "", "", true},
|
|
{"no bearer prefix", "Token mytoken123", "", true},
|
|
{"empty token", "Bearer ", "", true},
|
|
{"case insensitive", "bearer mytoken123", "mytoken123", false},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
|
if tc.header != "" {
|
|
req.Header.Set("Authorization", tc.header)
|
|
}
|
|
got, err := extractBearerToken(req)
|
|
if (err != nil) != tc.wantErr {
|
|
t.Errorf("wantErr=%v, got err=%v", tc.wantErr, err)
|
|
}
|
|
if !tc.wantErr && got != tc.want {
|
|
t.Errorf("token = %q, want %q", got, tc.want)
|
|
}
|
|
})
|
|
}
|
|
}
|