package webserver import ( "context" "fmt" "net" "net/http" "net/http/httptest" "net/url" "strings" "testing" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/status" mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1" ) // fakeRegistryService implements RegistryServiceServer for testing. type fakeRegistryService struct { mcrv1.UnimplementedRegistryServiceServer repos []*mcrv1.RepositoryMetadata repoResp *mcrv1.GetRepositoryResponse repoErr error } func (f *fakeRegistryService) ListRepositories(_ context.Context, _ *mcrv1.ListRepositoriesRequest) (*mcrv1.ListRepositoriesResponse, error) { return &mcrv1.ListRepositoriesResponse{Repositories: f.repos}, nil } func (f *fakeRegistryService) GetRepository(_ context.Context, req *mcrv1.GetRepositoryRequest) (*mcrv1.GetRepositoryResponse, error) { if f.repoErr != nil { return nil, f.repoErr } if f.repoResp != nil { return f.repoResp, nil } return &mcrv1.GetRepositoryResponse{Name: req.GetName()}, nil } // fakePolicyService implements PolicyServiceServer for testing. type fakePolicyService struct { mcrv1.UnimplementedPolicyServiceServer rules []*mcrv1.PolicyRule created *mcrv1.PolicyRule } func (f *fakePolicyService) ListPolicyRules(_ context.Context, _ *mcrv1.ListPolicyRulesRequest) (*mcrv1.ListPolicyRulesResponse, error) { return &mcrv1.ListPolicyRulesResponse{Rules: f.rules}, nil } func (f *fakePolicyService) CreatePolicyRule(_ context.Context, req *mcrv1.CreatePolicyRuleRequest) (*mcrv1.PolicyRule, error) { rule := &mcrv1.PolicyRule{ Id: 1, Priority: req.GetPriority(), Description: req.GetDescription(), Effect: req.GetEffect(), Actions: req.GetActions(), Repositories: req.GetRepositories(), Enabled: req.GetEnabled(), } f.created = rule return rule, nil } func (f *fakePolicyService) GetPolicyRule(_ context.Context, req *mcrv1.GetPolicyRuleRequest) (*mcrv1.PolicyRule, error) { for _, r := range f.rules { if r.GetId() == req.GetId() { return r, nil } } return nil, status.Errorf(codes.NotFound, "policy rule not found") } func (f *fakePolicyService) UpdatePolicyRule(_ context.Context, req *mcrv1.UpdatePolicyRuleRequest) (*mcrv1.PolicyRule, error) { for _, r := range f.rules { if r.GetId() == req.GetId() { r.Enabled = req.GetEnabled() return r, nil } } return nil, status.Errorf(codes.NotFound, "policy rule not found") } func (f *fakePolicyService) DeletePolicyRule(_ context.Context, req *mcrv1.DeletePolicyRuleRequest) (*mcrv1.DeletePolicyRuleResponse, error) { for i, r := range f.rules { if r.GetId() == req.GetId() { f.rules = append(f.rules[:i], f.rules[i+1:]...) return &mcrv1.DeletePolicyRuleResponse{}, nil } } return nil, status.Errorf(codes.NotFound, "policy rule not found") } // fakeAuditService implements AuditServiceServer for testing. type fakeAuditService struct { mcrv1.UnimplementedAuditServiceServer events []*mcrv1.AuditEvent } func (f *fakeAuditService) ListAuditEvents(_ context.Context, _ *mcrv1.ListAuditEventsRequest) (*mcrv1.ListAuditEventsResponse, error) { return &mcrv1.ListAuditEventsResponse{Events: f.events}, nil } // fakeAdminService implements AdminServiceServer for testing. type fakeAdminService struct { mcrv1.UnimplementedAdminServiceServer } func (f *fakeAdminService) Health(_ context.Context, _ *mcrv1.HealthRequest) (*mcrv1.HealthResponse, error) { return &mcrv1.HealthResponse{Status: "ok"}, nil } // testEnv holds a test server and its dependencies. type testEnv struct { server *Server grpcServer *grpc.Server grpcConn *grpc.ClientConn registry *fakeRegistryService policyFake *fakePolicyService auditFake *fakeAuditService } func (e *testEnv) close() { _ = e.grpcConn.Close() e.grpcServer.Stop() } // setupTestEnv creates a test environment with fake gRPC backends. func setupTestEnv(t *testing.T) *testEnv { t.Helper() registrySvc := &fakeRegistryService{ repos: []*mcrv1.RepositoryMetadata{ {Name: "library/nginx", TagCount: 3, ManifestCount: 2, TotalSize: 1024 * 1024, CreatedAt: "2024-01-15T10:00:00Z"}, {Name: "library/alpine", TagCount: 1, ManifestCount: 1, TotalSize: 512 * 1024, CreatedAt: "2024-01-16T10:00:00Z"}, }, } policySvc := &fakePolicyService{ rules: []*mcrv1.PolicyRule{ {Id: 1, Priority: 100, Description: "Allow all pulls", Effect: "allow", Actions: []string{"pull"}, Repositories: []string{"*"}, Enabled: true}, }, } auditSvc := &fakeAuditService{ events: []*mcrv1.AuditEvent{ {Id: 1, EventTime: "2024-01-15T12:00:00Z", EventType: "manifest_pushed", ActorId: "user1", Repository: "library/nginx", Digest: "sha256:abc123", IpAddress: "10.0.0.1"}, }, } adminSvc := &fakeAdminService{} // Start in-process gRPC server. gs := grpc.NewServer() mcrv1.RegisterRegistryServiceServer(gs, registrySvc) mcrv1.RegisterPolicyServiceServer(gs, policySvc) mcrv1.RegisterAuditServiceServer(gs, auditSvc) mcrv1.RegisterAdminServiceServer(gs, adminSvc) lis, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen: %v", err) } go func() { _ = gs.Serve(lis) }() // Connect client. conn, err := grpc.NewClient( lis.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})), ) if err != nil { gs.Stop() t.Fatalf("dial: %v", err) } csrfKey := []byte("test-csrf-key-32-bytes-long!1234") loginFn := func(username, password string) (string, int, error) { if username == "admin" && password == "secret" { return "test-token-12345", 3600, nil } return "", 0, fmt.Errorf("invalid credentials") } srv, err := New( mcrv1.NewRegistryServiceClient(conn), mcrv1.NewPolicyServiceClient(conn), mcrv1.NewAuditServiceClient(conn), mcrv1.NewAdminServiceClient(conn), loginFn, csrfKey, ) if err != nil { _ = conn.Close() gs.Stop() t.Fatalf("create server: %v", err) } return &testEnv{ server: srv, grpcServer: gs, grpcConn: conn, registry: registrySvc, policyFake: policySvc, auditFake: auditSvc, } } func TestLoginPageRenders(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/login", nil) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /login: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "MCR Login") { t.Error("login page does not contain 'MCR Login'") } if !strings.Contains(body, "_csrf") { t.Error("login page does not contain CSRF token field") } } func TestLoginInvalidCredentials(t *testing.T) { env := setupTestEnv(t) defer env.close() // First get a CSRF token. getReq := httptest.NewRequest(http.MethodGet, "/login", nil) getRec := httptest.NewRecorder() env.server.Handler().ServeHTTP(getRec, getReq) // Extract CSRF cookie and token. var csrfCookie *http.Cookie for _, c := range getRec.Result().Cookies() { if c.Name == "csrf_token" { csrfCookie = c break } } if csrfCookie == nil { t.Fatal("no csrf_token cookie set") } // Extract the CSRF token from the cookie value (token.signature). parts := strings.SplitN(csrfCookie.Value, ".", 2) csrfToken := parts[0] // Submit login with wrong credentials. form := url.Values{ "username": {"baduser"}, "password": {"badpass"}, "_csrf": {csrfToken}, } postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.AddCookie(csrfCookie) postRec := httptest.NewRecorder() env.server.Handler().ServeHTTP(postRec, postReq) if postRec.Code != http.StatusOK { t.Fatalf("POST /login: status %d, want %d", postRec.Code, http.StatusOK) } body := postRec.Body.String() if !strings.Contains(body, "Invalid username or password") { t.Error("response does not contain error message for invalid credentials") } } func TestDashboardRequiresSession(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/", nil) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusSeeOther { t.Fatalf("GET / without session: status %d, want %d", rec.Code, http.StatusSeeOther) } loc := rec.Header().Get("Location") if loc != "/login" { t.Fatalf("redirect location: got %q, want /login", loc) } } func TestDashboardWithSession(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET / with session: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "Dashboard") { t.Error("dashboard page does not contain 'Dashboard'") } if !strings.Contains(body, "Repositories") { t.Error("dashboard page does not show repository count") } } func TestRepositoriesPageRenders(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/repositories", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /repositories: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "library/nginx") { t.Error("repositories page does not contain 'library/nginx'") } if !strings.Contains(body, "library/alpine") { t.Error("repositories page does not contain 'library/alpine'") } } func TestRepositoryDetailRenders(t *testing.T) { env := setupTestEnv(t) defer env.close() env.registry.repoResp = &mcrv1.GetRepositoryResponse{ Name: "library/nginx", TotalSize: 2048, Tags: []*mcrv1.TagInfo{ {Name: "latest", Digest: "sha256:abc123def456"}, }, Manifests: []*mcrv1.ManifestInfo{ {Digest: "sha256:abc123def456", MediaType: "application/vnd.oci.image.manifest.v1+json", Size: 2048, CreatedAt: "2024-01-15T10:00:00Z"}, }, } req := httptest.NewRequest(http.MethodGet, "/repositories/library/nginx", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /repositories/library/nginx: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "library/nginx") { t.Error("repository detail page does not contain repo name") } if !strings.Contains(body, "latest") { t.Error("repository detail page does not contain tag 'latest'") } } func TestCSRFTokenValidation(t *testing.T) { env := setupTestEnv(t) defer env.close() // POST without CSRF token should fail. form := url.Values{ "username": {"admin"}, "password": {"secret"}, } req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) body := rec.Body.String() // Should show the error about invalid form submission. if !strings.Contains(body, "Invalid or expired form submission") { t.Error("POST without CSRF token should show error, got: " + body[:min(200, len(body))]) } } func TestLogout(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/logout", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusSeeOther { t.Fatalf("GET /logout: status %d, want %d", rec.Code, http.StatusSeeOther) } loc := rec.Header().Get("Location") if loc != "/login" { t.Fatalf("redirect location: got %q, want /login", loc) } // Verify session cookie is cleared. var sessionCleared bool for _, c := range rec.Result().Cookies() { if c.Name == "mcr_session" && c.MaxAge < 0 { sessionCleared = true break } } if !sessionCleared { t.Error("session cookie was not cleared on logout") } } func TestPoliciesPage(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/policies", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /policies: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "Allow all pulls") { t.Error("policies page does not contain policy description") } } func TestAuditPage(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/audit", nil) req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"}) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /audit: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "manifest_pushed") { t.Error("audit page does not contain event type") } } func TestStaticFiles(t *testing.T) { env := setupTestEnv(t) defer env.close() req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil) rec := httptest.NewRecorder() env.server.Handler().ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Fatalf("GET /static/style.css: status %d, want %d", rec.Code, http.StatusOK) } body := rec.Body.String() if !strings.Contains(body, "font-family") { t.Error("style.css does not appear to contain CSS") } } func TestFormatSize(t *testing.T) { tests := []struct { input int64 want string }{ {0, "0 B"}, {512, "512 B"}, {1024, "1.0 KiB"}, {1048576, "1.0 MiB"}, {1073741824, "1.0 GiB"}, {1099511627776, "1.0 TiB"}, } for _, tt := range tests { got := formatSize(tt.input) if got != tt.want { t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want) } } } func TestFormatTime(t *testing.T) { got := formatTime("2024-01-15T10:30:00Z") want := "2024-01-15 10:30:00" if got != want { t.Errorf("formatTime = %q, want %q", got, want) } // Invalid time returns the input. got = formatTime("not-a-time") if got != "not-a-time" { t.Errorf("formatTime(invalid) = %q, want %q", got, "not-a-time") } } func TestTruncate(t *testing.T) { got := truncate("sha256:abc123def456", 12) want := "sha256:abc12..." if got != want { t.Errorf("truncate = %q, want %q", got, want) } // Short strings are not truncated. got = truncate("short", 10) if got != "short" { t.Errorf("truncate(short) = %q, want %q", got, "short") } } func TestLoginSuccessSetsCookie(t *testing.T) { env := setupTestEnv(t) defer env.close() // Get CSRF token. getReq := httptest.NewRequest(http.MethodGet, "/login", nil) getRec := httptest.NewRecorder() env.server.Handler().ServeHTTP(getRec, getReq) var csrfCookie *http.Cookie for _, c := range getRec.Result().Cookies() { if c.Name == "csrf_token" { csrfCookie = c break } } if csrfCookie == nil { t.Fatal("no csrf_token cookie") } parts := strings.SplitN(csrfCookie.Value, ".", 2) csrfToken := parts[0] form := url.Values{ "username": {"admin"}, "password": {"secret"}, "_csrf": {csrfToken}, } postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode())) postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded") postReq.AddCookie(csrfCookie) postRec := httptest.NewRecorder() env.server.Handler().ServeHTTP(postRec, postReq) if postRec.Code != http.StatusSeeOther { t.Fatalf("POST /login: status %d, want %d; body: %s", postRec.Code, http.StatusSeeOther, postRec.Body.String()) } var sessionCookie *http.Cookie for _, c := range postRec.Result().Cookies() { if c.Name == "mcr_session" { sessionCookie = c break } } if sessionCookie == nil { t.Fatal("no mcr_session cookie set after login") } if sessionCookie.Value != "test-token-12345" { t.Errorf("session cookie value = %q, want %q", sessionCookie.Value, "test-token-12345") } if !sessionCookie.HttpOnly { t.Error("session cookie is not HttpOnly") } if !sessionCookie.Secure { t.Error("session cookie is not Secure") } if sessionCookie.SameSite != http.SameSiteStrictMode { t.Error("session cookie SameSite is not Strict") } }