From 4469c650ccb065ee400b9563b86a7086c157ebb6 Mon Sep 17 00:00:00 2001 From: Kyle Isom Date: Sun, 15 Mar 2026 13:44:32 -0700 Subject: [PATCH] Cache issued tgz in memory for one-time download Instead of streaming the tgz directly to the response (which was fragile under server write timeouts), handleIssueCert now: - Builds the tgz into a bytes.Buffer - Stores it in a sync.Map (tgzCache) under a random 16-byte hex token - Redirects the browser to /pki/download/{token} handleTGZDownload serves the cached bytes via LoadAndDelete, so the archive is removed from memory after the first (and only) download. An unknown or already-used token returns 404. Also adds TestHandleTGZDownload covering the one-time-use and not-found cases, and wires issueCertFn into mockVault. Co-authored-by: Junie --- internal/webserver/cert_detail_test.go | 60 ++++++++++++++++++++++++ internal/webserver/routes.go | 63 ++++++++++++++++++-------- internal/webserver/server.go | 8 ++++ 3 files changed, 112 insertions(+), 19 deletions(-) diff --git a/internal/webserver/cert_detail_test.go b/internal/webserver/cert_detail_test.go index 509ae26..1b7c51e 100644 --- a/internal/webserver/cert_detail_test.go +++ b/internal/webserver/cert_detail_test.go @@ -24,6 +24,7 @@ type mockVault struct { validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error) listMountsFn func(ctx context.Context, token string) ([]MountInfo, error) getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error) + issueCertFn func(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) } func (m *mockVault) Status(ctx context.Context) (string, error) { @@ -80,6 +81,9 @@ func (m *mockVault) ListIssuers(ctx context.Context, token, mount string) ([]str } func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) { + if m.issueCertFn != nil { + return m.issueCertFn(ctx, token, req) + } return nil, fmt.Errorf("not implemented") } @@ -108,6 +112,62 @@ func (m *mockVault) DeleteCert(ctx context.Context, token, mount, serial string) func (m *mockVault) Close() error { return nil } +// ---- handleTGZDownload tests ---- + +func TestHandleTGZDownload(t *testing.T) { + t.Run("valid token serves archive and removes it from cache", func(t *testing.T) { + ws := newTestWebServer(t, &mockVault{}) + + const dlToken = "abc123" + ws.tgzCache.Store(dlToken, &tgzEntry{ + filename: "test.tgz", + data: []byte("fake-tgz-data"), + }) + + r := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken}) + r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) + + w := httptest.NewRecorder() + ws.handleTGZDownload(w, r) + + if w.Code != http.StatusOK { + t.Errorf("status = %d, want %d", w.Code, http.StatusOK) + } + if ct := w.Header().Get("Content-Type"); ct != "application/gzip" { + t.Errorf("Content-Type = %q, want application/gzip", ct) + } + if body := w.Body.String(); body != "fake-tgz-data" { + t.Errorf("body = %q, want fake-tgz-data", body) + } + + // Entry must have been removed — a second request should 404. + w2 := httptest.NewRecorder() + r2 := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken}) + r2 = addAuthCookie(r2, &TokenInfo{Username: "testuser"}) + ws.handleTGZDownload(w2, r2) + if w2.Code != http.StatusNotFound { + t.Errorf("second request status = %d, want %d", w2.Code, http.StatusNotFound) + } + }) + + t.Run("unknown token returns 404", func(t *testing.T) { + ws := newTestWebServer(t, &mockVault{}) + + r := newChiRequest(http.MethodGet, "/pki/download/nosuchtoken", map[string]string{"token": "nosuchtoken"}) + r = addAuthCookie(r, &TokenInfo{Username: "testuser"}) + + w := httptest.NewRecorder() + ws.handleTGZDownload(w, r) + + if w.Code != http.StatusNotFound { + t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound) + } + if !strings.Contains(w.Body.String(), "download not found") { + t.Errorf("body %q does not contain 'download not found'", w.Body.String()) + } + }) +} + // newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests. func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer { t.Helper() diff --git a/internal/webserver/routes.go b/internal/webserver/routes.go index ba87941..42e979d 100644 --- a/internal/webserver/routes.go +++ b/internal/webserver/routes.go @@ -2,7 +2,10 @@ package webserver import ( "archive/tar" + "bytes" "compress/gzip" + "crypto/rand" + "encoding/hex" "fmt" "io" "net/http" @@ -40,6 +43,7 @@ func (ws *WebServer) registerRoutes(r chi.Router) { r.Post("/import-root", ws.requireAuth(ws.handleImportRoot)) r.Post("/create-issuer", ws.requireAuth(ws.handleCreateIssuer)) r.Post("/issue", ws.requireAuth(ws.handleIssueCert)) + r.Get("/download/{token}", ws.requireAuth(ws.handleTGZDownload)) r.Get("/issuer/{issuer}", ws.requireAuth(ws.handleIssuerDetail)) r.Get("/cert/{serial}", ws.requireAuth(ws.handleCertDetail)) r.Get("/cert/{serial}/download", ws.requireAuth(ws.handleCertDownload)) @@ -479,11 +483,6 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request) } func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) { - // Disable the server-wide write deadline for this handler: it streams a - // tgz response only after several serial gRPC calls, which can easily - // consume the 30 s WriteTimeout before we start writing. We set our own - // 60 s deadline just before the write phase below. - _ = http.NewResponseController(w).SetWriteDeadline(time.Time{}) info := tokenInfoFromContext(r.Context()) token := extractCookie(r) @@ -538,17 +537,11 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) { return } - // Stream a tgz archive containing the private key (PKCS8) and certificate. - // Extend the write deadline before streaming so that slow gRPC backends - // don't consume the server WriteTimeout before we start writing. - rc := http.NewResponseController(w) - _ = rc.SetWriteDeadline(time.Now().Add(60 * time.Second)) - - filename := issuedCert.Serial + ".tgz" - w.Header().Set("Content-Type", "application/gzip") - w.Header().Set("Content-Disposition", "attachment; filename=\""+filename+"\"") - - gw := gzip.NewWriter(w) + // Build the tgz archive in memory, store it in the cache, then redirect + // the browser to the one-time download URL so the archive is only served + // once and then discarded. + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) tw := tar.NewWriter(gw) writeTarFile := func(name string, data []byte) error { @@ -566,16 +559,48 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) { } if err := writeTarFile("key.pem", []byte(issuedCert.KeyPEM)); err != nil { - ws.logger.Error("write key to tgz", "error", err) + ws.logger.Error("build tgz key", "error", err) + http.Error(w, "failed to build archive", http.StatusInternalServerError) return } if err := writeTarFile("cert.pem", []byte(issuedCert.CertPEM)); err != nil { - ws.logger.Error("write cert to tgz", "error", err) + ws.logger.Error("build tgz cert", "error", err) + http.Error(w, "failed to build archive", http.StatusInternalServerError) return } - _ = tw.Close() _ = gw.Close() + + // Generate a random one-time token for the download URL. + var raw [16]byte + if _, err := rand.Read(raw[:]); err != nil { + ws.logger.Error("generate download token", "error", err) + http.Error(w, "internal server error", http.StatusInternalServerError) + return + } + dlToken := hex.EncodeToString(raw[:]) + + ws.tgzCache.Store(dlToken, &tgzEntry{ + filename: issuedCert.Serial + ".tgz", + data: buf.Bytes(), + }) + + http.Redirect(w, r, "/pki/download/"+dlToken, http.StatusSeeOther) +} + +func (ws *WebServer) handleTGZDownload(w http.ResponseWriter, r *http.Request) { + dlToken := chi.URLParam(r, "token") + + val, ok := ws.tgzCache.LoadAndDelete(dlToken) + if !ok { + http.Error(w, "download not found or already used", http.StatusNotFound) + return + } + entry := val.(*tgzEntry) + + w.Header().Set("Content-Type", "application/gzip") + w.Header().Set("Content-Disposition", "attachment; filename=\""+entry.filename+"\"") + _, _ = w.Write(entry.data) } func (ws *WebServer) handleCertDetail(w http.ResponseWriter, r *http.Request) { diff --git a/internal/webserver/server.go b/internal/webserver/server.go index 5d9fe5b..00836ac 100644 --- a/internal/webserver/server.go +++ b/internal/webserver/server.go @@ -10,6 +10,7 @@ import ( "io/fs" "log/slog" "net/http" + "sync" "time" "github.com/go-chi/chi/v5" @@ -42,6 +43,12 @@ type vaultBackend interface { Close() error } +// tgzEntry holds a cached tgz archive pending download. +type tgzEntry struct { + filename string + data []byte +} + // WebServer is the standalone web UI server. type WebServer struct { cfg *config.Config @@ -49,6 +56,7 @@ type WebServer struct { logger *slog.Logger httpSrv *http.Server staticFS fs.FS + tgzCache sync.Map // key: UUID string → *tgzEntry } // New creates a new WebServer. It dials the vault gRPC endpoint.