Migrate CSRF, web templates, session cookies, and snapshot to mcdsl
CSRF: Replace local csrfProtect with mcdsl/csrf.Protect. Delete internal/webserver/csrf.go. Web: Replace renderTemplate with web.RenderTemplate + csrf.TemplateFunc. Replace extractCookie with web.GetSessionToken. Replace manual session cookie SetCookie with web.SetSessionCookie. Snapshot: Replace local sqliteBackup with mcdsl/db.Snapshot. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,12 +1,11 @@
|
|||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
||||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||||
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
"git.wntrmute.dev/kyle/metacrypt/internal/db"
|
||||||
)
|
)
|
||||||
@@ -42,17 +41,9 @@ func runSnapshot(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
defer func() { _ = database.Close() }()
|
defer func() { _ = database.Close() }()
|
||||||
|
|
||||||
if err := sqliteBackup(database, snapshotOutput); err != nil {
|
if err := mcdsldb.Snapshot(database, snapshotOutput); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
fmt.Printf("Snapshot saved to %s\n", snapshotOutput)
|
fmt.Printf("Snapshot saved to %s\n", snapshotOutput)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func sqliteBackup(srcDB *sql.DB, dstPath string) error {
|
|
||||||
_, err := srcDB.Exec("VACUUM INTO ?", dstPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("snapshot: %w", err)
|
|
||||||
}
|
|
||||||
return os.Chmod(dstPath, 0600)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -307,7 +307,7 @@ func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
|
|||||||
vault: vault,
|
vault: vault,
|
||||||
logger: slog.Default(),
|
logger: slog.Default(),
|
||||||
staticFS: staticFS,
|
staticFS: staticFS,
|
||||||
csrf: newCSRFProtect(),
|
csrf: newTestCSRF(t),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,119 +0,0 @@
|
|||||||
package webserver
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/hmac"
|
|
||||||
"crypto/rand"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
csrfCookieName = "metacrypt_csrf"
|
|
||||||
csrfFieldName = "csrf_token"
|
|
||||||
csrfTokenLen = 32
|
|
||||||
)
|
|
||||||
|
|
||||||
// csrfProtect provides CSRF protection using the signed double-submit cookie
|
|
||||||
// pattern. A random secret is generated at startup. CSRF tokens are an HMAC of
|
|
||||||
// a random nonce, sent as both a cookie and a hidden form field. On POST the
|
|
||||||
// middleware verifies that the form field matches the cookie's HMAC.
|
|
||||||
type csrfProtect struct {
|
|
||||||
secret []byte
|
|
||||||
once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func newCSRFProtect() *csrfProtect {
|
|
||||||
secret := make([]byte, 32)
|
|
||||||
if _, err := rand.Read(secret); err != nil {
|
|
||||||
panic("csrf: failed to generate secret: " + err.Error())
|
|
||||||
}
|
|
||||||
return &csrfProtect{secret: secret}
|
|
||||||
}
|
|
||||||
|
|
||||||
// generateToken creates a new CSRF token: base64(nonce) + "." + base64(hmac(nonce)).
|
|
||||||
func (c *csrfProtect) generateToken() (string, error) {
|
|
||||||
nonce := make([]byte, csrfTokenLen)
|
|
||||||
if _, err := rand.Read(nonce); err != nil {
|
|
||||||
return "", fmt.Errorf("csrf: generate nonce: %w", err)
|
|
||||||
}
|
|
||||||
nonceB64 := base64.RawURLEncoding.EncodeToString(nonce)
|
|
||||||
mac := hmac.New(sha256.New, c.secret)
|
|
||||||
mac.Write(nonce)
|
|
||||||
sigB64 := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
||||||
return nonceB64 + "." + sigB64, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// validToken checks that a token has a valid HMAC signature.
|
|
||||||
func (c *csrfProtect) validToken(token string) bool {
|
|
||||||
parts := strings.SplitN(token, ".", 2)
|
|
||||||
if len(parts) != 2 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
nonce, err := base64.RawURLEncoding.DecodeString(parts[0])
|
|
||||||
if err != nil || len(nonce) != csrfTokenLen {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
sig, err := base64.RawURLEncoding.DecodeString(parts[1])
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
mac := hmac.New(sha256.New, c.secret)
|
|
||||||
mac.Write(nonce)
|
|
||||||
return hmac.Equal(mac.Sum(nil), sig)
|
|
||||||
}
|
|
||||||
|
|
||||||
// setToken generates a new CSRF token, sets it as a cookie, and returns it
|
|
||||||
// for embedding in a form.
|
|
||||||
func (c *csrfProtect) setToken(w http.ResponseWriter) string {
|
|
||||||
token, err := c.generateToken()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
http.SetCookie(w, &http.Cookie{
|
|
||||||
Name: csrfCookieName,
|
|
||||||
Value: token,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
return token
|
|
||||||
}
|
|
||||||
|
|
||||||
// middleware returns an HTTP middleware that enforces CSRF validation on
|
|
||||||
// mutation requests (POST, PUT, PATCH, DELETE). GET/HEAD/OPTIONS are passed
|
|
||||||
// through. The HTMX hx-post for /v1/seal is excluded since it hits the API
|
|
||||||
// server directly and uses token auth, not cookies.
|
|
||||||
func (c *csrfProtect) middleware(next http.Handler) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
switch r.Method {
|
|
||||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read token from form field (works for both regular forms and
|
|
||||||
// multipart forms since ParseForm/ParseMultipartForm will have
|
|
||||||
// been called or the field is available via FormValue).
|
|
||||||
formToken := r.FormValue(csrfFieldName)
|
|
||||||
|
|
||||||
// Read token from cookie.
|
|
||||||
cookie, err := r.Cookie(csrfCookieName)
|
|
||||||
if err != nil || cookie.Value == "" {
|
|
||||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Both tokens must be present, match each other, and be validly signed.
|
|
||||||
if formToken == "" || formToken != cookie.Value || !c.validToken(formToken) {
|
|
||||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
next.ServeHTTP(w, r)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,44 +1,28 @@
|
|||||||
package webserver
|
package webserver
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcdsl/csrf"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCSRFTokenGenerateAndValidate(t *testing.T) {
|
func newTestCSRF(t *testing.T) *csrf.Protect {
|
||||||
c := newCSRFProtect()
|
t.Helper()
|
||||||
token, err := c.generateToken()
|
secret := make([]byte, 32)
|
||||||
if err != nil {
|
if _, err := rand.Read(secret); err != nil {
|
||||||
t.Fatalf("generateToken: %v", err)
|
t.Fatal(err)
|
||||||
}
|
|
||||||
if !c.validToken(token) {
|
|
||||||
t.Fatal("valid token rejected")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCSRFTokenInvalidFormats(t *testing.T) {
|
|
||||||
c := newCSRFProtect()
|
|
||||||
for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} {
|
|
||||||
if c.validToken(bad) {
|
|
||||||
t.Errorf("should reject %q", bad)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCSRFTokenCrossSecret(t *testing.T) {
|
|
||||||
c1 := newCSRFProtect()
|
|
||||||
c2 := newCSRFProtect()
|
|
||||||
token, _ := c1.generateToken()
|
|
||||||
if c2.validToken(token) {
|
|
||||||
t.Fatal("token from different secret should be rejected")
|
|
||||||
}
|
}
|
||||||
|
return csrf.New(secret, "metacrypt_csrf", "csrf_token")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
|
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
|
||||||
c := newCSRFProtect()
|
c := newTestCSRF(t)
|
||||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
@@ -50,8 +34,8 @@ func TestCSRFMiddlewareAllowsGET(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
|
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
|
||||||
c := newCSRFProtect()
|
c := newTestCSRF(t)
|
||||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
||||||
@@ -64,17 +48,20 @@ func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
|
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
|
||||||
c := newCSRFProtect()
|
c := newTestCSRF(t)
|
||||||
token, _ := c.generateToken()
|
|
||||||
|
|
||||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
// Generate a token via SetToken (writes cookie to response).
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
token := c.SetToken(rec)
|
||||||
|
|
||||||
|
handler := c.Middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
body := csrfFieldName + "=" + token
|
body := "csrf_token=" + url.QueryEscape(token)
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token})
|
req.AddCookie(&http.Cookie{Name: "metacrypt_csrf", Value: token})
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
handler.ServeHTTP(w, req)
|
handler.ServeHTTP(w, req)
|
||||||
@@ -82,24 +69,3 @@ func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
|
|||||||
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
|
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCSRFMiddlewareRejectsMismatch(t *testing.T) {
|
|
||||||
c := newCSRFProtect()
|
|
||||||
token1, _ := c.generateToken()
|
|
||||||
token2, _ := c.generateToken()
|
|
||||||
|
|
||||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
}))
|
|
||||||
|
|
||||||
body := csrfFieldName + "=" + token1
|
|
||||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
|
||||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
|
||||||
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2})
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
handler.ServeHTTP(w, req)
|
|
||||||
if w.Code != http.StatusForbidden {
|
|
||||||
t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ import (
|
|||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"git.wntrmute.dev/kyle/mcdsl/web"
|
||||||
)
|
)
|
||||||
|
|
||||||
// splitLines splits a newline-delimited string into non-empty trimmed lines.
|
// splitLines splits a newline-delimited string into non-empty trimmed lines.
|
||||||
@@ -224,14 +226,7 @@ func (ws *WebServer) handleLogin(w http.ResponseWriter, r *http.Request) {
|
|||||||
ws.renderTemplate(w, "login.html", map[string]interface{}{"Error": "Invalid credentials"})
|
ws.renderTemplate(w, "login.html", map[string]interface{}{"Error": "Invalid credentials"})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
http.SetCookie(w, &http.Cookie{
|
web.SetSessionCookie(w, "metacrypt_token", token)
|
||||||
Name: "metacrypt_token",
|
|
||||||
Value: token,
|
|
||||||
Path: "/",
|
|
||||||
HttpOnly: true,
|
|
||||||
Secure: true,
|
|
||||||
SameSite: http.SameSiteStrictMode,
|
|
||||||
})
|
|
||||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||||
default:
|
default:
|
||||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
|||||||
@@ -4,9 +4,9 @@ package webserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"html/template"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -16,6 +16,8 @@ import (
|
|||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
|
||||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||||
|
"git.wntrmute.dev/kyle/mcdsl/csrf"
|
||||||
|
"git.wntrmute.dev/kyle/mcdsl/web"
|
||||||
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
"git.wntrmute.dev/kyle/metacrypt/internal/config"
|
||||||
webui "git.wntrmute.dev/kyle/metacrypt/web"
|
webui "git.wntrmute.dev/kyle/metacrypt/web"
|
||||||
)
|
)
|
||||||
@@ -116,7 +118,7 @@ type WebServer struct {
|
|||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
httpSrv *http.Server
|
httpSrv *http.Server
|
||||||
staticFS fs.FS
|
staticFS fs.FS
|
||||||
csrf *csrfProtect
|
csrf *csrf.Protect
|
||||||
tgzCache sync.Map // key: UUID string → *tgzEntry
|
tgzCache sync.Map // key: UUID string → *tgzEntry
|
||||||
userCache sync.Map // key: UUID string → *cachedUsername
|
userCache sync.Map // key: UUID string → *cachedUsername
|
||||||
}
|
}
|
||||||
@@ -154,12 +156,17 @@ func New(cfg *config.Config, logger *slog.Logger) (*WebServer, error) {
|
|||||||
return nil, fmt.Errorf("webserver: static FS: %w", err)
|
return nil, fmt.Errorf("webserver: static FS: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
secret := make([]byte, 32)
|
||||||
|
if _, err := rand.Read(secret); err != nil {
|
||||||
|
return nil, fmt.Errorf("webserver: generate CSRF secret: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
ws := &WebServer{
|
ws := &WebServer{
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
vault: vault,
|
vault: vault,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
staticFS: staticFS,
|
staticFS: staticFS,
|
||||||
csrf: newCSRFProtect(),
|
csrf: csrf.New(secret, "metacrypt_csrf", "csrf_token"),
|
||||||
}
|
}
|
||||||
|
|
||||||
if tok := cfg.MCIAS.ServiceToken; tok != "" {
|
if tok := cfg.MCIAS.ServiceToken; tok != "" {
|
||||||
@@ -220,7 +227,7 @@ func (lw *loggingResponseWriter) Unwrap() http.ResponseWriter {
|
|||||||
func (ws *WebServer) Start() error {
|
func (ws *WebServer) Start() error {
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(ws.loggingMiddleware)
|
r.Use(ws.loggingMiddleware)
|
||||||
r.Use(ws.csrf.middleware)
|
r.Use(ws.csrf.Middleware)
|
||||||
ws.registerRoutes(r)
|
ws.registerRoutes(r)
|
||||||
|
|
||||||
ws.httpSrv = &http.Server{
|
ws.httpSrv = &http.Server{
|
||||||
@@ -259,36 +266,9 @@ func (ws *WebServer) Shutdown(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
func (ws *WebServer) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||||
csrfToken := ws.csrf.setToken(w)
|
web.RenderTemplate(w, webui.FS, name, data, ws.csrf.TemplateFunc(w))
|
||||||
|
|
||||||
funcMap := template.FuncMap{
|
|
||||||
"csrfField": func() template.HTML {
|
|
||||||
return template.HTML(fmt.Sprintf(
|
|
||||||
`<input type="hidden" name="%s" value="%s">`,
|
|
||||||
csrfFieldName, template.HTMLEscapeString(csrfToken),
|
|
||||||
))
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(webui.FS,
|
|
||||||
"templates/layout.html",
|
|
||||||
"templates/"+name,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
ws.logger.Error("parse template", "name", name, "error", err)
|
|
||||||
http.Error(w, "internal server error", http.StatusInternalServerError)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
|
||||||
if err := tmpl.ExecuteTemplate(w, "layout", data); err != nil {
|
|
||||||
ws.logger.Error("execute template", "name", name, "error", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func extractCookie(r *http.Request) string {
|
func extractCookie(r *http.Request) string {
|
||||||
c, err := r.Cookie("metacrypt_token")
|
return web.GetSessionToken(r, "metacrypt_token")
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return c.Value
|
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user