Phases 11, 12: mcrctl CLI tool and mcr-web UI

Phase 11 implements the admin CLI with dual REST/gRPC transport,
global flags (--server, --grpc, --token, --ca-cert, --json), and
all commands: status, repo list/delete, policy CRUD, audit tail,
gc trigger/status/reconcile, and snapshot.

Phase 12 implements the HTMX web UI with chi router, session-based
auth (HttpOnly/Secure/SameSite=Strict cookies), CSRF protection
(HMAC-SHA256 signed double-submit), and pages for dashboard,
repositories, manifest detail, policy management, and audit log.

Security: CSRF via signed double-submit cookie, session cookies
with HttpOnly/Secure/SameSite=Strict, TLS 1.3 minimum on all
connections, form body size limits via http.MaxBytesReader.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-20 10:14:38 -07:00
parent 185b68ff6d
commit 593da3975d
23 changed files with 3737 additions and 66 deletions

184
internal/webserver/auth.go Normal file
View File

@@ -0,0 +1,184 @@
package webserver
import (
"context"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"log"
"net/http"
"strings"
)
// sessionKey is the context key for the session token.
type sessionKey struct{}
// tokenFromContext retrieves the bearer token from context.
func tokenFromContext(ctx context.Context) string {
s, _ := ctx.Value(sessionKey{}).(string)
return s
}
// contextWithToken stores a bearer token in the context.
func contextWithToken(ctx context.Context, token string) context.Context {
return context.WithValue(ctx, sessionKey{}, token)
}
// sessionMiddleware checks for a valid mcr_session cookie and adds the
// token to the request context. If no session is present, it redirects
// to the login page.
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cookie, err := r.Cookie("mcr_session")
if err != nil || cookie.Value == "" {
http.Redirect(w, r, "/login", http.StatusSeeOther)
return
}
ctx := contextWithToken(r.Context(), cookie.Value)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// handleLoginPage renders the login form.
func (s *Server) handleLoginPage(w http.ResponseWriter, r *http.Request) {
csrf := s.generateCSRFToken(w)
s.templates.render(w, "login", map[string]any{
"CSRFToken": csrf,
"Session": false,
})
}
// handleLoginSubmit processes the login form.
func (s *Server) handleLoginSubmit(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !s.validateCSRFToken(r) {
csrf := s.generateCSRFToken(w)
s.templates.render(w, "login", map[string]any{
"Error": "Invalid or expired form submission. Please try again.",
"CSRFToken": csrf,
"Session": false,
})
return
}
username := r.FormValue("username")
password := r.FormValue("password")
if username == "" || password == "" {
csrf := s.generateCSRFToken(w)
s.templates.render(w, "login", map[string]any{
"Error": "Username and password are required.",
"CSRFToken": csrf,
"Session": false,
})
return
}
token, _, err := s.loginFn(username, password)
if err != nil {
log.Printf("login failed for user %q: %v", username, err)
csrf := s.generateCSRFToken(w)
s.templates.render(w, "login", map[string]any{
"Error": "Invalid username or password.",
"CSRFToken": csrf,
"Session": false,
})
return
}
http.SetCookie(w, &http.Cookie{
Name: "mcr_session",
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(w, r, "/", http.StatusSeeOther)
}
// handleLogout clears the session and redirects to login.
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
http.SetCookie(w, &http.Cookie{
Name: "mcr_session",
Value: "",
Path: "/",
MaxAge: -1,
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(w, r, "/login", http.StatusSeeOther)
}
// generateCSRFToken creates a random token, signs it with HMAC, stores
// the signed value in a cookie, and returns the token for form embedding.
func (s *Server) generateCSRFToken(w http.ResponseWriter) string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
// Crypto RNG failure is fatal; this should never happen.
log.Printf("csrf: failed to generate random bytes: %v", err)
return ""
}
token := hex.EncodeToString(b)
sig := s.signCSRF(token)
cookieVal := token + "." + sig
http.SetCookie(w, &http.Cookie{
Name: "csrf_token",
Value: cookieVal,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
return token
}
// validateCSRFToken verifies the form _csrf field matches the cookie and
// the HMAC signature is valid.
func (s *Server) validateCSRFToken(r *http.Request) bool {
formToken := r.FormValue("_csrf")
if formToken == "" {
return false
}
cookie, err := r.Cookie("csrf_token")
if err != nil || cookie.Value == "" {
return false
}
parts := strings.SplitN(cookie.Value, ".", 2)
if len(parts) != 2 {
return false
}
cookieToken := parts[0]
cookieSig := parts[1]
// Verify the form token matches the cookie token.
if !hmac.Equal([]byte(formToken), []byte(cookieToken)) {
return false
}
// Verify the HMAC signature.
expectedSig := s.signCSRF(cookieToken)
return hmac.Equal([]byte(cookieSig), []byte(expectedSig))
}
// signCSRF computes an HMAC-SHA256 signature for a CSRF token.
func (s *Server) signCSRF(token string) string {
mac := hmac.New(sha256.New, s.csrfKey)
mac.Write([]byte(token))
return hex.EncodeToString(mac.Sum(nil))
}

View File

@@ -0,0 +1,457 @@
package webserver
import (
"context"
"net/http"
"net/url"
"strconv"
"strings"
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// grpcContext creates a context with the bearer token from the session
// attached as gRPC outgoing metadata.
func grpcContext(r *http.Request) context.Context {
token := tokenFromContext(r.Context())
return metadata.AppendToOutgoingContext(r.Context(), "authorization", "Bearer "+token)
}
// handleDashboard renders the dashboard with repo stats and recent activity.
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
ctx := grpcContext(r)
repos, err := s.registry.ListRepositories(ctx, &mcrv1.ListRepositoriesRequest{})
if err != nil {
s.renderError(w, "dashboard", "Failed to load repositories.", err)
return
}
var repoCount int
var totalSize int64
for _, repo := range repos.GetRepositories() {
repoCount++
totalSize += repo.GetTotalSize()
}
// Fetch recent audit events for dashboard activity.
var events []*mcrv1.AuditEvent
auditResp, auditErr := s.audit.ListAuditEvents(ctx, &mcrv1.ListAuditEventsRequest{
Pagination: &mcrv1.PaginationRequest{Limit: 10},
})
if auditErr == nil {
events = auditResp.GetEvents()
}
// If audit fails with PermissionDenied, just show no events (user is not admin).
s.templates.render(w, "dashboard", map[string]any{
"Session": true,
"RepoCount": repoCount,
"TotalSize": formatSize(totalSize),
"Events": events,
})
}
// handleRepositories renders the repository list.
func (s *Server) handleRepositories(w http.ResponseWriter, r *http.Request) {
ctx := grpcContext(r)
resp, err := s.registry.ListRepositories(ctx, &mcrv1.ListRepositoriesRequest{})
if err != nil {
s.renderError(w, "repositories", "Failed to load repositories.", err)
return
}
s.templates.render(w, "repositories", map[string]any{
"Session": true,
"Repositories": resp.GetRepositories(),
})
}
// handleRepositoryDetail renders a single repository's tags and manifests.
func (s *Server) handleRepositoryDetail(w http.ResponseWriter, r *http.Request) {
name := extractRepoName(r)
if name == "" {
http.NotFound(w, r)
return
}
ctx := grpcContext(r)
resp, err := s.registry.GetRepository(ctx, &mcrv1.GetRepositoryRequest{Name: name})
if err != nil {
s.templates.render(w, "repository_detail", map[string]any{
"Session": true,
"Name": name,
"Error": grpcErrorMessage(err),
})
return
}
s.templates.render(w, "repository_detail", map[string]any{
"Session": true,
"Name": resp.GetName(),
"Tags": resp.GetTags(),
"Manifests": resp.GetManifests(),
"TotalSize": resp.GetTotalSize(),
})
}
// handleManifestDetail renders details for a specific manifest.
func (s *Server) handleManifestDetail(w http.ResponseWriter, r *http.Request) {
// URL format: /repositories/{name}/manifests/{digest}
// The name can contain slashes, so we parse manually.
path := r.URL.Path
const manifestsPrefix = "/manifests/"
idx := strings.LastIndex(path, manifestsPrefix)
if idx < 0 {
http.NotFound(w, r)
return
}
digest := path[idx+len(manifestsPrefix):]
repoPath := path[len("/repositories/"):idx]
if repoPath == "" || digest == "" {
http.NotFound(w, r)
return
}
ctx := grpcContext(r)
resp, err := s.registry.GetRepository(ctx, &mcrv1.GetRepositoryRequest{Name: repoPath})
if err != nil {
s.templates.render(w, "manifest_detail", map[string]any{
"Session": true,
"RepoName": repoPath,
"Error": grpcErrorMessage(err),
})
return
}
// Find the specific manifest.
var manifest *mcrv1.ManifestInfo
for _, m := range resp.GetManifests() {
if m.GetDigest() == digest {
manifest = m
break
}
}
if manifest == nil {
s.templates.render(w, "manifest_detail", map[string]any{
"Session": true,
"RepoName": repoPath,
"Error": "Manifest not found.",
})
return
}
s.templates.render(w, "manifest_detail", map[string]any{
"Session": true,
"RepoName": repoPath,
"Manifest": manifest,
})
}
// handlePolicies renders the policy list and create form.
func (s *Server) handlePolicies(w http.ResponseWriter, r *http.Request) {
ctx := grpcContext(r)
csrf := s.generateCSRFToken(w)
resp, err := s.policy.ListPolicyRules(ctx, &mcrv1.ListPolicyRulesRequest{})
if err != nil {
s.templates.render(w, "policies", map[string]any{
"Session": true,
"CSRFToken": csrf,
"Error": grpcErrorMessage(err),
})
return
}
s.templates.render(w, "policies", map[string]any{
"Session": true,
"CSRFToken": csrf,
"Policies": resp.GetRules(),
})
}
// handleCreatePolicy processes the policy creation form.
func (s *Server) handleCreatePolicy(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !s.validateCSRFToken(r) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
priority, _ := strconv.ParseInt(r.FormValue("priority"), 10, 32)
actions := splitCSV(r.FormValue("actions"))
repos := splitCSV(r.FormValue("repositories"))
ctx := grpcContext(r)
_, err := s.policy.CreatePolicyRule(ctx, &mcrv1.CreatePolicyRuleRequest{
Priority: int32(priority),
Description: r.FormValue("description"),
Effect: r.FormValue("effect"),
Actions: actions,
Repositories: repos,
Enabled: true,
})
if err != nil {
csrf := s.generateCSRFToken(w)
s.templates.render(w, "policies", map[string]any{
"Session": true,
"CSRFToken": csrf,
"Error": grpcErrorMessage(err),
})
return
}
http.Redirect(w, r, "/policies", http.StatusSeeOther)
}
// handleTogglePolicy toggles a policy rule's enabled state.
func (s *Server) handleTogglePolicy(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !s.validateCSRFToken(r) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
idStr := extractPolicyID(r.URL.Path, "/toggle")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid policy ID", http.StatusBadRequest)
return
}
ctx := grpcContext(r)
// Get current state.
rule, err := s.policy.GetPolicyRule(ctx, &mcrv1.GetPolicyRuleRequest{Id: id})
if err != nil {
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
return
}
// Toggle the enabled field.
_, err = s.policy.UpdatePolicyRule(ctx, &mcrv1.UpdatePolicyRuleRequest{
Id: id,
Enabled: !rule.GetEnabled(),
UpdateMask: []string{"enabled"},
// Carry forward required fields.
Priority: rule.GetPriority(),
Description: rule.GetDescription(),
Effect: rule.GetEffect(),
Actions: rule.GetActions(),
Repositories: rule.GetRepositories(),
})
if err != nil {
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/policies", http.StatusSeeOther)
}
// handleDeletePolicy deletes a policy rule.
func (s *Server) handleDeletePolicy(w http.ResponseWriter, r *http.Request) {
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
if err := r.ParseForm(); err != nil {
http.Error(w, "bad request", http.StatusBadRequest)
return
}
if !s.validateCSRFToken(r) {
http.Error(w, "invalid CSRF token", http.StatusForbidden)
return
}
idStr := extractPolicyID(r.URL.Path, "/delete")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
http.Error(w, "invalid policy ID", http.StatusBadRequest)
return
}
ctx := grpcContext(r)
_, err = s.policy.DeletePolicyRule(ctx, &mcrv1.DeletePolicyRuleRequest{Id: id})
if err != nil {
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/policies", http.StatusSeeOther)
}
// handleAudit renders the audit log with filters and pagination.
func (s *Server) handleAudit(w http.ResponseWriter, r *http.Request) {
ctx := grpcContext(r)
q := r.URL.Query()
eventType := q.Get("event_type")
repo := q.Get("repository")
since := q.Get("since")
until := q.Get("until")
pageStr := q.Get("page")
page := 1
if pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
const pageSize int32 = 50
offset := int32(page-1) * pageSize
req := &mcrv1.ListAuditEventsRequest{
Pagination: &mcrv1.PaginationRequest{
Limit: pageSize + 1, // fetch one extra to detect next page
Offset: offset,
},
EventType: eventType,
Repository: repo,
Since: since,
Until: until,
}
resp, err := s.audit.ListAuditEvents(ctx, req)
if err != nil {
s.templates.render(w, "audit", map[string]any{
"Session": true,
"Error": grpcErrorMessage(err),
"FilterType": eventType,
"FilterRepo": repo,
"FilterSince": since,
"FilterUntil": until,
"Page": page,
})
return
}
events := resp.GetEvents()
hasNext := len(events) > int(pageSize)
if hasNext {
events = events[:pageSize]
}
// Build pagination URLs.
buildURL := func(p int) string {
v := url.Values{}
if eventType != "" {
v.Set("event_type", eventType)
}
if repo != "" {
v.Set("repository", repo)
}
if since != "" {
v.Set("since", since)
}
if until != "" {
v.Set("until", until)
}
v.Set("page", strconv.Itoa(p))
return "/audit?" + v.Encode()
}
s.templates.render(w, "audit", map[string]any{
"Session": true,
"Events": events,
"FilterType": eventType,
"FilterRepo": repo,
"FilterSince": since,
"FilterUntil": until,
"Page": page,
"HasNext": hasNext,
"PrevURL": buildURL(page - 1),
"NextURL": buildURL(page + 1),
})
}
// renderError renders a template with an error message derived from a gRPC error.
func (s *Server) renderError(w http.ResponseWriter, tmpl, fallback string, err error) {
msg := fallback
if st, ok := status.FromError(err); ok {
if st.Code() == codes.PermissionDenied {
msg = "Access denied."
}
}
s.templates.render(w, tmpl, map[string]any{
"Session": true,
"Error": msg,
})
}
// grpcErrorMessage extracts a human-readable message from a gRPC error.
func grpcErrorMessage(err error) string {
if st, ok := status.FromError(err); ok {
if st.Code() == codes.PermissionDenied {
return "Access denied."
}
return st.Message()
}
return err.Error()
}
// extractRepoName extracts the repository name from the URL path.
// The name may contain slashes (e.g., "library/nginx").
// URL format: /repositories/{name...}
func extractRepoName(r *http.Request) string {
path := r.URL.Path
prefix := "/repositories/"
if !strings.HasPrefix(path, prefix) {
return ""
}
name := path[len(prefix):]
// Strip trailing slash.
name = strings.TrimRight(name, "/")
// If the path contains /manifests/, extract only the repo name part.
if idx := strings.Index(name, "/manifests/"); idx >= 0 {
name = name[:idx]
}
return name
}
// extractPolicyID extracts the policy ID from paths like /policies/{id}/toggle
// or /policies/{id}/delete.
func extractPolicyID(path, suffix string) string {
path = strings.TrimSuffix(path, suffix)
path = strings.TrimPrefix(path, "/policies/")
return path
}
// splitCSV splits a comma-separated string, trimming whitespace.
func splitCSV(s string) []string {
if s == "" {
return nil
}
parts := strings.Split(s, ",")
result := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
result = append(result, p)
}
}
return result
}

View File

@@ -0,0 +1,133 @@
// Package webserver implements the MCR web UI server.
//
// It serves HTML pages rendered from Go templates with htmx for
// interactive elements. All data is fetched via gRPC from the main
// mcrsrv API server. Authentication is handled via MCIAS, with session
// tokens stored in secure HttpOnly cookies.
package webserver
import (
"io/fs"
"log"
"net/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
"git.wntrmute.dev/kyle/mcr/web"
)
// LoginFunc authenticates a user and returns a bearer token.
type LoginFunc func(username, password string) (token string, expiresIn int, err error)
// Server is the MCR web UI server.
type Server struct {
router chi.Router
templates *templateSet
registry mcrv1.RegistryServiceClient
policy mcrv1.PolicyServiceClient
audit mcrv1.AuditServiceClient
admin mcrv1.AdminServiceClient
loginFn LoginFunc
csrfKey []byte // 32-byte key for HMAC signing
}
// New creates a new web UI server with the given gRPC clients and login function.
func New(
registry mcrv1.RegistryServiceClient,
policy mcrv1.PolicyServiceClient,
audit mcrv1.AuditServiceClient,
admin mcrv1.AdminServiceClient,
loginFn LoginFunc,
csrfKey []byte,
) (*Server, error) {
tmpl, err := loadTemplates()
if err != nil {
return nil, err
}
s := &Server{
templates: tmpl,
registry: registry,
policy: policy,
audit: audit,
admin: admin,
loginFn: loginFn,
csrfKey: csrfKey,
}
s.router = s.buildRouter()
return s, nil
}
// Handler returns the http.Handler for the server.
func (s *Server) Handler() http.Handler {
return s.router
}
// buildRouter sets up the chi router with all routes and middleware.
func (s *Server) buildRouter() chi.Router {
r := chi.NewRouter()
// Global middleware.
r.Use(middleware.Recoverer)
r.Use(middleware.RequestID)
r.Use(middleware.RealIP)
// Static files (no auth required).
staticFS, err := fs.Sub(web.Content, "static")
if err != nil {
log.Fatalf("webserver: failed to create static sub-filesystem: %v", err)
}
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
// Public routes (no session required).
r.Get("/login", s.handleLoginPage)
r.Post("/login", s.handleLoginSubmit)
r.Get("/logout", s.handleLogout)
// Protected routes (session required).
r.Group(func(r chi.Router) {
r.Use(s.sessionMiddleware)
r.Get("/", s.handleDashboard)
// Repository routes — name may contain slashes.
r.Get("/repositories", s.handleRepositories)
r.Get("/repositories/*", s.handleRepositoryOrManifest)
// Policy routes (admin — gRPC interceptors enforce this).
r.Get("/policies", s.handlePolicies)
r.Post("/policies", s.handleCreatePolicy)
r.Post("/policies/{id}/toggle", s.handleTogglePolicy)
r.Post("/policies/{id}/delete", s.handleDeletePolicy)
// Audit routes (admin — gRPC interceptors enforce this).
r.Get("/audit", s.handleAudit)
})
return r
}
// handleRepositoryOrManifest dispatches between repository detail and
// manifest detail based on the URL path. This is necessary because
// repository names can contain slashes.
func (s *Server) handleRepositoryOrManifest(w http.ResponseWriter, r *http.Request) {
path := r.URL.Path
if idx := lastIndex(path, "/manifests/"); idx >= 0 {
s.handleManifestDetail(w, r)
return
}
s.handleRepositoryDetail(w, r)
}
// lastIndex returns the index of the last occurrence of sep in s, or -1.
func lastIndex(s, sep string) int {
for i := len(s) - len(sep); i >= 0; i-- {
if s[i:i+len(sep)] == sep {
return i
}
}
return -1
}

View File

@@ -0,0 +1,608 @@
package webserver
import (
"context"
"fmt"
"net"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
)
// fakeRegistryService implements RegistryServiceServer for testing.
type fakeRegistryService struct {
mcrv1.UnimplementedRegistryServiceServer
repos []*mcrv1.RepositoryMetadata
repoResp *mcrv1.GetRepositoryResponse
repoErr error
}
func (f *fakeRegistryService) ListRepositories(_ context.Context, _ *mcrv1.ListRepositoriesRequest) (*mcrv1.ListRepositoriesResponse, error) {
return &mcrv1.ListRepositoriesResponse{Repositories: f.repos}, nil
}
func (f *fakeRegistryService) GetRepository(_ context.Context, req *mcrv1.GetRepositoryRequest) (*mcrv1.GetRepositoryResponse, error) {
if f.repoErr != nil {
return nil, f.repoErr
}
if f.repoResp != nil {
return f.repoResp, nil
}
return &mcrv1.GetRepositoryResponse{Name: req.GetName()}, nil
}
// fakePolicyService implements PolicyServiceServer for testing.
type fakePolicyService struct {
mcrv1.UnimplementedPolicyServiceServer
rules []*mcrv1.PolicyRule
created *mcrv1.PolicyRule
}
func (f *fakePolicyService) ListPolicyRules(_ context.Context, _ *mcrv1.ListPolicyRulesRequest) (*mcrv1.ListPolicyRulesResponse, error) {
return &mcrv1.ListPolicyRulesResponse{Rules: f.rules}, nil
}
func (f *fakePolicyService) CreatePolicyRule(_ context.Context, req *mcrv1.CreatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
rule := &mcrv1.PolicyRule{
Id: 1,
Priority: req.GetPriority(),
Description: req.GetDescription(),
Effect: req.GetEffect(),
Actions: req.GetActions(),
Repositories: req.GetRepositories(),
Enabled: req.GetEnabled(),
}
f.created = rule
return rule, nil
}
func (f *fakePolicyService) GetPolicyRule(_ context.Context, req *mcrv1.GetPolicyRuleRequest) (*mcrv1.PolicyRule, error) {
for _, r := range f.rules {
if r.GetId() == req.GetId() {
return r, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
func (f *fakePolicyService) UpdatePolicyRule(_ context.Context, req *mcrv1.UpdatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
for _, r := range f.rules {
if r.GetId() == req.GetId() {
r.Enabled = req.GetEnabled()
return r, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
func (f *fakePolicyService) DeletePolicyRule(_ context.Context, req *mcrv1.DeletePolicyRuleRequest) (*mcrv1.DeletePolicyRuleResponse, error) {
for i, r := range f.rules {
if r.GetId() == req.GetId() {
f.rules = append(f.rules[:i], f.rules[i+1:]...)
return &mcrv1.DeletePolicyRuleResponse{}, nil
}
}
return nil, status.Errorf(codes.NotFound, "policy rule not found")
}
// fakeAuditService implements AuditServiceServer for testing.
type fakeAuditService struct {
mcrv1.UnimplementedAuditServiceServer
events []*mcrv1.AuditEvent
}
func (f *fakeAuditService) ListAuditEvents(_ context.Context, _ *mcrv1.ListAuditEventsRequest) (*mcrv1.ListAuditEventsResponse, error) {
return &mcrv1.ListAuditEventsResponse{Events: f.events}, nil
}
// fakeAdminService implements AdminServiceServer for testing.
type fakeAdminService struct {
mcrv1.UnimplementedAdminServiceServer
}
func (f *fakeAdminService) Health(_ context.Context, _ *mcrv1.HealthRequest) (*mcrv1.HealthResponse, error) {
return &mcrv1.HealthResponse{Status: "ok"}, nil
}
// testEnv holds a test server and its dependencies.
type testEnv struct {
server *Server
grpcServer *grpc.Server
grpcConn *grpc.ClientConn
registry *fakeRegistryService
policyFake *fakePolicyService
auditFake *fakeAuditService
}
func (e *testEnv) close() {
_ = e.grpcConn.Close()
e.grpcServer.Stop()
}
// setupTestEnv creates a test environment with fake gRPC backends.
func setupTestEnv(t *testing.T) *testEnv {
t.Helper()
registrySvc := &fakeRegistryService{
repos: []*mcrv1.RepositoryMetadata{
{Name: "library/nginx", TagCount: 3, ManifestCount: 2, TotalSize: 1024 * 1024, CreatedAt: "2024-01-15T10:00:00Z"},
{Name: "library/alpine", TagCount: 1, ManifestCount: 1, TotalSize: 512 * 1024, CreatedAt: "2024-01-16T10:00:00Z"},
},
}
policySvc := &fakePolicyService{
rules: []*mcrv1.PolicyRule{
{Id: 1, Priority: 100, Description: "Allow all pulls", Effect: "allow", Actions: []string{"pull"}, Repositories: []string{"*"}, Enabled: true},
},
}
auditSvc := &fakeAuditService{
events: []*mcrv1.AuditEvent{
{Id: 1, EventTime: "2024-01-15T12:00:00Z", EventType: "manifest_pushed", ActorId: "user1", Repository: "library/nginx", Digest: "sha256:abc123", IpAddress: "10.0.0.1"},
},
}
adminSvc := &fakeAdminService{}
// Start in-process gRPC server.
gs := grpc.NewServer()
mcrv1.RegisterRegistryServiceServer(gs, registrySvc)
mcrv1.RegisterPolicyServiceServer(gs, policySvc)
mcrv1.RegisterAuditServiceServer(gs, auditSvc)
mcrv1.RegisterAdminServiceServer(gs, adminSvc)
lis, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen: %v", err)
}
go func() {
_ = gs.Serve(lis)
}()
// Connect client.
conn, err := grpc.NewClient(
lis.Addr().String(),
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})),
)
if err != nil {
gs.Stop()
t.Fatalf("dial: %v", err)
}
csrfKey := []byte("test-csrf-key-32-bytes-long!1234")
loginFn := func(username, password string) (string, int, error) {
if username == "admin" && password == "secret" {
return "test-token-12345", 3600, nil
}
return "", 0, fmt.Errorf("invalid credentials")
}
srv, err := New(
mcrv1.NewRegistryServiceClient(conn),
mcrv1.NewPolicyServiceClient(conn),
mcrv1.NewAuditServiceClient(conn),
mcrv1.NewAdminServiceClient(conn),
loginFn,
csrfKey,
)
if err != nil {
_ = conn.Close()
gs.Stop()
t.Fatalf("create server: %v", err)
}
return &testEnv{
server: srv,
grpcServer: gs,
grpcConn: conn,
registry: registrySvc,
policyFake: policySvc,
auditFake: auditSvc,
}
}
func TestLoginPageRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/login", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /login: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "MCR Login") {
t.Error("login page does not contain 'MCR Login'")
}
if !strings.Contains(body, "_csrf") {
t.Error("login page does not contain CSRF token field")
}
}
func TestLoginInvalidCredentials(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// First get a CSRF token.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(getRec, getReq)
// Extract CSRF cookie and token.
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if c.Name == "csrf_token" {
csrfCookie = c
break
}
}
if csrfCookie == nil {
t.Fatal("no csrf_token cookie set")
}
// Extract the CSRF token from the cookie value (token.signature).
parts := strings.SplitN(csrfCookie.Value, ".", 2)
csrfToken := parts[0]
// Submit login with wrong credentials.
form := url.Values{
"username": {"baduser"},
"password": {"badpass"},
"_csrf": {csrfToken},
}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
postReq.AddCookie(csrfCookie)
postRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(postRec, postReq)
if postRec.Code != http.StatusOK {
t.Fatalf("POST /login: status %d, want %d", postRec.Code, http.StatusOK)
}
body := postRec.Body.String()
if !strings.Contains(body, "Invalid username or password") {
t.Error("response does not contain error message for invalid credentials")
}
}
func TestDashboardRequiresSession(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("GET / without session: status %d, want %d", rec.Code, http.StatusSeeOther)
}
loc := rec.Header().Get("Location")
if loc != "/login" {
t.Fatalf("redirect location: got %q, want /login", loc)
}
}
func TestDashboardWithSession(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET / with session: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "Dashboard") {
t.Error("dashboard page does not contain 'Dashboard'")
}
if !strings.Contains(body, "Repositories") {
t.Error("dashboard page does not show repository count")
}
}
func TestRepositoriesPageRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/repositories", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /repositories: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "library/nginx") {
t.Error("repositories page does not contain 'library/nginx'")
}
if !strings.Contains(body, "library/alpine") {
t.Error("repositories page does not contain 'library/alpine'")
}
}
func TestRepositoryDetailRenders(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
env.registry.repoResp = &mcrv1.GetRepositoryResponse{
Name: "library/nginx",
TotalSize: 2048,
Tags: []*mcrv1.TagInfo{
{Name: "latest", Digest: "sha256:abc123def456"},
},
Manifests: []*mcrv1.ManifestInfo{
{Digest: "sha256:abc123def456", MediaType: "application/vnd.oci.image.manifest.v1+json", Size: 2048, CreatedAt: "2024-01-15T10:00:00Z"},
},
}
req := httptest.NewRequest(http.MethodGet, "/repositories/library/nginx", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /repositories/library/nginx: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "library/nginx") {
t.Error("repository detail page does not contain repo name")
}
if !strings.Contains(body, "latest") {
t.Error("repository detail page does not contain tag 'latest'")
}
}
func TestCSRFTokenValidation(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// POST without CSRF token should fail.
form := url.Values{
"username": {"admin"},
"password": {"secret"},
}
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
body := rec.Body.String()
// Should show the error about invalid form submission.
if !strings.Contains(body, "Invalid or expired form submission") {
t.Error("POST without CSRF token should show error, got: " + body[:min(200, len(body))])
}
}
func TestLogout(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/logout", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusSeeOther {
t.Fatalf("GET /logout: status %d, want %d", rec.Code, http.StatusSeeOther)
}
loc := rec.Header().Get("Location")
if loc != "/login" {
t.Fatalf("redirect location: got %q, want /login", loc)
}
// Verify session cookie is cleared.
var sessionCleared bool
for _, c := range rec.Result().Cookies() {
if c.Name == "mcr_session" && c.MaxAge < 0 {
sessionCleared = true
break
}
}
if !sessionCleared {
t.Error("session cookie was not cleared on logout")
}
}
func TestPoliciesPage(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/policies", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /policies: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "Allow all pulls") {
t.Error("policies page does not contain policy description")
}
}
func TestAuditPage(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/audit", nil)
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /audit: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "manifest_pushed") {
t.Error("audit page does not contain event type")
}
}
func TestStaticFiles(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil)
rec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Fatalf("GET /static/style.css: status %d, want %d", rec.Code, http.StatusOK)
}
body := rec.Body.String()
if !strings.Contains(body, "font-family") {
t.Error("style.css does not appear to contain CSS")
}
}
func TestFormatSize(t *testing.T) {
tests := []struct {
input int64
want string
}{
{0, "0 B"},
{512, "512 B"},
{1024, "1.0 KiB"},
{1048576, "1.0 MiB"},
{1073741824, "1.0 GiB"},
{1099511627776, "1.0 TiB"},
}
for _, tt := range tests {
got := formatSize(tt.input)
if got != tt.want {
t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want)
}
}
}
func TestFormatTime(t *testing.T) {
got := formatTime("2024-01-15T10:30:00Z")
want := "2024-01-15 10:30:00"
if got != want {
t.Errorf("formatTime = %q, want %q", got, want)
}
// Invalid time returns the input.
got = formatTime("not-a-time")
if got != "not-a-time" {
t.Errorf("formatTime(invalid) = %q, want %q", got, "not-a-time")
}
}
func TestTruncate(t *testing.T) {
got := truncate("sha256:abc123def456", 12)
want := "sha256:abc12..."
if got != want {
t.Errorf("truncate = %q, want %q", got, want)
}
// Short strings are not truncated.
got = truncate("short", 10)
if got != "short" {
t.Errorf("truncate(short) = %q, want %q", got, "short")
}
}
func TestLoginSuccessSetsCookie(t *testing.T) {
env := setupTestEnv(t)
defer env.close()
// Get CSRF token.
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
getRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(getRec, getReq)
var csrfCookie *http.Cookie
for _, c := range getRec.Result().Cookies() {
if c.Name == "csrf_token" {
csrfCookie = c
break
}
}
if csrfCookie == nil {
t.Fatal("no csrf_token cookie")
}
parts := strings.SplitN(csrfCookie.Value, ".", 2)
csrfToken := parts[0]
form := url.Values{
"username": {"admin"},
"password": {"secret"},
"_csrf": {csrfToken},
}
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
postReq.AddCookie(csrfCookie)
postRec := httptest.NewRecorder()
env.server.Handler().ServeHTTP(postRec, postReq)
if postRec.Code != http.StatusSeeOther {
t.Fatalf("POST /login: status %d, want %d; body: %s", postRec.Code, http.StatusSeeOther, postRec.Body.String())
}
var sessionCookie *http.Cookie
for _, c := range postRec.Result().Cookies() {
if c.Name == "mcr_session" {
sessionCookie = c
break
}
}
if sessionCookie == nil {
t.Fatal("no mcr_session cookie set after login")
}
if sessionCookie.Value != "test-token-12345" {
t.Errorf("session cookie value = %q, want %q", sessionCookie.Value, "test-token-12345")
}
if !sessionCookie.HttpOnly {
t.Error("session cookie is not HttpOnly")
}
if !sessionCookie.Secure {
t.Error("session cookie is not Secure")
}
if sessionCookie.SameSite != http.SameSiteStrictMode {
t.Error("session cookie SameSite is not Strict")
}
}

