Phases 11, 12: mcrctl CLI tool and mcr-web UI
Phase 11 implements the admin CLI with dual REST/gRPC transport, global flags (--server, --grpc, --token, --ca-cert, --json), and all commands: status, repo list/delete, policy CRUD, audit tail, gc trigger/status/reconcile, and snapshot. Phase 12 implements the HTMX web UI with chi router, session-based auth (HttpOnly/Secure/SameSite=Strict cookies), CSRF protection (HMAC-SHA256 signed double-submit), and pages for dashboard, repositories, manifest detail, policy management, and audit log. Security: CSRF via signed double-submit cookie, session cookies with HttpOnly/Secure/SameSite=Strict, TLS 1.3 minimum on all connections, form body size limits via http.MaxBytesReader. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
608
internal/webserver/server_test.go
Normal file
608
internal/webserver/server_test.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
)
|
||||
|
||||
// fakeRegistryService implements RegistryServiceServer for testing.
|
||||
type fakeRegistryService struct {
|
||||
mcrv1.UnimplementedRegistryServiceServer
|
||||
repos []*mcrv1.RepositoryMetadata
|
||||
repoResp *mcrv1.GetRepositoryResponse
|
||||
repoErr error
|
||||
}
|
||||
|
||||
func (f *fakeRegistryService) ListRepositories(_ context.Context, _ *mcrv1.ListRepositoriesRequest) (*mcrv1.ListRepositoriesResponse, error) {
|
||||
return &mcrv1.ListRepositoriesResponse{Repositories: f.repos}, nil
|
||||
}
|
||||
|
||||
func (f *fakeRegistryService) GetRepository(_ context.Context, req *mcrv1.GetRepositoryRequest) (*mcrv1.GetRepositoryResponse, error) {
|
||||
if f.repoErr != nil {
|
||||
return nil, f.repoErr
|
||||
}
|
||||
if f.repoResp != nil {
|
||||
return f.repoResp, nil
|
||||
}
|
||||
return &mcrv1.GetRepositoryResponse{Name: req.GetName()}, nil
|
||||
}
|
||||
|
||||
// fakePolicyService implements PolicyServiceServer for testing.
|
||||
type fakePolicyService struct {
|
||||
mcrv1.UnimplementedPolicyServiceServer
|
||||
rules []*mcrv1.PolicyRule
|
||||
created *mcrv1.PolicyRule
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) ListPolicyRules(_ context.Context, _ *mcrv1.ListPolicyRulesRequest) (*mcrv1.ListPolicyRulesResponse, error) {
|
||||
return &mcrv1.ListPolicyRulesResponse{Rules: f.rules}, nil
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) CreatePolicyRule(_ context.Context, req *mcrv1.CreatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
rule := &mcrv1.PolicyRule{
|
||||
Id: 1,
|
||||
Priority: req.GetPriority(),
|
||||
Description: req.GetDescription(),
|
||||
Effect: req.GetEffect(),
|
||||
Actions: req.GetActions(),
|
||||
Repositories: req.GetRepositories(),
|
||||
Enabled: req.GetEnabled(),
|
||||
}
|
||||
f.created = rule
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) GetPolicyRule(_ context.Context, req *mcrv1.GetPolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
for _, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) UpdatePolicyRule(_ context.Context, req *mcrv1.UpdatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
for _, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
r.Enabled = req.GetEnabled()
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) DeletePolicyRule(_ context.Context, req *mcrv1.DeletePolicyRuleRequest) (*mcrv1.DeletePolicyRuleResponse, error) {
|
||||
for i, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
f.rules = append(f.rules[:i], f.rules[i+1:]...)
|
||||
return &mcrv1.DeletePolicyRuleResponse{}, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
// fakeAuditService implements AuditServiceServer for testing.
|
||||
type fakeAuditService struct {
|
||||
mcrv1.UnimplementedAuditServiceServer
|
||||
events []*mcrv1.AuditEvent
|
||||
}
|
||||
|
||||
func (f *fakeAuditService) ListAuditEvents(_ context.Context, _ *mcrv1.ListAuditEventsRequest) (*mcrv1.ListAuditEventsResponse, error) {
|
||||
return &mcrv1.ListAuditEventsResponse{Events: f.events}, nil
|
||||
}
|
||||
|
||||
// fakeAdminService implements AdminServiceServer for testing.
|
||||
type fakeAdminService struct {
|
||||
mcrv1.UnimplementedAdminServiceServer
|
||||
}
|
||||
|
||||
func (f *fakeAdminService) Health(_ context.Context, _ *mcrv1.HealthRequest) (*mcrv1.HealthResponse, error) {
|
||||
return &mcrv1.HealthResponse{Status: "ok"}, nil
|
||||
}
|
||||
|
||||
// testEnv holds a test server and its dependencies.
|
||||
type testEnv struct {
|
||||
server *Server
|
||||
grpcServer *grpc.Server
|
||||
grpcConn *grpc.ClientConn
|
||||
registry *fakeRegistryService
|
||||
policyFake *fakePolicyService
|
||||
auditFake *fakeAuditService
|
||||
}
|
||||
|
||||
func (e *testEnv) close() {
|
||||
_ = e.grpcConn.Close()
|
||||
e.grpcServer.Stop()
|
||||
}
|
||||
|
||||
// setupTestEnv creates a test environment with fake gRPC backends.
|
||||
func setupTestEnv(t *testing.T) *testEnv {
|
||||
t.Helper()
|
||||
|
||||
registrySvc := &fakeRegistryService{
|
||||
repos: []*mcrv1.RepositoryMetadata{
|
||||
{Name: "library/nginx", TagCount: 3, ManifestCount: 2, TotalSize: 1024 * 1024, CreatedAt: "2024-01-15T10:00:00Z"},
|
||||
{Name: "library/alpine", TagCount: 1, ManifestCount: 1, TotalSize: 512 * 1024, CreatedAt: "2024-01-16T10:00:00Z"},
|
||||
},
|
||||
}
|
||||
policySvc := &fakePolicyService{
|
||||
rules: []*mcrv1.PolicyRule{
|
||||
{Id: 1, Priority: 100, Description: "Allow all pulls", Effect: "allow", Actions: []string{"pull"}, Repositories: []string{"*"}, Enabled: true},
|
||||
},
|
||||
}
|
||||
auditSvc := &fakeAuditService{
|
||||
events: []*mcrv1.AuditEvent{
|
||||
{Id: 1, EventTime: "2024-01-15T12:00:00Z", EventType: "manifest_pushed", ActorId: "user1", Repository: "library/nginx", Digest: "sha256:abc123", IpAddress: "10.0.0.1"},
|
||||
},
|
||||
}
|
||||
adminSvc := &fakeAdminService{}
|
||||
|
||||
// Start in-process gRPC server.
|
||||
gs := grpc.NewServer()
|
||||
mcrv1.RegisterRegistryServiceServer(gs, registrySvc)
|
||||
mcrv1.RegisterPolicyServiceServer(gs, policySvc)
|
||||
mcrv1.RegisterAuditServiceServer(gs, auditSvc)
|
||||
mcrv1.RegisterAdminServiceServer(gs, adminSvc)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = gs.Serve(lis)
|
||||
}()
|
||||
|
||||
// Connect client.
|
||||
conn, err := grpc.NewClient(
|
||||
lis.Addr().String(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})),
|
||||
)
|
||||
if err != nil {
|
||||
gs.Stop()
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
csrfKey := []byte("test-csrf-key-32-bytes-long!1234")
|
||||
|
||||
loginFn := func(username, password string) (string, int, error) {
|
||||
if username == "admin" && password == "secret" {
|
||||
return "test-token-12345", 3600, nil
|
||||
}
|
||||
return "", 0, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
srv, err := New(
|
||||
mcrv1.NewRegistryServiceClient(conn),
|
||||
mcrv1.NewPolicyServiceClient(conn),
|
||||
mcrv1.NewAuditServiceClient(conn),
|
||||
mcrv1.NewAdminServiceClient(conn),
|
||||
loginFn,
|
||||
csrfKey,
|
||||
)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
gs.Stop()
|
||||
t.Fatalf("create server: %v", err)
|
||||
}
|
||||
|
||||
return &testEnv{
|
||||
server: srv,
|
||||
grpcServer: gs,
|
||||
grpcConn: conn,
|
||||
registry: registrySvc,
|
||||
policyFake: policySvc,
|
||||
auditFake: auditSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginPageRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /login: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "MCR Login") {
|
||||
t.Error("login page does not contain 'MCR Login'")
|
||||
}
|
||||
if !strings.Contains(body, "_csrf") {
|
||||
t.Error("login page does not contain CSRF token field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginInvalidCredentials(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// First get a CSRF token.
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
env.server.Handler().ServeHTTP(getRec, getReq)
|
||||
|
||||
// Extract CSRF cookie and token.
|
||||
var csrfCookie *http.Cookie
|
||||
for _, c := range getRec.Result().Cookies() {
|
||||
if c.Name == "csrf_token" {
|
||||
csrfCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if csrfCookie == nil {
|
||||
t.Fatal("no csrf_token cookie set")
|
||||
}
|
||||
|
||||
// Extract the CSRF token from the cookie value (token.signature).
|
||||
parts := strings.SplitN(csrfCookie.Value, ".", 2)
|
||||
csrfToken := parts[0]
|
||||
|
||||
// Submit login with wrong credentials.
|
||||
form := url.Values{
|
||||
"username": {"baduser"},
|
||||
"password": {"badpass"},
|
||||
"_csrf": {csrfToken},
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq.AddCookie(csrfCookie)
|
||||
postRec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(postRec, postReq)
|
||||
|
||||
if postRec.Code != http.StatusOK {
|
||||
t.Fatalf("POST /login: status %d, want %d", postRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := postRec.Body.String()
|
||||
if !strings.Contains(body, "Invalid username or password") {
|
||||
t.Error("response does not contain error message for invalid credentials")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardRequiresSession(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("GET / without session: status %d, want %d", rec.Code, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
loc := rec.Header().Get("Location")
|
||||
if loc != "/login" {
|
||||
t.Fatalf("redirect location: got %q, want /login", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardWithSession(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET / with session: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "Dashboard") {
|
||||
t.Error("dashboard page does not contain 'Dashboard'")
|
||||
}
|
||||
if !strings.Contains(body, "Repositories") {
|
||||
t.Error("dashboard page does not show repository count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepositoriesPageRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/repositories", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /repositories: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "library/nginx") {
|
||||
t.Error("repositories page does not contain 'library/nginx'")
|
||||
}
|
||||
if !strings.Contains(body, "library/alpine") {
|
||||
t.Error("repositories page does not contain 'library/alpine'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepositoryDetailRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
env.registry.repoResp = &mcrv1.GetRepositoryResponse{
|
||||
Name: "library/nginx",
|
||||
TotalSize: 2048,
|
||||
Tags: []*mcrv1.TagInfo{
|
||||
{Name: "latest", Digest: "sha256:abc123def456"},
|
||||
},
|
||||
Manifests: []*mcrv1.ManifestInfo{
|
||||
{Digest: "sha256:abc123def456", MediaType: "application/vnd.oci.image.manifest.v1+json", Size: 2048, CreatedAt: "2024-01-15T10:00:00Z"},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/repositories/library/nginx", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /repositories/library/nginx: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "library/nginx") {
|
||||
t.Error("repository detail page does not contain repo name")
|
||||
}
|
||||
if !strings.Contains(body, "latest") {
|
||||
t.Error("repository detail page does not contain tag 'latest'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFTokenValidation(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// POST without CSRF token should fail.
|
||||
form := url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"secret"},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
body := rec.Body.String()
|
||||
// Should show the error about invalid form submission.
|
||||
if !strings.Contains(body, "Invalid or expired form submission") {
|
||||
t.Error("POST without CSRF token should show error, got: " + body[:min(200, len(body))])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("GET /logout: status %d, want %d", rec.Code, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
loc := rec.Header().Get("Location")
|
||||
if loc != "/login" {
|
||||
t.Fatalf("redirect location: got %q, want /login", loc)
|
||||
}
|
||||
|
||||
// Verify session cookie is cleared.
|
||||
var sessionCleared bool
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == "mcr_session" && c.MaxAge < 0 {
|
||||
sessionCleared = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sessionCleared {
|
||||
t.Error("session cookie was not cleared on logout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoliciesPage(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/policies", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /policies: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "Allow all pulls") {
|
||||
t.Error("policies page does not contain policy description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditPage(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/audit", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /audit: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "manifest_pushed") {
|
||||
t.Error("audit page does not contain event type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticFiles(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /static/style.css: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "font-family") {
|
||||
t.Error("style.css does not appear to contain CSS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int64
|
||||
want string
|
||||
}{
|
||||
{0, "0 B"},
|
||||
{512, "512 B"},
|
||||
{1024, "1.0 KiB"},
|
||||
{1048576, "1.0 MiB"},
|
||||
{1073741824, "1.0 GiB"},
|
||||
{1099511627776, "1.0 TiB"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := formatSize(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTime(t *testing.T) {
|
||||
got := formatTime("2024-01-15T10:30:00Z")
|
||||
want := "2024-01-15 10:30:00"
|
||||
if got != want {
|
||||
t.Errorf("formatTime = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Invalid time returns the input.
|
||||
got = formatTime("not-a-time")
|
||||
if got != "not-a-time" {
|
||||
t.Errorf("formatTime(invalid) = %q, want %q", got, "not-a-time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
got := truncate("sha256:abc123def456", 12)
|
||||
want := "sha256:abc12..."
|
||||
if got != want {
|
||||
t.Errorf("truncate = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Short strings are not truncated.
|
||||
got = truncate("short", 10)
|
||||
if got != "short" {
|
||||
t.Errorf("truncate(short) = %q, want %q", got, "short")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginSuccessSetsCookie(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// Get CSRF token.
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
env.server.Handler().ServeHTTP(getRec, getReq)
|
||||
|
||||
var csrfCookie *http.Cookie
|
||||
for _, c := range getRec.Result().Cookies() {
|
||||
if c.Name == "csrf_token" {
|
||||
csrfCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if csrfCookie == nil {
|
||||
t.Fatal("no csrf_token cookie")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(csrfCookie.Value, ".", 2)
|
||||
csrfToken := parts[0]
|
||||
|
||||
form := url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"secret"},
|
||||
"_csrf": {csrfToken},
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq.AddCookie(csrfCookie)
|
||||
postRec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(postRec, postReq)
|
||||
|
||||
if postRec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("POST /login: status %d, want %d; body: %s", postRec.Code, http.StatusSeeOther, postRec.Body.String())
|
||||
}
|
||||
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range postRec.Result().Cookies() {
|
||||
if c.Name == "mcr_session" {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("no mcr_session cookie set after login")
|
||||
}
|
||||
if sessionCookie.Value != "test-token-12345" {
|
||||
t.Errorf("session cookie value = %q, want %q", sessionCookie.Value, "test-token-12345")
|
||||
}
|
||||
if !sessionCookie.HttpOnly {
|
||||
t.Error("session cookie is not HttpOnly")
|
||||
}
|
||||
if !sessionCookie.Secure {
|
||||
t.Error("session cookie is not Secure")
|
||||
}
|
||||
if sessionCookie.SameSite != http.SameSiteStrictMode {
|
||||
t.Error("session cookie SameSite is not Strict")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user