Phases 11, 12: mcrctl CLI tool and mcr-web UI
Phase 11 implements the admin CLI with dual REST/gRPC transport, global flags (--server, --grpc, --token, --ca-cert, --json), and all commands: status, repo list/delete, policy CRUD, audit tail, gc trigger/status/reconcile, and snapshot. Phase 12 implements the HTMX web UI with chi router, session-based auth (HttpOnly/Secure/SameSite=Strict cookies), CSRF protection (HMAC-SHA256 signed double-submit), and pages for dashboard, repositories, manifest detail, policy management, and audit log. Security: CSRF via signed double-submit cookie, session cookies with HttpOnly/Secure/SameSite=Strict, TLS 1.3 minimum on all connections, form body size limits via http.MaxBytesReader. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
184
internal/webserver/auth.go
Normal file
184
internal/webserver/auth.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// sessionKey is the context key for the session token.
|
||||
type sessionKey struct{}
|
||||
|
||||
// tokenFromContext retrieves the bearer token from context.
|
||||
func tokenFromContext(ctx context.Context) string {
|
||||
s, _ := ctx.Value(sessionKey{}).(string)
|
||||
return s
|
||||
}
|
||||
|
||||
// contextWithToken stores a bearer token in the context.
|
||||
func contextWithToken(ctx context.Context, token string) context.Context {
|
||||
return context.WithValue(ctx, sessionKey{}, token)
|
||||
}
|
||||
|
||||
// sessionMiddleware checks for a valid mcr_session cookie and adds the
|
||||
// token to the request context. If no session is present, it redirects
|
||||
// to the login page.
|
||||
func (s *Server) sessionMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cookie, err := r.Cookie("mcr_session")
|
||||
if err != nil || cookie.Value == "" {
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := contextWithToken(r.Context(), cookie.Value)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
// handleLoginPage renders the login form.
|
||||
func (s *Server) handleLoginPage(w http.ResponseWriter, r *http.Request) {
|
||||
csrf := s.generateCSRFToken(w)
|
||||
s.templates.render(w, "login", map[string]any{
|
||||
"CSRFToken": csrf,
|
||||
"Session": false,
|
||||
})
|
||||
}
|
||||
|
||||
// handleLoginSubmit processes the login form.
|
||||
func (s *Server) handleLoginSubmit(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.validateCSRFToken(r) {
|
||||
csrf := s.generateCSRFToken(w)
|
||||
s.templates.render(w, "login", map[string]any{
|
||||
"Error": "Invalid or expired form submission. Please try again.",
|
||||
"CSRFToken": csrf,
|
||||
"Session": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
username := r.FormValue("username")
|
||||
password := r.FormValue("password")
|
||||
|
||||
if username == "" || password == "" {
|
||||
csrf := s.generateCSRFToken(w)
|
||||
s.templates.render(w, "login", map[string]any{
|
||||
"Error": "Username and password are required.",
|
||||
"CSRFToken": csrf,
|
||||
"Session": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
token, _, err := s.loginFn(username, password)
|
||||
if err != nil {
|
||||
log.Printf("login failed for user %q: %v", username, err)
|
||||
csrf := s.generateCSRFToken(w)
|
||||
s.templates.render(w, "login", map[string]any{
|
||||
"Error": "Invalid username or password.",
|
||||
"CSRFToken": csrf,
|
||||
"Session": false,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "mcr_session",
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
http.Redirect(w, r, "/", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// handleLogout clears the session and redirects to login.
|
||||
func (s *Server) handleLogout(w http.ResponseWriter, r *http.Request) {
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "mcr_session",
|
||||
Value: "",
|
||||
Path: "/",
|
||||
MaxAge: -1,
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
http.Redirect(w, r, "/login", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// generateCSRFToken creates a random token, signs it with HMAC, stores
|
||||
// the signed value in a cookie, and returns the token for form embedding.
|
||||
func (s *Server) generateCSRFToken(w http.ResponseWriter) string {
|
||||
b := make([]byte, 16)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
// Crypto RNG failure is fatal; this should never happen.
|
||||
log.Printf("csrf: failed to generate random bytes: %v", err)
|
||||
return ""
|
||||
}
|
||||
|
||||
token := hex.EncodeToString(b)
|
||||
sig := s.signCSRF(token)
|
||||
cookieVal := token + "." + sig
|
||||
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: "csrf_token",
|
||||
Value: cookieVal,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
// validateCSRFToken verifies the form _csrf field matches the cookie and
|
||||
// the HMAC signature is valid.
|
||||
func (s *Server) validateCSRFToken(r *http.Request) bool {
|
||||
formToken := r.FormValue("_csrf")
|
||||
if formToken == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
cookie, err := r.Cookie("csrf_token")
|
||||
if err != nil || cookie.Value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
parts := strings.SplitN(cookie.Value, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
|
||||
cookieToken := parts[0]
|
||||
cookieSig := parts[1]
|
||||
|
||||
// Verify the form token matches the cookie token.
|
||||
if !hmac.Equal([]byte(formToken), []byte(cookieToken)) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Verify the HMAC signature.
|
||||
expectedSig := s.signCSRF(cookieToken)
|
||||
return hmac.Equal([]byte(cookieSig), []byte(expectedSig))
|
||||
}
|
||||
|
||||
// signCSRF computes an HMAC-SHA256 signature for a CSRF token.
|
||||
func (s *Server) signCSRF(token string) string {
|
||||
mac := hmac.New(sha256.New, s.csrfKey)
|
||||
mac.Write([]byte(token))
|
||||
return hex.EncodeToString(mac.Sum(nil))
|
||||
}
|
||||
457
internal/webserver/handlers.go
Normal file
457
internal/webserver/handlers.go
Normal file
@@ -0,0 +1,457 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// grpcContext creates a context with the bearer token from the session
|
||||
// attached as gRPC outgoing metadata.
|
||||
func grpcContext(r *http.Request) context.Context {
|
||||
token := tokenFromContext(r.Context())
|
||||
return metadata.AppendToOutgoingContext(r.Context(), "authorization", "Bearer "+token)
|
||||
}
|
||||
|
||||
// handleDashboard renders the dashboard with repo stats and recent activity.
|
||||
func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := grpcContext(r)
|
||||
|
||||
repos, err := s.registry.ListRepositories(ctx, &mcrv1.ListRepositoriesRequest{})
|
||||
if err != nil {
|
||||
s.renderError(w, "dashboard", "Failed to load repositories.", err)
|
||||
return
|
||||
}
|
||||
|
||||
var repoCount int
|
||||
var totalSize int64
|
||||
for _, repo := range repos.GetRepositories() {
|
||||
repoCount++
|
||||
totalSize += repo.GetTotalSize()
|
||||
}
|
||||
|
||||
// Fetch recent audit events for dashboard activity.
|
||||
var events []*mcrv1.AuditEvent
|
||||
auditResp, auditErr := s.audit.ListAuditEvents(ctx, &mcrv1.ListAuditEventsRequest{
|
||||
Pagination: &mcrv1.PaginationRequest{Limit: 10},
|
||||
})
|
||||
if auditErr == nil {
|
||||
events = auditResp.GetEvents()
|
||||
}
|
||||
// If audit fails with PermissionDenied, just show no events (user is not admin).
|
||||
|
||||
s.templates.render(w, "dashboard", map[string]any{
|
||||
"Session": true,
|
||||
"RepoCount": repoCount,
|
||||
"TotalSize": formatSize(totalSize),
|
||||
"Events": events,
|
||||
})
|
||||
}
|
||||
|
||||
// handleRepositories renders the repository list.
|
||||
func (s *Server) handleRepositories(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := grpcContext(r)
|
||||
|
||||
resp, err := s.registry.ListRepositories(ctx, &mcrv1.ListRepositoriesRequest{})
|
||||
if err != nil {
|
||||
s.renderError(w, "repositories", "Failed to load repositories.", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.templates.render(w, "repositories", map[string]any{
|
||||
"Session": true,
|
||||
"Repositories": resp.GetRepositories(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleRepositoryDetail renders a single repository's tags and manifests.
|
||||
func (s *Server) handleRepositoryDetail(w http.ResponseWriter, r *http.Request) {
|
||||
name := extractRepoName(r)
|
||||
if name == "" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := grpcContext(r)
|
||||
|
||||
resp, err := s.registry.GetRepository(ctx, &mcrv1.GetRepositoryRequest{Name: name})
|
||||
if err != nil {
|
||||
s.templates.render(w, "repository_detail", map[string]any{
|
||||
"Session": true,
|
||||
"Name": name,
|
||||
"Error": grpcErrorMessage(err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
s.templates.render(w, "repository_detail", map[string]any{
|
||||
"Session": true,
|
||||
"Name": resp.GetName(),
|
||||
"Tags": resp.GetTags(),
|
||||
"Manifests": resp.GetManifests(),
|
||||
"TotalSize": resp.GetTotalSize(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleManifestDetail renders details for a specific manifest.
|
||||
func (s *Server) handleManifestDetail(w http.ResponseWriter, r *http.Request) {
|
||||
// URL format: /repositories/{name}/manifests/{digest}
|
||||
// The name can contain slashes, so we parse manually.
|
||||
path := r.URL.Path
|
||||
const manifestsPrefix = "/manifests/"
|
||||
idx := strings.LastIndex(path, manifestsPrefix)
|
||||
if idx < 0 {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
digest := path[idx+len(manifestsPrefix):]
|
||||
repoPath := path[len("/repositories/"):idx]
|
||||
|
||||
if repoPath == "" || digest == "" {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := grpcContext(r)
|
||||
|
||||
resp, err := s.registry.GetRepository(ctx, &mcrv1.GetRepositoryRequest{Name: repoPath})
|
||||
if err != nil {
|
||||
s.templates.render(w, "manifest_detail", map[string]any{
|
||||
"Session": true,
|
||||
"RepoName": repoPath,
|
||||
"Error": grpcErrorMessage(err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Find the specific manifest.
|
||||
var manifest *mcrv1.ManifestInfo
|
||||
for _, m := range resp.GetManifests() {
|
||||
if m.GetDigest() == digest {
|
||||
manifest = m
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if manifest == nil {
|
||||
s.templates.render(w, "manifest_detail", map[string]any{
|
||||
"Session": true,
|
||||
"RepoName": repoPath,
|
||||
"Error": "Manifest not found.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
s.templates.render(w, "manifest_detail", map[string]any{
|
||||
"Session": true,
|
||||
"RepoName": repoPath,
|
||||
"Manifest": manifest,
|
||||
})
|
||||
}
|
||||
|
||||
// handlePolicies renders the policy list and create form.
|
||||
func (s *Server) handlePolicies(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := grpcContext(r)
|
||||
csrf := s.generateCSRFToken(w)
|
||||
|
||||
resp, err := s.policy.ListPolicyRules(ctx, &mcrv1.ListPolicyRulesRequest{})
|
||||
if err != nil {
|
||||
s.templates.render(w, "policies", map[string]any{
|
||||
"Session": true,
|
||||
"CSRFToken": csrf,
|
||||
"Error": grpcErrorMessage(err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
s.templates.render(w, "policies", map[string]any{
|
||||
"Session": true,
|
||||
"CSRFToken": csrf,
|
||||
"Policies": resp.GetRules(),
|
||||
})
|
||||
}
|
||||
|
||||
// handleCreatePolicy processes the policy creation form.
|
||||
func (s *Server) handleCreatePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.validateCSRFToken(r) {
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
priority, _ := strconv.ParseInt(r.FormValue("priority"), 10, 32)
|
||||
actions := splitCSV(r.FormValue("actions"))
|
||||
repos := splitCSV(r.FormValue("repositories"))
|
||||
|
||||
ctx := grpcContext(r)
|
||||
|
||||
_, err := s.policy.CreatePolicyRule(ctx, &mcrv1.CreatePolicyRuleRequest{
|
||||
Priority: int32(priority),
|
||||
Description: r.FormValue("description"),
|
||||
Effect: r.FormValue("effect"),
|
||||
Actions: actions,
|
||||
Repositories: repos,
|
||||
Enabled: true,
|
||||
})
|
||||
if err != nil {
|
||||
csrf := s.generateCSRFToken(w)
|
||||
s.templates.render(w, "policies", map[string]any{
|
||||
"Session": true,
|
||||
"CSRFToken": csrf,
|
||||
"Error": grpcErrorMessage(err),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/policies", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// handleTogglePolicy toggles a policy rule's enabled state.
|
||||
func (s *Server) handleTogglePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.validateCSRFToken(r) {
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := extractPolicyID(r.URL.Path, "/toggle")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid policy ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := grpcContext(r)
|
||||
|
||||
// Get current state.
|
||||
rule, err := s.policy.GetPolicyRule(ctx, &mcrv1.GetPolicyRuleRequest{Id: id})
|
||||
if err != nil {
|
||||
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// Toggle the enabled field.
|
||||
_, err = s.policy.UpdatePolicyRule(ctx, &mcrv1.UpdatePolicyRuleRequest{
|
||||
Id: id,
|
||||
Enabled: !rule.GetEnabled(),
|
||||
UpdateMask: []string{"enabled"},
|
||||
// Carry forward required fields.
|
||||
Priority: rule.GetPriority(),
|
||||
Description: rule.GetDescription(),
|
||||
Effect: rule.GetEffect(),
|
||||
Actions: rule.GetActions(),
|
||||
Repositories: rule.GetRepositories(),
|
||||
})
|
||||
if err != nil {
|
||||
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/policies", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// handleDeletePolicy deletes a policy rule.
|
||||
func (s *Server) handleDeletePolicy(w http.ResponseWriter, r *http.Request) {
|
||||
r.Body = http.MaxBytesReader(w, r.Body, 1<<20) // 1 MiB limit
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "bad request", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.validateCSRFToken(r) {
|
||||
http.Error(w, "invalid CSRF token", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
idStr := extractPolicyID(r.URL.Path, "/delete")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, "invalid policy ID", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := grpcContext(r)
|
||||
|
||||
_, err = s.policy.DeletePolicyRule(ctx, &mcrv1.DeletePolicyRuleRequest{Id: id})
|
||||
if err != nil {
|
||||
http.Error(w, grpcErrorMessage(err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
http.Redirect(w, r, "/policies", http.StatusSeeOther)
|
||||
}
|
||||
|
||||
// handleAudit renders the audit log with filters and pagination.
|
||||
func (s *Server) handleAudit(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := grpcContext(r)
|
||||
|
||||
q := r.URL.Query()
|
||||
eventType := q.Get("event_type")
|
||||
repo := q.Get("repository")
|
||||
since := q.Get("since")
|
||||
until := q.Get("until")
|
||||
pageStr := q.Get("page")
|
||||
|
||||
page := 1
|
||||
if pageStr != "" {
|
||||
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
|
||||
page = p
|
||||
}
|
||||
}
|
||||
|
||||
const pageSize int32 = 50
|
||||
offset := int32(page-1) * pageSize
|
||||
|
||||
req := &mcrv1.ListAuditEventsRequest{
|
||||
Pagination: &mcrv1.PaginationRequest{
|
||||
Limit: pageSize + 1, // fetch one extra to detect next page
|
||||
Offset: offset,
|
||||
},
|
||||
EventType: eventType,
|
||||
Repository: repo,
|
||||
Since: since,
|
||||
Until: until,
|
||||
}
|
||||
|
||||
resp, err := s.audit.ListAuditEvents(ctx, req)
|
||||
if err != nil {
|
||||
s.templates.render(w, "audit", map[string]any{
|
||||
"Session": true,
|
||||
"Error": grpcErrorMessage(err),
|
||||
"FilterType": eventType,
|
||||
"FilterRepo": repo,
|
||||
"FilterSince": since,
|
||||
"FilterUntil": until,
|
||||
"Page": page,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
events := resp.GetEvents()
|
||||
hasNext := len(events) > int(pageSize)
|
||||
if hasNext {
|
||||
events = events[:pageSize]
|
||||
}
|
||||
|
||||
// Build pagination URLs.
|
||||
buildURL := func(p int) string {
|
||||
v := url.Values{}
|
||||
if eventType != "" {
|
||||
v.Set("event_type", eventType)
|
||||
}
|
||||
if repo != "" {
|
||||
v.Set("repository", repo)
|
||||
}
|
||||
if since != "" {
|
||||
v.Set("since", since)
|
||||
}
|
||||
if until != "" {
|
||||
v.Set("until", until)
|
||||
}
|
||||
v.Set("page", strconv.Itoa(p))
|
||||
return "/audit?" + v.Encode()
|
||||
}
|
||||
|
||||
s.templates.render(w, "audit", map[string]any{
|
||||
"Session": true,
|
||||
"Events": events,
|
||||
"FilterType": eventType,
|
||||
"FilterRepo": repo,
|
||||
"FilterSince": since,
|
||||
"FilterUntil": until,
|
||||
"Page": page,
|
||||
"HasNext": hasNext,
|
||||
"PrevURL": buildURL(page - 1),
|
||||
"NextURL": buildURL(page + 1),
|
||||
})
|
||||
}
|
||||
|
||||
// renderError renders a template with an error message derived from a gRPC error.
|
||||
func (s *Server) renderError(w http.ResponseWriter, tmpl, fallback string, err error) {
|
||||
msg := fallback
|
||||
if st, ok := status.FromError(err); ok {
|
||||
if st.Code() == codes.PermissionDenied {
|
||||
msg = "Access denied."
|
||||
}
|
||||
}
|
||||
s.templates.render(w, tmpl, map[string]any{
|
||||
"Session": true,
|
||||
"Error": msg,
|
||||
})
|
||||
}
|
||||
|
||||
// grpcErrorMessage extracts a human-readable message from a gRPC error.
|
||||
func grpcErrorMessage(err error) string {
|
||||
if st, ok := status.FromError(err); ok {
|
||||
if st.Code() == codes.PermissionDenied {
|
||||
return "Access denied."
|
||||
}
|
||||
return st.Message()
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
// extractRepoName extracts the repository name from the URL path.
|
||||
// The name may contain slashes (e.g., "library/nginx").
|
||||
// URL format: /repositories/{name...}
|
||||
func extractRepoName(r *http.Request) string {
|
||||
path := r.URL.Path
|
||||
prefix := "/repositories/"
|
||||
if !strings.HasPrefix(path, prefix) {
|
||||
return ""
|
||||
}
|
||||
name := path[len(prefix):]
|
||||
|
||||
// Strip trailing slash.
|
||||
name = strings.TrimRight(name, "/")
|
||||
|
||||
// If the path contains /manifests/, extract only the repo name part.
|
||||
if idx := strings.Index(name, "/manifests/"); idx >= 0 {
|
||||
name = name[:idx]
|
||||
}
|
||||
|
||||
return name
|
||||
}
|
||||
|
||||
// extractPolicyID extracts the policy ID from paths like /policies/{id}/toggle
|
||||
// or /policies/{id}/delete.
|
||||
func extractPolicyID(path, suffix string) string {
|
||||
path = strings.TrimSuffix(path, suffix)
|
||||
path = strings.TrimPrefix(path, "/policies/")
|
||||
return path
|
||||
}
|
||||
|
||||
// splitCSV splits a comma-separated string, trimming whitespace.
|
||||
func splitCSV(s string) []string {
|
||||
if s == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(s, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
for _, p := range parts {
|
||||
p = strings.TrimSpace(p)
|
||||
if p != "" {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
133
internal/webserver/server.go
Normal file
133
internal/webserver/server.go
Normal file
@@ -0,0 +1,133 @@
|
||||
// Package webserver implements the MCR web UI server.
|
||||
//
|
||||
// It serves HTML pages rendered from Go templates with htmx for
|
||||
// interactive elements. All data is fetched via gRPC from the main
|
||||
// mcrsrv API server. Authentication is handled via MCIAS, with session
|
||||
// tokens stored in secure HttpOnly cookies.
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"io/fs"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/go-chi/chi/v5/middleware"
|
||||
|
||||
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
"git.wntrmute.dev/kyle/mcr/web"
|
||||
)
|
||||
|
||||
// LoginFunc authenticates a user and returns a bearer token.
|
||||
type LoginFunc func(username, password string) (token string, expiresIn int, err error)
|
||||
|
||||
// Server is the MCR web UI server.
|
||||
type Server struct {
|
||||
router chi.Router
|
||||
templates *templateSet
|
||||
registry mcrv1.RegistryServiceClient
|
||||
policy mcrv1.PolicyServiceClient
|
||||
audit mcrv1.AuditServiceClient
|
||||
admin mcrv1.AdminServiceClient
|
||||
loginFn LoginFunc
|
||||
csrfKey []byte // 32-byte key for HMAC signing
|
||||
}
|
||||
|
||||
// New creates a new web UI server with the given gRPC clients and login function.
|
||||
func New(
|
||||
registry mcrv1.RegistryServiceClient,
|
||||
policy mcrv1.PolicyServiceClient,
|
||||
audit mcrv1.AuditServiceClient,
|
||||
admin mcrv1.AdminServiceClient,
|
||||
loginFn LoginFunc,
|
||||
csrfKey []byte,
|
||||
) (*Server, error) {
|
||||
tmpl, err := loadTemplates()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Server{
|
||||
templates: tmpl,
|
||||
registry: registry,
|
||||
policy: policy,
|
||||
audit: audit,
|
||||
admin: admin,
|
||||
loginFn: loginFn,
|
||||
csrfKey: csrfKey,
|
||||
}
|
||||
|
||||
s.router = s.buildRouter()
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Handler returns the http.Handler for the server.
|
||||
func (s *Server) Handler() http.Handler {
|
||||
return s.router
|
||||
}
|
||||
|
||||
// buildRouter sets up the chi router with all routes and middleware.
|
||||
func (s *Server) buildRouter() chi.Router {
|
||||
r := chi.NewRouter()
|
||||
|
||||
// Global middleware.
|
||||
r.Use(middleware.Recoverer)
|
||||
r.Use(middleware.RequestID)
|
||||
r.Use(middleware.RealIP)
|
||||
|
||||
// Static files (no auth required).
|
||||
staticFS, err := fs.Sub(web.Content, "static")
|
||||
if err != nil {
|
||||
log.Fatalf("webserver: failed to create static sub-filesystem: %v", err)
|
||||
}
|
||||
r.Handle("/static/*", http.StripPrefix("/static/", http.FileServer(http.FS(staticFS))))
|
||||
|
||||
// Public routes (no session required).
|
||||
r.Get("/login", s.handleLoginPage)
|
||||
r.Post("/login", s.handleLoginSubmit)
|
||||
r.Get("/logout", s.handleLogout)
|
||||
|
||||
// Protected routes (session required).
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(s.sessionMiddleware)
|
||||
|
||||
r.Get("/", s.handleDashboard)
|
||||
|
||||
// Repository routes — name may contain slashes.
|
||||
r.Get("/repositories", s.handleRepositories)
|
||||
r.Get("/repositories/*", s.handleRepositoryOrManifest)
|
||||
|
||||
// Policy routes (admin — gRPC interceptors enforce this).
|
||||
r.Get("/policies", s.handlePolicies)
|
||||
r.Post("/policies", s.handleCreatePolicy)
|
||||
r.Post("/policies/{id}/toggle", s.handleTogglePolicy)
|
||||
r.Post("/policies/{id}/delete", s.handleDeletePolicy)
|
||||
|
||||
// Audit routes (admin — gRPC interceptors enforce this).
|
||||
r.Get("/audit", s.handleAudit)
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// handleRepositoryOrManifest dispatches between repository detail and
|
||||
// manifest detail based on the URL path. This is necessary because
|
||||
// repository names can contain slashes.
|
||||
func (s *Server) handleRepositoryOrManifest(w http.ResponseWriter, r *http.Request) {
|
||||
path := r.URL.Path
|
||||
if idx := lastIndex(path, "/manifests/"); idx >= 0 {
|
||||
s.handleManifestDetail(w, r)
|
||||
return
|
||||
}
|
||||
s.handleRepositoryDetail(w, r)
|
||||
}
|
||||
|
||||
// lastIndex returns the index of the last occurrence of sep in s, or -1.
|
||||
func lastIndex(s, sep string) int {
|
||||
for i := len(s) - len(sep); i >= 0; i-- {
|
||||
if s[i:i+len(sep)] == sep {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
608
internal/webserver/server_test.go
Normal file
608
internal/webserver/server_test.go
Normal file
@@ -0,0 +1,608 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
mcrv1 "git.wntrmute.dev/kyle/mcr/gen/mcr/v1"
|
||||
)
|
||||
|
||||
// fakeRegistryService implements RegistryServiceServer for testing.
|
||||
type fakeRegistryService struct {
|
||||
mcrv1.UnimplementedRegistryServiceServer
|
||||
repos []*mcrv1.RepositoryMetadata
|
||||
repoResp *mcrv1.GetRepositoryResponse
|
||||
repoErr error
|
||||
}
|
||||
|
||||
func (f *fakeRegistryService) ListRepositories(_ context.Context, _ *mcrv1.ListRepositoriesRequest) (*mcrv1.ListRepositoriesResponse, error) {
|
||||
return &mcrv1.ListRepositoriesResponse{Repositories: f.repos}, nil
|
||||
}
|
||||
|
||||
func (f *fakeRegistryService) GetRepository(_ context.Context, req *mcrv1.GetRepositoryRequest) (*mcrv1.GetRepositoryResponse, error) {
|
||||
if f.repoErr != nil {
|
||||
return nil, f.repoErr
|
||||
}
|
||||
if f.repoResp != nil {
|
||||
return f.repoResp, nil
|
||||
}
|
||||
return &mcrv1.GetRepositoryResponse{Name: req.GetName()}, nil
|
||||
}
|
||||
|
||||
// fakePolicyService implements PolicyServiceServer for testing.
|
||||
type fakePolicyService struct {
|
||||
mcrv1.UnimplementedPolicyServiceServer
|
||||
rules []*mcrv1.PolicyRule
|
||||
created *mcrv1.PolicyRule
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) ListPolicyRules(_ context.Context, _ *mcrv1.ListPolicyRulesRequest) (*mcrv1.ListPolicyRulesResponse, error) {
|
||||
return &mcrv1.ListPolicyRulesResponse{Rules: f.rules}, nil
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) CreatePolicyRule(_ context.Context, req *mcrv1.CreatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
rule := &mcrv1.PolicyRule{
|
||||
Id: 1,
|
||||
Priority: req.GetPriority(),
|
||||
Description: req.GetDescription(),
|
||||
Effect: req.GetEffect(),
|
||||
Actions: req.GetActions(),
|
||||
Repositories: req.GetRepositories(),
|
||||
Enabled: req.GetEnabled(),
|
||||
}
|
||||
f.created = rule
|
||||
return rule, nil
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) GetPolicyRule(_ context.Context, req *mcrv1.GetPolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
for _, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) UpdatePolicyRule(_ context.Context, req *mcrv1.UpdatePolicyRuleRequest) (*mcrv1.PolicyRule, error) {
|
||||
for _, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
r.Enabled = req.GetEnabled()
|
||||
return r, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
func (f *fakePolicyService) DeletePolicyRule(_ context.Context, req *mcrv1.DeletePolicyRuleRequest) (*mcrv1.DeletePolicyRuleResponse, error) {
|
||||
for i, r := range f.rules {
|
||||
if r.GetId() == req.GetId() {
|
||||
f.rules = append(f.rules[:i], f.rules[i+1:]...)
|
||||
return &mcrv1.DeletePolicyRuleResponse{}, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "policy rule not found")
|
||||
}
|
||||
|
||||
// fakeAuditService implements AuditServiceServer for testing.
|
||||
type fakeAuditService struct {
|
||||
mcrv1.UnimplementedAuditServiceServer
|
||||
events []*mcrv1.AuditEvent
|
||||
}
|
||||
|
||||
func (f *fakeAuditService) ListAuditEvents(_ context.Context, _ *mcrv1.ListAuditEventsRequest) (*mcrv1.ListAuditEventsResponse, error) {
|
||||
return &mcrv1.ListAuditEventsResponse{Events: f.events}, nil
|
||||
}
|
||||
|
||||
// fakeAdminService implements AdminServiceServer for testing.
|
||||
type fakeAdminService struct {
|
||||
mcrv1.UnimplementedAdminServiceServer
|
||||
}
|
||||
|
||||
func (f *fakeAdminService) Health(_ context.Context, _ *mcrv1.HealthRequest) (*mcrv1.HealthResponse, error) {
|
||||
return &mcrv1.HealthResponse{Status: "ok"}, nil
|
||||
}
|
||||
|
||||
// testEnv holds a test server and its dependencies.
|
||||
type testEnv struct {
|
||||
server *Server
|
||||
grpcServer *grpc.Server
|
||||
grpcConn *grpc.ClientConn
|
||||
registry *fakeRegistryService
|
||||
policyFake *fakePolicyService
|
||||
auditFake *fakeAuditService
|
||||
}
|
||||
|
||||
func (e *testEnv) close() {
|
||||
_ = e.grpcConn.Close()
|
||||
e.grpcServer.Stop()
|
||||
}
|
||||
|
||||
// setupTestEnv creates a test environment with fake gRPC backends.
|
||||
func setupTestEnv(t *testing.T) *testEnv {
|
||||
t.Helper()
|
||||
|
||||
registrySvc := &fakeRegistryService{
|
||||
repos: []*mcrv1.RepositoryMetadata{
|
||||
{Name: "library/nginx", TagCount: 3, ManifestCount: 2, TotalSize: 1024 * 1024, CreatedAt: "2024-01-15T10:00:00Z"},
|
||||
{Name: "library/alpine", TagCount: 1, ManifestCount: 1, TotalSize: 512 * 1024, CreatedAt: "2024-01-16T10:00:00Z"},
|
||||
},
|
||||
}
|
||||
policySvc := &fakePolicyService{
|
||||
rules: []*mcrv1.PolicyRule{
|
||||
{Id: 1, Priority: 100, Description: "Allow all pulls", Effect: "allow", Actions: []string{"pull"}, Repositories: []string{"*"}, Enabled: true},
|
||||
},
|
||||
}
|
||||
auditSvc := &fakeAuditService{
|
||||
events: []*mcrv1.AuditEvent{
|
||||
{Id: 1, EventTime: "2024-01-15T12:00:00Z", EventType: "manifest_pushed", ActorId: "user1", Repository: "library/nginx", Digest: "sha256:abc123", IpAddress: "10.0.0.1"},
|
||||
},
|
||||
}
|
||||
adminSvc := &fakeAdminService{}
|
||||
|
||||
// Start in-process gRPC server.
|
||||
gs := grpc.NewServer()
|
||||
mcrv1.RegisterRegistryServiceServer(gs, registrySvc)
|
||||
mcrv1.RegisterPolicyServiceServer(gs, policySvc)
|
||||
mcrv1.RegisterAuditServiceServer(gs, auditSvc)
|
||||
mcrv1.RegisterAdminServiceServer(gs, adminSvc)
|
||||
|
||||
lis, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatalf("listen: %v", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
_ = gs.Serve(lis)
|
||||
}()
|
||||
|
||||
// Connect client.
|
||||
conn, err := grpc.NewClient(
|
||||
lis.Addr().String(),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithDefaultCallOptions(grpc.ForceCodecV2(mcrv1.JSONCodec{})),
|
||||
)
|
||||
if err != nil {
|
||||
gs.Stop()
|
||||
t.Fatalf("dial: %v", err)
|
||||
}
|
||||
|
||||
csrfKey := []byte("test-csrf-key-32-bytes-long!1234")
|
||||
|
||||
loginFn := func(username, password string) (string, int, error) {
|
||||
if username == "admin" && password == "secret" {
|
||||
return "test-token-12345", 3600, nil
|
||||
}
|
||||
return "", 0, fmt.Errorf("invalid credentials")
|
||||
}
|
||||
|
||||
srv, err := New(
|
||||
mcrv1.NewRegistryServiceClient(conn),
|
||||
mcrv1.NewPolicyServiceClient(conn),
|
||||
mcrv1.NewAuditServiceClient(conn),
|
||||
mcrv1.NewAdminServiceClient(conn),
|
||||
loginFn,
|
||||
csrfKey,
|
||||
)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
gs.Stop()
|
||||
t.Fatalf("create server: %v", err)
|
||||
}
|
||||
|
||||
return &testEnv{
|
||||
server: srv,
|
||||
grpcServer: gs,
|
||||
grpcConn: conn,
|
||||
registry: registrySvc,
|
||||
policyFake: policySvc,
|
||||
auditFake: auditSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginPageRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /login: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "MCR Login") {
|
||||
t.Error("login page does not contain 'MCR Login'")
|
||||
}
|
||||
if !strings.Contains(body, "_csrf") {
|
||||
t.Error("login page does not contain CSRF token field")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginInvalidCredentials(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// First get a CSRF token.
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
env.server.Handler().ServeHTTP(getRec, getReq)
|
||||
|
||||
// Extract CSRF cookie and token.
|
||||
var csrfCookie *http.Cookie
|
||||
for _, c := range getRec.Result().Cookies() {
|
||||
if c.Name == "csrf_token" {
|
||||
csrfCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if csrfCookie == nil {
|
||||
t.Fatal("no csrf_token cookie set")
|
||||
}
|
||||
|
||||
// Extract the CSRF token from the cookie value (token.signature).
|
||||
parts := strings.SplitN(csrfCookie.Value, ".", 2)
|
||||
csrfToken := parts[0]
|
||||
|
||||
// Submit login with wrong credentials.
|
||||
form := url.Values{
|
||||
"username": {"baduser"},
|
||||
"password": {"badpass"},
|
||||
"_csrf": {csrfToken},
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq.AddCookie(csrfCookie)
|
||||
postRec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(postRec, postReq)
|
||||
|
||||
if postRec.Code != http.StatusOK {
|
||||
t.Fatalf("POST /login: status %d, want %d", postRec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := postRec.Body.String()
|
||||
if !strings.Contains(body, "Invalid username or password") {
|
||||
t.Error("response does not contain error message for invalid credentials")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardRequiresSession(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("GET / without session: status %d, want %d", rec.Code, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
loc := rec.Header().Get("Location")
|
||||
if loc != "/login" {
|
||||
t.Fatalf("redirect location: got %q, want /login", loc)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDashboardWithSession(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET / with session: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "Dashboard") {
|
||||
t.Error("dashboard page does not contain 'Dashboard'")
|
||||
}
|
||||
if !strings.Contains(body, "Repositories") {
|
||||
t.Error("dashboard page does not show repository count")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepositoriesPageRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/repositories", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /repositories: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "library/nginx") {
|
||||
t.Error("repositories page does not contain 'library/nginx'")
|
||||
}
|
||||
if !strings.Contains(body, "library/alpine") {
|
||||
t.Error("repositories page does not contain 'library/alpine'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRepositoryDetailRenders(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
env.registry.repoResp = &mcrv1.GetRepositoryResponse{
|
||||
Name: "library/nginx",
|
||||
TotalSize: 2048,
|
||||
Tags: []*mcrv1.TagInfo{
|
||||
{Name: "latest", Digest: "sha256:abc123def456"},
|
||||
},
|
||||
Manifests: []*mcrv1.ManifestInfo{
|
||||
{Digest: "sha256:abc123def456", MediaType: "application/vnd.oci.image.manifest.v1+json", Size: 2048, CreatedAt: "2024-01-15T10:00:00Z"},
|
||||
},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/repositories/library/nginx", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /repositories/library/nginx: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "library/nginx") {
|
||||
t.Error("repository detail page does not contain repo name")
|
||||
}
|
||||
if !strings.Contains(body, "latest") {
|
||||
t.Error("repository detail page does not contain tag 'latest'")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFTokenValidation(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// POST without CSRF token should fail.
|
||||
form := url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"secret"},
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
body := rec.Body.String()
|
||||
// Should show the error about invalid form submission.
|
||||
if !strings.Contains(body, "Invalid or expired form submission") {
|
||||
t.Error("POST without CSRF token should show error, got: " + body[:min(200, len(body))])
|
||||
}
|
||||
}
|
||||
|
||||
func TestLogout(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/logout", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("GET /logout: status %d, want %d", rec.Code, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
loc := rec.Header().Get("Location")
|
||||
if loc != "/login" {
|
||||
t.Fatalf("redirect location: got %q, want /login", loc)
|
||||
}
|
||||
|
||||
// Verify session cookie is cleared.
|
||||
var sessionCleared bool
|
||||
for _, c := range rec.Result().Cookies() {
|
||||
if c.Name == "mcr_session" && c.MaxAge < 0 {
|
||||
sessionCleared = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !sessionCleared {
|
||||
t.Error("session cookie was not cleared on logout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPoliciesPage(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/policies", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /policies: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "Allow all pulls") {
|
||||
t.Error("policies page does not contain policy description")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuditPage(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/audit", nil)
|
||||
req.AddCookie(&http.Cookie{Name: "mcr_session", Value: "test-token"})
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /audit: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "manifest_pushed") {
|
||||
t.Error("audit page does not contain event type")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStaticFiles(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/static/style.css", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Fatalf("GET /static/style.css: status %d, want %d", rec.Code, http.StatusOK)
|
||||
}
|
||||
|
||||
body := rec.Body.String()
|
||||
if !strings.Contains(body, "font-family") {
|
||||
t.Error("style.css does not appear to contain CSS")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
input int64
|
||||
want string
|
||||
}{
|
||||
{0, "0 B"},
|
||||
{512, "512 B"},
|
||||
{1024, "1.0 KiB"},
|
||||
{1048576, "1.0 MiB"},
|
||||
{1073741824, "1.0 GiB"},
|
||||
{1099511627776, "1.0 TiB"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := formatSize(tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("formatSize(%d) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFormatTime(t *testing.T) {
|
||||
got := formatTime("2024-01-15T10:30:00Z")
|
||||
want := "2024-01-15 10:30:00"
|
||||
if got != want {
|
||||
t.Errorf("formatTime = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Invalid time returns the input.
|
||||
got = formatTime("not-a-time")
|
||||
if got != "not-a-time" {
|
||||
t.Errorf("formatTime(invalid) = %q, want %q", got, "not-a-time")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTruncate(t *testing.T) {
|
||||
got := truncate("sha256:abc123def456", 12)
|
||||
want := "sha256:abc12..."
|
||||
if got != want {
|
||||
t.Errorf("truncate = %q, want %q", got, want)
|
||||
}
|
||||
|
||||
// Short strings are not truncated.
|
||||
got = truncate("short", 10)
|
||||
if got != "short" {
|
||||
t.Errorf("truncate(short) = %q, want %q", got, "short")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoginSuccessSetsCookie(t *testing.T) {
|
||||
env := setupTestEnv(t)
|
||||
defer env.close()
|
||||
|
||||
// Get CSRF token.
|
||||
getReq := httptest.NewRequest(http.MethodGet, "/login", nil)
|
||||
getRec := httptest.NewRecorder()
|
||||
env.server.Handler().ServeHTTP(getRec, getReq)
|
||||
|
||||
var csrfCookie *http.Cookie
|
||||
for _, c := range getRec.Result().Cookies() {
|
||||
if c.Name == "csrf_token" {
|
||||
csrfCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if csrfCookie == nil {
|
||||
t.Fatal("no csrf_token cookie")
|
||||
}
|
||||
|
||||
parts := strings.SplitN(csrfCookie.Value, ".", 2)
|
||||
csrfToken := parts[0]
|
||||
|
||||
form := url.Values{
|
||||
"username": {"admin"},
|
||||
"password": {"secret"},
|
||||
"_csrf": {csrfToken},
|
||||
}
|
||||
|
||||
postReq := httptest.NewRequest(http.MethodPost, "/login", strings.NewReader(form.Encode()))
|
||||
postReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
postReq.AddCookie(csrfCookie)
|
||||
postRec := httptest.NewRecorder()
|
||||
|
||||
env.server.Handler().ServeHTTP(postRec, postReq)
|
||||
|
||||
if postRec.Code != http.StatusSeeOther {
|
||||
t.Fatalf("POST /login: status %d, want %d; body: %s", postRec.Code, http.StatusSeeOther, postRec.Body.String())
|
||||
}
|
||||
|
||||
var sessionCookie *http.Cookie
|
||||
for _, c := range postRec.Result().Cookies() {
|
||||
if c.Name == "mcr_session" {
|
||||
sessionCookie = c
|
||||
break
|
||||
}
|
||||
}
|
||||
if sessionCookie == nil {
|
||||
t.Fatal("no mcr_session cookie set after login")
|
||||
}
|
||||
if sessionCookie.Value != "test-token-12345" {
|
||||
t.Errorf("session cookie value = %q, want %q", sessionCookie.Value, "test-token-12345")
|
||||
}
|
||||
if !sessionCookie.HttpOnly {
|
||||
t.Error("session cookie is not HttpOnly")
|
||||
}
|
||||
if !sessionCookie.Secure {
|
||||
t.Error("session cookie is not Secure")
|
||||
}
|
||||
if sessionCookie.SameSite != http.SameSiteStrictMode {
|
||||
t.Error("session cookie SameSite is not Strict")
|
||||
}
|
||||
}
|
||||
136
internal/webserver/templates.go
Normal file
136
internal/webserver/templates.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcr/web"
|
||||
)
|
||||
|
||||
// templateSet wraps parsed templates and provides a render method.
|
||||
type templateSet struct {
|
||||
templates map[string]*template.Template
|
||||
}
|
||||
|
||||
// templateFuncs returns the function map used across all templates.
|
||||
func templateFuncs() template.FuncMap {
|
||||
return template.FuncMap{
|
||||
"formatSize": formatSize,
|
||||
"formatTime": formatTime,
|
||||
"truncate": truncate,
|
||||
"joinStrings": joinStrings,
|
||||
}
|
||||
}
|
||||
|
||||
// loadTemplates parses all page templates with the layout template.
|
||||
func loadTemplates() (*templateSet, error) {
|
||||
// Read layout template.
|
||||
layoutBytes, err := fs.ReadFile(web.Content, "templates/layout.html")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("webserver: read layout template: %w", err)
|
||||
}
|
||||
layoutStr := string(layoutBytes)
|
||||
|
||||
pages := []string{
|
||||
"login",
|
||||
"dashboard",
|
||||
"repositories",
|
||||
"repository_detail",
|
||||
"manifest_detail",
|
||||
"policies",
|
||||
"audit",
|
||||
}
|
||||
|
||||
ts := &templateSet{
|
||||
templates: make(map[string]*template.Template, len(pages)),
|
||||
}
|
||||
|
||||
for _, page := range pages {
|
||||
pageBytes, readErr := fs.ReadFile(web.Content, "templates/"+page+".html")
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("webserver: read template %s: %w", page, readErr)
|
||||
}
|
||||
|
||||
t, parseErr := template.New("layout").Funcs(templateFuncs()).Parse(layoutStr)
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("webserver: parse layout for %s: %w", page, parseErr)
|
||||
}
|
||||
|
||||
_, parseErr = t.Parse(string(pageBytes))
|
||||
if parseErr != nil {
|
||||
return nil, fmt.Errorf("webserver: parse template %s: %w", page, parseErr)
|
||||
}
|
||||
|
||||
ts.templates[page] = t
|
||||
}
|
||||
|
||||
return ts, nil
|
||||
}
|
||||
|
||||
// render executes a named template and writes the result to w.
|
||||
func (ts *templateSet) render(w http.ResponseWriter, name string, data any) {
|
||||
t, ok := ts.templates[name]
|
||||
if !ok {
|
||||
http.Error(w, "template not found", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
if err := t.Execute(w, data); err != nil {
|
||||
// Template already started writing; log but don't send another error.
|
||||
_ = err // best-effort; headers may already be sent
|
||||
}
|
||||
}
|
||||
|
||||
// formatSize converts bytes to a human-readable string.
|
||||
func formatSize(b int64) string {
|
||||
const (
|
||||
kib = 1024
|
||||
mib = 1024 * kib
|
||||
gib = 1024 * mib
|
||||
tib = 1024 * gib
|
||||
)
|
||||
|
||||
switch {
|
||||
case b >= tib:
|
||||
return fmt.Sprintf("%.1f TiB", float64(b)/float64(tib))
|
||||
case b >= gib:
|
||||
return fmt.Sprintf("%.1f GiB", float64(b)/float64(gib))
|
||||
case b >= mib:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/float64(mib))
|
||||
case b >= kib:
|
||||
return fmt.Sprintf("%.1f KiB", float64(b)/float64(kib))
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
}
|
||||
|
||||
// formatTime converts an RFC3339 string to a more readable format.
|
||||
func formatTime(s string) string {
|
||||
t, err := time.Parse(time.RFC3339, s)
|
||||
if err != nil {
|
||||
// Try RFC3339Nano.
|
||||
t, err = time.Parse(time.RFC3339Nano, s)
|
||||
if err != nil {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return t.Format("2006-01-02 15:04:05")
|
||||
}
|
||||
|
||||
// truncate returns the first n characters of s, appending "..." if truncated.
|
||||
func truncate(s string, n int) string {
|
||||
if len(s) <= n {
|
||||
return s
|
||||
}
|
||||
return s[:n] + "..."
|
||||
}
|
||||
|
||||
// joinStrings joins string slices for template rendering.
|
||||
func joinStrings(ss []string, sep string) string {
|
||||
return strings.Join(ss, sep)
|
||||
}
|
||||
Reference in New Issue
Block a user