Files
mcr/internal/webserver/server_test.go
Kyle Isom 9d7043a594 Block guest accounts from web UI login
The web UI now validates the MCIAS token after login and rejects
accounts with the guest role before setting the session cookie.
This is defense-in-depth alongside the env:restricted MCIAS tag.

The webserver.New() constructor takes a new ValidateFunc parameter
that inspects token roles post-authentication. MCIAS login does not
return roles, so this requires an extra ValidateToken round-trip at
login time (result is cached for 30s).

Security: guest role accounts are denied web UI access

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 23:02:22 -07:00

682 lines
18 KiB
Go

package webserver
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
)
// fakeRegistryService implements RegistryServiceServer for testing.
type fakeRegistryService struct {
mcrv1.UnimplementedRegistryServiceServer
repos []*mcrv1.RepositoryMetadata
repoResp *mcrv1.GetRepositoryResponse
repoErr error
}
func (f *fakeRegistryService) ListRepositories(_ context.Context, _ *mcrv1.ListRepositoriesRequest) (*mcrv1.ListRepositoriesResponse, error) {
return &mcrv1.ListRepositoriesResponse{Repositories: f.repos}, nil
}
func (f *fakeRegistryService) GetRepository(_ context.Context, req *mcrv1.GetRepositoryRequest) (*mcrv1.GetRepositoryResponse, error) {
if f.repoErr != nil {
return nil, f.repoErr
}
if f.repoResp != nil {
return f.repoResp, nil
}
return &mcrv1.GetRepositoryResponse{Name: req.GetName()}, nil
}
// fakePolicyService implements PolicyServiceServer for testing.
type fakePolicyService struct {
mcrv1.UnimplementedPolicyServiceServer
rules []*mcrv1.PolicyRule
created *mcrv1.PolicyRule
}
func (f *fakePolicyService) ListPolicyRules(_ context.Context, _ *mcrv1.ListPolicyRulesRequest) (*mcrv1.ListPolicyRulesResponse, error) {
return &mcrv1.ListPolicyRulesResponse{Rules: f.rules}, nil
}
func (f *fakePolicyService) CreatePolicyRule(_ context.Context, req *mcrv1.CreatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
rule := &mcrv1.PolicyRule{
Id: 1,
Priority: req.GetPriority(),
Description: req.GetDescription(),
Effect: req.GetEffect(),
Actions: req.GetActions(),
Repositories: req.GetRepositories(),
Enabled: req.GetEnabled(),
}
f.created = rule
return rule, nil
}
func (f *fakePolicyService) GetPolicyRule(_ context.Context, req *mcrv1.GetPolicyRuleRequest) (*mcrv1.PolicyRule, error) {
for _, r := range f.rules {
if r.GetId() == req.GetId() {
return r, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
func (f *fakePolicyService) UpdatePolicyRule(_ context.Context, req *mcrv1.UpdatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
for _, r := range f.rules {
if r.GetId() == req.GetId() {
r.Enabled = req.GetEnabled()
return r, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
func (f *fakePolicyService) DeletePolicyRule(_ context.Context, req *mcrv1.DeletePolicyRuleRequest) (*mcrv1.DeletePolicyRuleResponse, error) {
for i, r := range f.rules {
if r.GetId() == req.GetId() {
f.rules = append(f.rules[:i], f.rules[i+1:]...)
return &mcrv1.DeletePolicyRuleResponse{}, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
// fakeAuditService implements AuditServiceServer for testing.
type fakeAuditService struct {
mcrv1.UnimplementedAuditServiceServer
events []*mcrv1.AuditEvent
}
func (f *fakeAuditService) ListAuditEvents(_ context.Context, _ *mcrv1.ListAuditEventsRequest) (*mcrv1.ListAuditEventsResponse, error) {
return &mcrv1.ListAuditEventsResponse{Events: f.events}, nil
}
// fakeAdminService implements AdminServiceServer for testing.
type fakeAdminService struct {
mcrv1.UnimplementedAdminServiceServer
}
func (f *fakeAdminService) Health(_ context.Context, _ *mcrv1.HealthRequest) (*mcrv1.HealthResponse, error) {
return &mcrv1.HealthResponse{Status: "ok"}, nil
}
// testEnv holds a test server and its dependencies.
type testEnv struct {
server *Server
grpcServer *grpc.Server
grpcConn *grpc.ClientConn
registry *fakeRegistryService
policyFake *fakePolicyService
auditFake *fakeAuditService
}
func (e *testEnv) close() {
_ = e.grpcConn.Close()
e.grpcServer.Stop()
}
// setupTestEnv creates a test environment with fake gRPC backends.
func setupTestEnv(t *testing.T) *testEnv {
t.Helper()
registrySvc := &fakeRegistryService{
repos: []*mcrv1.RepositoryMetadata{
{Name: "library/nginx", TagCount: 3, ManifestCount: 2, TotalSize: 1024 * 1024, CreatedAt: "2024-01-15T10:00:00Z"},
{Name: "library/alpine", TagCount: 1, ManifestCount: 1, TotalSize: 512 * 1024, CreatedAt: "2024-01-16T10:00:00Z"},
},
}
policySvc := &fakePolicyService{
rules: []*mcrv1.PolicyRule{
{Id: 1, Priority: 100, Description: "Allow all pulls", Effect: "allow", Actions: []string{"pull"}, Repositories: []string{"*"}, Enabled: true},
},
}
auditSvc := &fakeAuditService{
events: []*mcrv1.AuditEvent{
{Id: 1, EventTime: "2024-01-15T12:00:00Z", EventType: "manifest_pushed", ActorId: "user1", Repository: "library/nginx", Digest: "sha256:abc123", IpAddress: "10.0.0.1"},
},
}
adminSvc := &fakeAdminService{}
// Start in-process gRPC server.
gs := grpc.NewServer()
mcrv1.RegisterRegistryServiceServer(gs, registrySvc)
mcrv1.RegisterPolicyServiceServer(gs, policySvc)
mcrv1.RegisterAuditServiceServer(gs, auditSvc)
mcrv1.RegisterAdminServiceServer(gs, adminSvc)
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
go func() {
_ = gs.Serve(lis)
}()
// Connect client.
conn, err := grpc.NewClient(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})),
)
if err != nil {
gs.Stop()
t.Fatalf("dial: %v", err)
}
csrfKey := []byte("test-csrf-key-32-bytes-long!1234")
loginFn := func(username, password string) (string, int, error) {
if username == "admin" && password == "secret" {
return "test-token-12345", 3600, nil
}
if username == "guest" && password == "secret" {
return "test-token-guest", 3600, nil
}
if username == "user" && password == "secret" {
return "test-token-user", 3600, nil
}
return "", 0, fmt.Errorf("invalid credentials")
}
validateFn := func(token string) ([]string, error) {
switch token {
case "test-token-12345":
return []string{"admin"}, nil
case "test-token-guest":
return []string{"guest"}, nil
case "test-token-user":
return []string{"user"}, nil
default:
return nil, fmt.Errorf("invalid token")
}
}
srv, err := New(
mcrv1.NewRegistryServiceClient(conn),
mcrv1.NewPolicyServiceClient(conn),
mcrv1.NewAuditServiceClient(conn),
mcrv1.NewAdminServiceClient(conn),
loginFn,
validateFn,
csrfKey,
)
if err != nil {
_ = conn.Close()
gs.Stop()
t.Fatalf("create server: %v", err)
}
return &testEnv{
server: srv,
grpcServer: gs,
grpcConn: conn,
registry: registrySvc,
policyFake: policySvc,
auditFake: auditSvc,
}
}
func TestLoginPageRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /login: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "MCR Login") {
t.Error("login page does not contain 'MCR Login'")
}
if !strings.Contains(body, "_csrf") {
t.Error("login page does not contain CSRF token field")
}
}
func TestLoginInvalidCredentials(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// First get a CSRF token.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(getRec, getReq)
// Extract CSRF cookie and token.
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if c.Name == "csrf_token" {
csrfCookie = c
break
}
}
if csrfCookie == nil {
t.Fatal("no csrf_token cookie set")
}
// Extract the CSRF token from the cookie value (token.signature).
parts := strings.SplitN(csrfCookie.Value, ".", 2)
csrfToken := parts[0]
// Submit login with wrong credentials.
form := url.Values{
"username": {"baduser"},
"password": {"badpass"},
"_csrf": {csrfToken},
}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
postReq.AddCookie(csrfCookie)
postRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(postRec, postReq)
if postRec.Code != http.StatusOK {
t.Fatalf("POST /login: status %d, want %d", postRec.Code, http.StatusOK)
}
body := postRec.Body.String()
if !strings.Contains(body, "Invalid username or password") {
t.Error("response does not contain error message for invalid credentials")
}
}
func TestDashboardRequiresSession(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("GET / without session: status %d, want %d", rec.Code, http.StatusSeeOther)
}
loc := rec.Header().Get("Location")
if loc != "/login" {
t.Fatalf("redirect location: got %q, want /login", loc)
}
}
func TestDashboardWithSession(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET / with session: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "Dashboard") {
t.Error("dashboard page does not contain 'Dashboard'")
}
if !strings.Contains(body, "Repositories") {
t.Error("dashboard page does not show repository count")
}
}
func TestRepositoriesPageRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/repositories", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /repositories: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "library/nginx") {
t.Error("repositories page does not contain 'library/nginx'")
}
if !strings.Contains(body, "library/alpine") {
t.Error("repositories page does not contain 'library/alpine'")
}
}
func TestRepositoryDetailRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
env.registry.repoResp = &mcrv1.GetRepositoryResponse{
Name: "library/nginx",
TotalSize: 2048,
Tags: []*mcrv1.TagInfo{
{Name: "latest", Digest: "sha256:abc123def456"},
},
Manifests: []*mcrv1.ManifestInfo{
{Digest: "sha256:abc123def456", MediaType: "application/vnd.oci.image.manifest.v1+json", Size: 2048, CreatedAt: "2024-01-15T10:00:00Z"},
},
}
req := httptest.NewRequest(http.MethodGet, "/repositories/library/nginx", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /repositories/library/nginx: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "library/nginx") {
t.Error("repository detail page does not contain repo name")
}
if !strings.Contains(body, "latest") {
t.Error("repository detail page does not contain tag 'latest'")
}
}
func TestCSRFTokenValidation(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// POST without CSRF token should fail.
form := url.Values{
"username": {"admin"},
"password": {"secret"},
}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
body := rec.Body.String()
// Should show the error about invalid form submission.
if !strings.Contains(body, "Invalid or expired form submission") {
t.Error("POST without CSRF token should show error, got: " + body[:min(200, len(body))])
}
}
func TestLogout(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/logout", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("GET /logout: status %d, want %d", rec.Code, http.StatusSeeOther)
}
loc := rec.Header().Get("Location")
if loc != "/login" {
t.Fatalf("redirect location: got %q, want /login", loc)
}
// Verify session cookie is cleared.
var sessionCleared bool
for _, c := range rec.Result().Cookies() {
if c.Name == "mcr_session" && c.MaxAge < 0 {
sessionCleared = true
break
}
}
if !sessionCleared {
t.Error("session cookie was not cleared on logout")
}
}
func TestPoliciesPage(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/policies", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /policies: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "Allow all pulls") {
t.Error("policies page does not contain policy description")
}
}
func TestAuditPage(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/audit", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /audit: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "manifest_pushed") {
t.Error("audit page does not contain event type")
}
}
func TestStaticFiles(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /static/style.css: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "font-family") {
t.Error("style.css does not appear to contain CSS")
}
}
func TestFormatSize(t *testing.T) {
tests := []struct {
input int64
want string
}{
{0, "0 B"},
{512, "512 B"},
{1024, "1.0 KiB"},
{1048576, "1.0 MiB"},
{1073741824, "1.0 GiB"},
{1099511627776, "1.0 TiB"},
}
for _, tt := range tests {
got := formatSize(tt.input)
if got != tt.want {
t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestFormatTime(t *testing.T) {
got := formatTime("2024-01-15T10:30:00Z")
want := "2024-01-15 10:30:00"
if got != want {
t.Errorf("formatTime = %q, want %q", got, want)
}
// Invalid time returns the input.
got = formatTime("not-a-time")
if got != "not-a-time" {
t.Errorf("formatTime(invalid) = %q, want %q", got, "not-a-time")
}
}
func TestTruncate(t *testing.T) {
got := truncate("sha256:abc123def456", 12)
want := "sha256:abc12..."
if got != want {
t.Errorf("truncate = %q, want %q", got, want)
}
// Short strings are not truncated.
got = truncate("short", 10)
if got != "short" {
t.Errorf("truncate(short) = %q, want %q", got, "short")
}
}
func TestLoginDeniesGuest(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// Get CSRF token.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(getRec, getReq)
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if c.Name == "csrf_token" {
csrfCookie = c
break
}
}
if csrfCookie == nil {
t.Fatal("no csrf_token cookie")
}
parts := strings.SplitN(csrfCookie.Value, ".", 2)
csrfToken := parts[0]
form := url.Values{
"username": {"guest"},
"password": {"secret"},
"_csrf": {csrfToken},
}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
postReq.AddCookie(csrfCookie)
postRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(postRec, postReq)
if postRec.Code != http.StatusOK {
t.Fatalf("POST /login as guest: status %d, want %d", postRec.Code, http.StatusOK)
}
body := postRec.Body.String()
if !strings.Contains(body, "Guest accounts are not permitted") {
t.Error("response does not contain guest denial message")
}
// Verify no session cookie was set.
for _, c := range postRec.Result().Cookies() {
if c.Name == "mcr_session" {
t.Error("session cookie should not be set for guest login")
}
}
}
func TestLoginSuccessSetsCookie(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// Get CSRF token.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(getRec, getReq)
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if c.Name == "csrf_token" {
csrfCookie = c
break
}
}
if csrfCookie == nil {
t.Fatal("no csrf_token cookie")
}
parts := strings.SplitN(csrfCookie.Value, ".", 2)
csrfToken := parts[0]
form := url.Values{
"username": {"admin"},
"password": {"secret"},
"_csrf": {csrfToken},
}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
postReq.AddCookie(csrfCookie)
postRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(postRec, postReq)
if postRec.Code != http.StatusSeeOther {
t.Fatalf("POST /login: status %d, want %d; body: %s", postRec.Code, http.StatusSeeOther, postRec.Body.String())
}
var sessionCookie *http.Cookie
for _, c := range postRec.Result().Cookies() {
if c.Name == "mcr_session" {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("no mcr_session cookie set after login")
}
if sessionCookie.Value != "test-token-12345" {
t.Errorf("session cookie value = %q, want %q", sessionCookie.Value, "test-token-12345")
}
if !sessionCookie.HttpOnly {
t.Error("session cookie is not HttpOnly")
}
if !sessionCookie.Secure {
t.Error("session cookie is not Secure")
}
if sessionCookie.SameSite != http.SameSiteStrictMode {
t.Error("session cookie SameSite is not Strict")
}
}