View File

@@ -0,0 +1,136 @@
package webserver
import (
"fmt"
"html/template"
"io/fs"
"net/http"
"strings"
"time"
"git.wntrmute.dev/kyle/mcr/web"
)
// templateSet wraps parsed templates and provides a render method.
type templateSet struct {
templates map[string]*template.Template
}
// templateFuncs returns the function map used across all templates.
func templateFuncs() template.FuncMap {
return template.FuncMap{
"formatSize": formatSize,
"formatTime": formatTime,
"truncate": truncate,
"joinStrings": joinStrings,
}
}
// loadTemplates parses all page templates with the layout template.
func loadTemplates() (*templateSet, error) {
// Read layout template.
layoutBytes, err := fs.ReadFile(web.Content, "templates/layout.html")
if err != nil {
return nil, fmt.Errorf("webserver: read layout template: %w", err)
}
layoutStr := string(layoutBytes)
pages := []string{
"login",
"dashboard",
"repositories",
"repository_detail",
"manifest_detail",
"policies",
"audit",
}
ts := &templateSet{
templates: make(map[string]*template.Template, len(pages)),
}
for _, page := range pages {
pageBytes, readErr := fs.ReadFile(web.Content, "templates/"+page+".html")
if readErr != nil {
return nil, fmt.Errorf("webserver: read template %s: %w", page, readErr)
}
t, parseErr := template.New("layout").Funcs(templateFuncs()).Parse(layoutStr)
if parseErr != nil {
return nil, fmt.Errorf("webserver: parse layout for %s: %w", page, parseErr)
}
_, parseErr = t.Parse(string(pageBytes))
if parseErr != nil {
return nil, fmt.Errorf("webserver: parse template %s: %w", page, parseErr)
}
ts.templates[page] = t
}
return ts, nil
}
// render executes a named template and writes the result to w.
func (ts *templateSet) render(w http.ResponseWriter, name string, data any) {
t, ok := ts.templates[name]
if !ok {
http.Error(w, "template not found", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/html; charset=utf-8")
if err := t.Execute(w, data); err != nil {
// Template already started writing; log but don't send another error.
_ = err // best-effort; headers may already be sent
}
}
// formatSize converts bytes to a human-readable string.
func formatSize(b int64) string {
const (
kib = 1024
mib = 1024 * kib
gib = 1024 * mib
tib = 1024 * gib
)
switch {
case b >= tib:
return fmt.Sprintf("%.1f TiB", float64(b)/float64(tib))
case b >= gib:
return fmt.Sprintf("%.1f GiB", float64(b)/float64(gib))
case b >= mib:
return fmt.Sprintf("%.1f MiB", float64(b)/float64(mib))
case b >= kib:
return fmt.Sprintf("%.1f KiB", float64(b)/float64(kib))
default:
return fmt.Sprintf("%d B", b)
}
}
// formatTime converts an RFC3339 string to a more readable format.
func formatTime(s string) string {
t, err := time.Parse(time.RFC3339, s)
if err != nil {
// Try RFC3339Nano.
t, err = time.Parse(time.RFC3339Nano, s)
if err != nil {
return s
}
}
return t.Format("2006-01-02 15:04:05")
}
// truncate returns the first n characters of s, appending "..." if truncated.
func truncate(s string, n int) string {
if len(s) <= n {
return s
}
return s[:n] + "..."
}
// joinStrings joins string slices for template rendering.
func joinStrings(ss []string, sep string) string {
return strings.Join(ss, sep)
}