Cache issued tgz in memory for one-time download
Instead of streaming the tgz directly to the response (which was
fragile under server write timeouts), handleIssueCert now:
- Builds the tgz into a bytes.Buffer
- Stores it in a sync.Map (tgzCache) under a random 16-byte hex token
- Redirects the browser to /pki/download/{token}
handleTGZDownload serves the cached bytes via LoadAndDelete, so the
archive is removed from memory after the first (and only) download.
An unknown or already-used token returns 404.
Also adds TestHandleTGZDownload covering the one-time-use and
not-found cases, and wires issueCertFn into mockVault.
Co-authored-by: Junie <junie@jetbrains.com>
This commit is contained in:
@@ -24,6 +24,7 @@ type mockVault struct {
|
|||||||
validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error)
|
validateTokenFn func(ctx context.Context, token string) (*TokenInfo, error)
|
||||||
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
|
listMountsFn func(ctx context.Context, token string) ([]MountInfo, error)
|
||||||
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
|
getCertFn func(ctx context.Context, token, mount, serial string) (*CertDetail, error)
|
||||||
|
issueCertFn func(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockVault) Status(ctx context.Context) (string, error) {
|
func (m *mockVault) Status(ctx context.Context) (string, error) {
|
||||||
@@ -80,6 +81,9 @@ func (m *mockVault) ListIssuers(ctx context.Context, token, mount string) ([]str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
|
func (m *mockVault) IssueCert(ctx context.Context, token string, req IssueCertRequest) (*IssuedCert, error) {
|
||||||
|
if m.issueCertFn != nil {
|
||||||
|
return m.issueCertFn(ctx, token, req)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("not implemented")
|
return nil, fmt.Errorf("not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -108,6 +112,62 @@ func (m *mockVault) DeleteCert(ctx context.Context, token, mount, serial string)
|
|||||||
|
|
||||||
func (m *mockVault) Close() error { return nil }
|
func (m *mockVault) Close() error { return nil }
|
||||||
|
|
||||||
|
// ---- handleTGZDownload tests ----
|
||||||
|
|
||||||
|
func TestHandleTGZDownload(t *testing.T) {
|
||||||
|
t.Run("valid token serves archive and removes it from cache", func(t *testing.T) {
|
||||||
|
ws := newTestWebServer(t, &mockVault{})
|
||||||
|
|
||||||
|
const dlToken = "abc123"
|
||||||
|
ws.tgzCache.Store(dlToken, &tgzEntry{
|
||||||
|
filename: "test.tgz",
|
||||||
|
data: []byte("fake-tgz-data"),
|
||||||
|
})
|
||||||
|
|
||||||
|
r := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
|
||||||
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ws.handleTGZDownload(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusOK {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusOK)
|
||||||
|
}
|
||||||
|
if ct := w.Header().Get("Content-Type"); ct != "application/gzip" {
|
||||||
|
t.Errorf("Content-Type = %q, want application/gzip", ct)
|
||||||
|
}
|
||||||
|
if body := w.Body.String(); body != "fake-tgz-data" {
|
||||||
|
t.Errorf("body = %q, want fake-tgz-data", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Entry must have been removed — a second request should 404.
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
r2 := newChiRequest(http.MethodGet, "/pki/download/"+dlToken, map[string]string{"token": dlToken})
|
||||||
|
r2 = addAuthCookie(r2, &TokenInfo{Username: "testuser"})
|
||||||
|
ws.handleTGZDownload(w2, r2)
|
||||||
|
if w2.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("second request status = %d, want %d", w2.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("unknown token returns 404", func(t *testing.T) {
|
||||||
|
ws := newTestWebServer(t, &mockVault{})
|
||||||
|
|
||||||
|
r := newChiRequest(http.MethodGet, "/pki/download/nosuchtoken", map[string]string{"token": "nosuchtoken"})
|
||||||
|
r = addAuthCookie(r, &TokenInfo{Username: "testuser"})
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
ws.handleTGZDownload(w, r)
|
||||||
|
|
||||||
|
if w.Code != http.StatusNotFound {
|
||||||
|
t.Errorf("status = %d, want %d", w.Code, http.StatusNotFound)
|
||||||
|
}
|
||||||
|
if !strings.Contains(w.Body.String(), "download not found") {
|
||||||
|
t.Errorf("body %q does not contain 'download not found'", w.Body.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests.
|
// newTestWebServer builds a WebServer wired to the given mock, suitable for unit tests.
|
||||||
func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
|
func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|||||||
@@ -2,7 +2,10 @@ package webserver
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/tar"
|
"archive/tar"
|
||||||
|
"bytes"
|
||||||
"compress/gzip"
|
"compress/gzip"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -40,6 +43,7 @@ func (ws *WebServer) registerRoutes(r chi.Router) {
|
|||||||
r.Post("/import-root", ws.requireAuth(ws.handleImportRoot))
|
r.Post("/import-root", ws.requireAuth(ws.handleImportRoot))
|
||||||
r.Post("/create-issuer", ws.requireAuth(ws.handleCreateIssuer))
|
r.Post("/create-issuer", ws.requireAuth(ws.handleCreateIssuer))
|
||||||
r.Post("/issue", ws.requireAuth(ws.handleIssueCert))
|
r.Post("/issue", ws.requireAuth(ws.handleIssueCert))
|
||||||
|
r.Get("/download/{token}", ws.requireAuth(ws.handleTGZDownload))
|
||||||
r.Get("/issuer/{issuer}", ws.requireAuth(ws.handleIssuerDetail))
|
r.Get("/issuer/{issuer}", ws.requireAuth(ws.handleIssuerDetail))
|
||||||
r.Get("/cert/{serial}", ws.requireAuth(ws.handleCertDetail))
|
r.Get("/cert/{serial}", ws.requireAuth(ws.handleCertDetail))
|
||||||
r.Get("/cert/{serial}/download", ws.requireAuth(ws.handleCertDownload))
|
r.Get("/cert/{serial}/download", ws.requireAuth(ws.handleCertDownload))
|
||||||
@@ -479,11 +483,6 @@ func (ws *WebServer) handleIssuerDetail(w http.ResponseWriter, r *http.Request)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
||||||
// Disable the server-wide write deadline for this handler: it streams a
|
|
||||||
// tgz response only after several serial gRPC calls, which can easily
|
|
||||||
// consume the 30 s WriteTimeout before we start writing. We set our own
|
|
||||||
// 60 s deadline just before the write phase below.
|
|
||||||
_ = http.NewResponseController(w).SetWriteDeadline(time.Time{})
|
|
||||||
|
|
||||||
info := tokenInfoFromContext(r.Context())
|
info := tokenInfoFromContext(r.Context())
|
||||||
token := extractCookie(r)
|
token := extractCookie(r)
|
||||||
@@ -538,17 +537,11 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stream a tgz archive containing the private key (PKCS8) and certificate.
|
// Build the tgz archive in memory, store it in the cache, then redirect
|
||||||
// Extend the write deadline before streaming so that slow gRPC backends
|
// the browser to the one-time download URL so the archive is only served
|
||||||
// don't consume the server WriteTimeout before we start writing.
|
// once and then discarded.
|
||||||
rc := http.NewResponseController(w)
|
var buf bytes.Buffer
|
||||||
_ = rc.SetWriteDeadline(time.Now().Add(60 * time.Second))
|
gw := gzip.NewWriter(&buf)
|
||||||
|
|
||||||
filename := issuedCert.Serial + ".tgz"
|
|
||||||
w.Header().Set("Content-Type", "application/gzip")
|
|
||||||
w.Header().Set("Content-Disposition", "attachment; filename=\""+filename+"\"")
|
|
||||||
|
|
||||||
gw := gzip.NewWriter(w)
|
|
||||||
tw := tar.NewWriter(gw)
|
tw := tar.NewWriter(gw)
|
||||||
|
|
||||||
writeTarFile := func(name string, data []byte) error {
|
writeTarFile := func(name string, data []byte) error {
|
||||||
@@ -566,16 +559,48 @@ func (ws *WebServer) handleIssueCert(w http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := writeTarFile("key.pem", []byte(issuedCert.KeyPEM)); err != nil {
|
if err := writeTarFile("key.pem", []byte(issuedCert.KeyPEM)); err != nil {
|
||||||
ws.logger.Error("write key to tgz", "error", err)
|
ws.logger.Error("build tgz key", "error", err)
|
||||||
|
http.Error(w, "failed to build archive", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := writeTarFile("cert.pem", []byte(issuedCert.CertPEM)); err != nil {
|
if err := writeTarFile("cert.pem", []byte(issuedCert.CertPEM)); err != nil {
|
||||||
ws.logger.Error("write cert to tgz", "error", err)
|
ws.logger.Error("build tgz cert", "error", err)
|
||||||
|
http.Error(w, "failed to build archive", http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tw.Close()
|
_ = tw.Close()
|
||||||
_ = gw.Close()
|
_ = gw.Close()
|
||||||
|
|
||||||
|
// Generate a random one-time token for the download URL.
|
||||||
|
var raw [16]byte
|
||||||
|
if _, err := rand.Read(raw[:]); err != nil {
|
||||||
|
ws.logger.Error("generate download token", "error", err)
|
||||||
|
http.Error(w, "internal server error", http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dlToken := hex.EncodeToString(raw[:])
|
||||||
|
|
||||||
|
ws.tgzCache.Store(dlToken, &tgzEntry{
|
||||||
|
filename: issuedCert.Serial + ".tgz",
|
||||||
|
data: buf.Bytes(),
|
||||||
|
})
|
||||||
|
|
||||||
|
http.Redirect(w, r, "/pki/download/"+dlToken, http.StatusSeeOther)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ws *WebServer) handleTGZDownload(w http.ResponseWriter, r *http.Request) {
|
||||||
|
dlToken := chi.URLParam(r, "token")
|
||||||
|
|
||||||
|
val, ok := ws.tgzCache.LoadAndDelete(dlToken)
|
||||||
|
if !ok {
|
||||||
|
http.Error(w, "download not found or already used", http.StatusNotFound)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
entry := val.(*tgzEntry)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/gzip")
|
||||||
|
w.Header().Set("Content-Disposition", "attachment; filename=\""+entry.filename+"\"")
|
||||||
|
_, _ = w.Write(entry.data)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ws *WebServer) handleCertDetail(w http.ResponseWriter, r *http.Request) {
|
func (ws *WebServer) handleCertDetail(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"io/fs"
|
"io/fs"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
@@ -42,6 +43,12 @@ type vaultBackend interface {
|
|||||||
Close() error
|
Close() error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// tgzEntry holds a cached tgz archive pending download.
|
||||||
|
type tgzEntry struct {
|
||||||
|
filename string
|
||||||
|
data []byte
|
||||||
|
}
|
||||||
|
|
||||||
// WebServer is the standalone web UI server.
|
// WebServer is the standalone web UI server.
|
||||||
type WebServer struct {
|
type WebServer struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -49,6 +56,7 @@ type WebServer struct {
|
|||||||
logger *slog.Logger
|
logger *slog.Logger
|
||||||
httpSrv *http.Server
|
httpSrv *http.Server
|
||||||
staticFS fs.FS
|
staticFS fs.FS
|
||||||
|
tgzCache sync.Map // key: UUID string → *tgzEntry
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new WebServer. It dials the vault gRPC endpoint.
|
// New creates a new WebServer. It dials the vault gRPC endpoint.
|
||||||
|
|||||||
Reference in New Issue
Block a user