package ui import ( "crypto/ed25519" "crypto/rand" "io" "log/slog" "net/http" "net/http/httptest" "strings" "testing" "git.wntrmute.dev/kyle/mcias/internal/config" "git.wntrmute.dev/kyle/mcias/internal/db" ) const testIssuer = "https://auth.example.com" // newTestMux creates a UIServer and returns the http.Handler used in production // (a ServeMux with all UI routes registered, wrapped with securityHeaders). func newTestMux(t *testing.T) http.Handler { t.Helper() pub, priv, err := ed25519.GenerateKey(rand.Reader) if err != nil { t.Fatalf("generate key: %v", err) } database, err := db.Open(":memory:") if err != nil { t.Fatalf("open db: %v", err) } if err := db.Migrate(database); err != nil { t.Fatalf("migrate db: %v", err) } t.Cleanup(func() { _ = database.Close() }) masterKey := make([]byte, 32) if _, err := rand.Read(masterKey); err != nil { t.Fatalf("generate master key: %v", err) } cfg := config.NewTestConfig(testIssuer) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) uiSrv, err := New(database, cfg, priv, pub, masterKey, logger) if err != nil { t.Fatalf("new UIServer: %v", err) } mux := http.NewServeMux() uiSrv.Register(mux) return mux } // assertSecurityHeaders verifies all mandatory defensive headers are present in // resp with acceptable values. The label is used in failure messages to identify // which endpoint the test was checking. func assertSecurityHeaders(t *testing.T, h http.Header, label string) { t.Helper() checks := []struct { header string wantSub string }{ {"Content-Security-Policy", "default-src 'self'"}, {"X-Content-Type-Options", "nosniff"}, {"X-Frame-Options", "DENY"}, {"Strict-Transport-Security", "max-age="}, {"Referrer-Policy", "no-referrer"}, } for _, c := range checks { val := h.Get(c.header) if val == "" { t.Errorf("[%s] missing security header %s", label, c.header) continue } if c.wantSub != "" && !strings.Contains(val, c.wantSub) { t.Errorf("[%s] %s = %q, want substring %q", label, c.header, val, c.wantSub) } } } // TestSecurityHeadersOnLoginPage verifies headers are present on the public login page. func TestSecurityHeadersOnLoginPage(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/login", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) assertSecurityHeaders(t, rr.Result().Header, "GET /login") } // TestSecurityHeadersOnUnauthenticatedDashboard verifies headers are present even // when the response is a redirect to login (no session cookie supplied). func TestSecurityHeadersOnUnauthenticatedDashboard(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/dashboard", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) assertSecurityHeaders(t, rr.Result().Header, "GET /dashboard (no session)") } // TestSecurityHeadersOnRootRedirect verifies headers on the "/" → "/login" redirect. func TestSecurityHeadersOnRootRedirect(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) assertSecurityHeaders(t, rr.Result().Header, "GET /") } // TestSecurityHeadersOnStaticAsset verifies headers are present on static file responses. func TestSecurityHeadersOnStaticAsset(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) // 200 or 404 — either way the securityHeaders wrapper must fire. assertSecurityHeaders(t, rr.Result().Header, "GET /static/style.css") } // TestCSPDirectives verifies the Content-Security-Policy includes same-origin // directives for scripts and styles. func TestCSPDirectives(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/login", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) csp := rr.Header().Get("Content-Security-Policy") for _, directive := range []string{ "default-src 'self'", "script-src 'self'", "style-src 'self'", } { if !strings.Contains(csp, directive) { t.Errorf("CSP missing directive %q; full value: %q", directive, csp) } } } // TestHSTSMinAge verifies HSTS max-age is at least two years (63072000 seconds). func TestHSTSMinAge(t *testing.T) { mux := newTestMux(t) req := httptest.NewRequest(http.MethodGet, "/login", nil) rr := httptest.NewRecorder() mux.ServeHTTP(rr, req) hsts := rr.Header().Get("Strict-Transport-Security") if !strings.Contains(hsts, "max-age=63072000") { t.Errorf("HSTS = %q, want max-age=63072000 (2 years)", hsts) } } // TestSecurityHeadersMiddlewareUnit tests the securityHeaders middleware in // isolation, independent of routing, to guard against future refactoring. func TestSecurityHeadersMiddlewareUnit(t *testing.T) { reached := false inner := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { reached = true w.WriteHeader(http.StatusOK) }) handler := securityHeaders(inner) req := httptest.NewRequest(http.MethodGet, "/test", nil) rr := httptest.NewRecorder() handler.ServeHTTP(rr, req) if !reached { t.Error("inner handler was not reached") } assertSecurityHeaders(t, rr.Result().Header, "unit test") }