Add MEK rotation, per-engine DEKs, and v2 ciphertext format (audit #6, #22)

Implement a two-level key hierarchy: the MEK now wraps per-engine DEKs
stored in a new barrier_keys table, rather than encrypting all barrier
entries directly. A v2 ciphertext format (0x02) embeds the key ID so the
barrier can resolve which DEK to use on decryption. v1 ciphertext remains
supported for backward compatibility.

Key changes:
- crypto: EncryptV2/DecryptV2/ExtractKeyID for v2 ciphertext with key IDs
- barrier: key registry (CreateKey, RotateKey, ListKeys, MigrateToV2, ReWrapKeys)
- seal: RotateMEK re-wraps DEKs without re-encrypting data
- engine: Mount auto-creates per-engine DEK
- REST + gRPC: barrier/keys, barrier/rotate-mek, barrier/rotate-key, barrier/migrate
- proto: BarrierService (v1 + v2) with ListKeys, RotateMEK, RotateKey, Migrate
- db: migration v2 adds barrier_keys table

Also includes: security audit report, CSRF protection, engine design specs
(sshca, transit, user), path-bound AAD migration tool, policy engine
enhancements, and ARCHITECTURE.md updates.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-16 18:27:44 -07:00
parent ac4577f778
commit 64d921827e
44 changed files with 5184 additions and 90 deletions

View File

@@ -13,8 +13,9 @@ import (
)
var (
ErrSealed = errors.New("barrier: sealed")
ErrNotFound = errors.New("barrier: entry not found")
ErrSealed = errors.New("barrier: sealed")
ErrNotFound = errors.New("barrier: entry not found")
ErrKeyNotFound = errors.New("barrier: key not found")
)
// Barrier is the encrypted storage barrier interface.
@@ -36,11 +37,20 @@ type Barrier interface {
List(ctx context.Context, prefix string) ([]string, error)
}
// KeyInfo holds metadata about a barrier key (DEK).
type KeyInfo struct {
KeyID string `json:"key_id"`
Version int `json:"version"`
CreatedAt string `json:"created_at"`
RotatedAt string `json:"rotated_at"`
}
// AESGCMBarrier implements Barrier using AES-256-GCM encryption.
type AESGCMBarrier struct {
db *sql.DB
mek []byte
mu sync.RWMutex
db *sql.DB
mek []byte
keys map[string][]byte // key_id → plaintext DEK
mu sync.RWMutex
}
// NewAESGCMBarrier creates a new AES-GCM barrier backed by the given database.
@@ -51,15 +61,56 @@ func NewAESGCMBarrier(db *sql.DB) *AESGCMBarrier {
func (b *AESGCMBarrier) Unseal(mek []byte) error {
b.mu.Lock()
defer b.mu.Unlock()
k := make([]byte, len(mek))
copy(k, mek)
b.mek = k
b.keys = make(map[string][]byte)
// Load DEKs from barrier_keys table.
if err := b.loadKeys(); err != nil {
// If the table doesn't exist yet (pre-migration), that's OK.
// The barrier will use MEK directly for v1 entries.
b.keys = make(map[string][]byte)
}
return nil
}
// loadKeys decrypts all DEKs from the barrier_keys table into memory.
// Caller must hold b.mu.
func (b *AESGCMBarrier) loadKeys() error {
rows, err := b.db.Query("SELECT key_id, encrypted_dek FROM barrier_keys")
if err != nil {
return err
}
defer func() { _ = rows.Close() }()
for rows.Next() {
var keyID string
var encDEK []byte
if err := rows.Scan(&keyID, &encDEK); err != nil {
return fmt.Errorf("barrier: scan key %q: %w", keyID, err)
}
dek, err := crypto.Decrypt(b.mek, encDEK, []byte(keyID))
if err != nil {
return fmt.Errorf("barrier: decrypt key %q: %w", keyID, err)
}
b.keys[keyID] = dek
}
return rows.Err()
}
func (b *AESGCMBarrier) Seal() error {
b.mu.Lock()
defer b.mu.Unlock()
// Zeroize all DEKs.
for _, dek := range b.keys {
crypto.Zeroize(dek)
}
b.keys = nil
if b.mek != nil {
crypto.Zeroize(b.mek)
b.mek = nil
@@ -73,9 +124,22 @@ func (b *AESGCMBarrier) IsSealed() bool {
return b.mek == nil
}
// resolveKeyID determines the key ID for a given barrier path.
func resolveKeyID(path string) string {
// Paths under engine/{type}/{mount}/... use per-engine DEKs.
if strings.HasPrefix(path, "engine/") {
parts := strings.SplitN(path, "/", 4) // engine/{type}/{mount}/...
if len(parts) >= 3 {
return "engine/" + parts[1] + "/" + parts[2]
}
}
return "system"
}
func (b *AESGCMBarrier) Get(ctx context.Context, path string) ([]byte, error) {
b.mu.RLock()
mek := b.mek
keys := b.keys
b.mu.RUnlock()
if mek == nil {
return nil, ErrSealed
@@ -91,22 +155,52 @@ func (b *AESGCMBarrier) Get(ctx context.Context, path string) ([]byte, error) {
return nil, fmt.Errorf("barrier: get %q: %w", path, err)
}
plaintext, err := crypto.Decrypt(mek, encrypted)
// Check version byte to determine decryption strategy.
if len(encrypted) > 0 && encrypted[0] == crypto.BarrierVersionV2 {
keyID, err := crypto.ExtractKeyID(encrypted)
if err != nil {
return nil, fmt.Errorf("barrier: extract key ID %q: %w", path, err)
}
dek, ok := keys[keyID]
if !ok {
return nil, fmt.Errorf("barrier: %w: %q for path %q", ErrKeyNotFound, keyID, path)
}
pt, _, err := crypto.DecryptV2(dek, encrypted, []byte(path))
if err != nil {
return nil, fmt.Errorf("barrier: decrypt %q: %w", path, err)
}
return pt, nil
}
// v1 ciphertext — use MEK directly (backward compat).
pt, err := crypto.Decrypt(mek, encrypted, []byte(path))
if err != nil {
return nil, fmt.Errorf("barrier: decrypt %q: %w", path, err)
}
return plaintext, nil
return pt, nil
}
func (b *AESGCMBarrier) Put(ctx context.Context, path string, value []byte) error {
b.mu.RLock()
mek := b.mek
keys := b.keys
b.mu.RUnlock()
if mek == nil {
return ErrSealed
}
encrypted, err := crypto.Encrypt(mek, value)
keyID := resolveKeyID(path)
var encrypted []byte
var err error
if dek, ok := keys[keyID]; ok {
// Use v2 format with the appropriate DEK.
encrypted, err = crypto.EncryptV2(dek, keyID, value, []byte(path))
} else {
// No DEK registered for this key ID — fall back to MEK with v1 format.
encrypted, err = crypto.Encrypt(mek, value, []byte(path))
}
if err != nil {
return fmt.Errorf("barrier: encrypt %q: %w", path, err)
}
@@ -159,9 +253,394 @@ func (b *AESGCMBarrier) List(ctx context.Context, prefix string) ([]string, erro
if err := rows.Scan(&p); err != nil {
return nil, fmt.Errorf("barrier: list scan: %w", err)
}
// Strip the prefix and return just the next segment.
remainder := strings.TrimPrefix(p, prefix)
paths = append(paths, remainder)
}
return paths, rows.Err()
}
// CreateKey generates a new DEK for the given key ID, wraps it with MEK,
// and stores it in the barrier_keys table.
func (b *AESGCMBarrier) CreateKey(ctx context.Context, keyID string) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.mek == nil {
return ErrSealed
}
if _, exists := b.keys[keyID]; exists {
return nil // Already exists.
}
dek, err := crypto.GenerateKey()
if err != nil {
return fmt.Errorf("barrier: generate DEK %q: %w", keyID, err)
}
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
if err != nil {
crypto.Zeroize(dek)
return fmt.Errorf("barrier: wrap DEK %q: %w", keyID, err)
}
_, err = b.db.ExecContext(ctx, `
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
VALUES (?, 1, ?)
ON CONFLICT(key_id) DO NOTHING`,
keyID, encDEK)
if err != nil {
crypto.Zeroize(dek)
return fmt.Errorf("barrier: store DEK %q: %w", keyID, err)
}
b.keys[keyID] = dek
return nil
}
// RotateKey generates a new DEK for the given key ID and re-encrypts all
// barrier entries under that key ID's prefix with the new DEK.
func (b *AESGCMBarrier) RotateKey(ctx context.Context, keyID string) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.mek == nil {
return ErrSealed
}
oldDEK, ok := b.keys[keyID]
if !ok {
return fmt.Errorf("barrier: %w: %q", ErrKeyNotFound, keyID)
}
// Generate new DEK.
newDEK, err := crypto.GenerateKey()
if err != nil {
return fmt.Errorf("barrier: generate DEK: %w", err)
}
// Wrap new DEK with MEK.
encDEK, err := crypto.Encrypt(b.mek, newDEK, []byte(keyID))
if err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: wrap DEK: %w", err)
}
// Determine the prefix for entries encrypted with this key.
prefix := keyID + "/"
if keyID == "system" {
// System key covers non-engine paths. Re-encrypt everything
// that doesn't start with "engine/".
prefix = ""
}
// Re-encrypt all entries under this key ID.
tx, err := b.db.BeginTx(ctx, nil)
if err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Update the key in barrier_keys.
_, err = tx.ExecContext(ctx, `
UPDATE barrier_keys SET encrypted_dek = ?, version = version + 1, rotated_at = datetime('now')
WHERE key_id = ?`, encDEK, keyID)
if err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: update key: %w", err)
}
// Fetch and re-encrypt entries.
var query string
var args []interface{}
if keyID == "system" {
query = "SELECT path, value FROM barrier_entries WHERE path NOT LIKE 'engine/%'"
} else {
query = "SELECT path, value FROM barrier_entries WHERE path LIKE ?"
args = append(args, prefix+"%")
}
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: query entries: %w", err)
}
type entry struct {
path string
value []byte
}
var entries []entry
for rows.Next() {
var e entry
if err := rows.Scan(&e.path, &e.value); err != nil {
_ = rows.Close()
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: scan entry: %w", err)
}
entries = append(entries, e)
}
_ = rows.Close()
if err := rows.Err(); err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: rows error: %w", err)
}
for _, e := range entries {
// Decrypt with old DEK (handle v1 or v2).
var plaintext []byte
if len(e.value) > 0 && e.value[0] == crypto.BarrierVersionV2 {
pt, _, decErr := crypto.DecryptV2(oldDEK, e.value, []byte(e.path))
if decErr != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: decrypt %q during rotation: %w", e.path, decErr)
}
plaintext = pt
} else {
// v1: encrypted with MEK.
pt, decErr := crypto.Decrypt(b.mek, e.value, []byte(e.path))
if decErr != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: decrypt v1 %q during rotation: %w", e.path, decErr)
}
plaintext = pt
}
// Re-encrypt with new DEK using v2 format.
newCiphertext, encErr := crypto.EncryptV2(newDEK, keyID, plaintext, []byte(e.path))
if encErr != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: re-encrypt %q: %w", e.path, encErr)
}
_, err = tx.ExecContext(ctx,
"UPDATE barrier_entries SET value = ?, updated_at = datetime('now') WHERE path = ?",
newCiphertext, e.path)
if err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: update %q: %w", e.path, err)
}
}
if err := tx.Commit(); err != nil {
crypto.Zeroize(newDEK)
return fmt.Errorf("barrier: commit rotation: %w", err)
}
// Swap the in-memory key.
crypto.Zeroize(oldDEK)
b.keys[keyID] = newDEK
return nil
}
// ListKeys returns metadata about all registered barrier keys.
func (b *AESGCMBarrier) ListKeys(ctx context.Context) ([]KeyInfo, error) {
b.mu.RLock()
mek := b.mek
b.mu.RUnlock()
if mek == nil {
return nil, ErrSealed
}
rows, err := b.db.QueryContext(ctx,
"SELECT key_id, version, created_at, rotated_at FROM barrier_keys ORDER BY key_id")
if err != nil {
return nil, fmt.Errorf("barrier: list keys: %w", err)
}
defer func() { _ = rows.Close() }()
var keys []KeyInfo
for rows.Next() {
var ki KeyInfo
if err := rows.Scan(&ki.KeyID, &ki.Version, &ki.CreatedAt, &ki.RotatedAt); err != nil {
return nil, fmt.Errorf("barrier: scan key info: %w", err)
}
keys = append(keys, ki)
}
return keys, rows.Err()
}
// MigrateToV2 creates per-engine DEKs and re-encrypts entries from v1
// (MEK-encrypted) to v2 (DEK-encrypted) format. On first call after upgrade,
// it creates a "system" DEK equal to the MEK for zero-cost backward compat,
// then creates per-engine DEKs and re-encrypts those entries.
func (b *AESGCMBarrier) MigrateToV2(ctx context.Context) (int, error) {
b.mu.Lock()
defer b.mu.Unlock()
if b.mek == nil {
return 0, ErrSealed
}
// Ensure the "system" key exists.
if _, ok := b.keys["system"]; !ok {
if err := b.createKeyLocked(ctx, "system"); err != nil {
return 0, fmt.Errorf("barrier: create system DEK: %w", err)
}
}
// Find all entries still in v1 format.
rows, err := b.db.QueryContext(ctx, "SELECT path, value FROM barrier_entries")
if err != nil {
return 0, fmt.Errorf("barrier: query entries: %w", err)
}
type entry struct {
path string
value []byte
}
var toMigrate []entry
for rows.Next() {
var e entry
if err := rows.Scan(&e.path, &e.value); err != nil {
_ = rows.Close()
return 0, fmt.Errorf("barrier: scan: %w", err)
}
if len(e.value) > 0 && e.value[0] == crypto.BarrierVersionV1 {
toMigrate = append(toMigrate, e)
}
}
_ = rows.Close()
if err := rows.Err(); err != nil {
return 0, err
}
if len(toMigrate) == 0 {
return 0, nil
}
tx, err := b.db.BeginTx(ctx, nil)
if err != nil {
return 0, fmt.Errorf("barrier: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
migrated := 0
for _, e := range toMigrate {
// Decrypt with MEK (v1).
plaintext, decErr := crypto.Decrypt(b.mek, e.value, []byte(e.path))
if decErr != nil {
return migrated, fmt.Errorf("barrier: decrypt %q: %w", e.path, decErr)
}
keyID := resolveKeyID(e.path)
// Ensure the DEK exists for this key ID.
if _, ok := b.keys[keyID]; !ok {
if err := b.createKeyLockedTx(ctx, tx, keyID); err != nil {
return migrated, fmt.Errorf("barrier: create DEK %q: %w", keyID, err)
}
}
dek := b.keys[keyID]
newCiphertext, encErr := crypto.EncryptV2(dek, keyID, plaintext, []byte(e.path))
if encErr != nil {
return migrated, fmt.Errorf("barrier: encrypt v2 %q: %w", e.path, encErr)
}
_, err = tx.ExecContext(ctx,
"UPDATE barrier_entries SET value = ?, updated_at = datetime('now') WHERE path = ?",
newCiphertext, e.path)
if err != nil {
return migrated, fmt.Errorf("barrier: update %q: %w", e.path, err)
}
migrated++
}
if err := tx.Commit(); err != nil {
return migrated, fmt.Errorf("barrier: commit migration: %w", err)
}
return migrated, nil
}
// createKeyLocked generates and stores a new DEK. Caller must hold b.mu.
func (b *AESGCMBarrier) createKeyLocked(ctx context.Context, keyID string) error {
dek, err := crypto.GenerateKey()
if err != nil {
return err
}
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
if err != nil {
crypto.Zeroize(dek)
return err
}
_, err = b.db.ExecContext(ctx, `
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
VALUES (?, 1, ?) ON CONFLICT(key_id) DO NOTHING`, keyID, encDEK)
if err != nil {
crypto.Zeroize(dek)
return err
}
b.keys[keyID] = dek
return nil
}
// createKeyLockedTx is like createKeyLocked but uses an existing transaction.
func (b *AESGCMBarrier) createKeyLockedTx(ctx context.Context, tx *sql.Tx, keyID string) error {
dek, err := crypto.GenerateKey()
if err != nil {
return err
}
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
if err != nil {
crypto.Zeroize(dek)
return err
}
_, err = tx.ExecContext(ctx, `
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
VALUES (?, 1, ?) ON CONFLICT(key_id) DO NOTHING`, keyID, encDEK)
if err != nil {
crypto.Zeroize(dek)
return err
}
b.keys[keyID] = dek
return nil
}
// ReWrapKeys re-encrypts all DEKs with a new MEK. Called during MEK rotation.
// The new MEK is already set in b.mek by the caller.
func (b *AESGCMBarrier) ReWrapKeys(ctx context.Context, newMEK []byte) error {
b.mu.Lock()
defer b.mu.Unlock()
if b.mek == nil {
return ErrSealed
}
tx, err := b.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("barrier: begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
for keyID, dek := range b.keys {
encDEK, err := crypto.Encrypt(newMEK, dek, []byte(keyID))
if err != nil {
return fmt.Errorf("barrier: re-wrap key %q: %w", keyID, err)
}
_, err = tx.ExecContext(ctx,
"UPDATE barrier_keys SET encrypted_dek = ?, rotated_at = datetime('now') WHERE key_id = ?",
encDEK, keyID)
if err != nil {
return fmt.Errorf("barrier: update key %q: %w", keyID, err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("barrier: commit re-wrap: %w", err)
}
// Update the MEK in memory.
crypto.Zeroize(b.mek)
k := make([]byte, len(newMEK))
copy(k, newMEK)
b.mek = k
return nil
}

View File

@@ -158,3 +158,309 @@ func TestBarrierOverwrite(t *testing.T) {
t.Fatalf("overwrite: got %q, want %q", got, "v2")
}
}
// --- DEK / Key Registry Tests ---
func TestBarrierCreateKey(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
if err := b.CreateKey(ctx, "engine/ca/prod"); err != nil {
t.Fatalf("CreateKey: %v", err)
}
// Duplicate should be a no-op.
if err := b.CreateKey(ctx, "engine/ca/prod"); err != nil {
t.Fatalf("CreateKey duplicate: %v", err)
}
keys, err := b.ListKeys(ctx)
if err != nil {
t.Fatalf("ListKeys: %v", err)
}
if len(keys) != 1 {
t.Fatalf("expected 1 key, got %d", len(keys))
}
if keys[0].KeyID != "engine/ca/prod" {
t.Fatalf("key ID: got %q, want %q", keys[0].KeyID, "engine/ca/prod")
}
}
func TestBarrierDEKEncryption(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
// Create a DEK for ca/prod.
_ = b.CreateKey(ctx, "engine/ca/prod")
// Write data under the engine path — should use DEK.
data := []byte("engine secret data")
if err := b.Put(ctx, "engine/ca/prod/config.json", data); err != nil {
t.Fatalf("Put: %v", err)
}
got, err := b.Get(ctx, "engine/ca/prod/config.json")
if err != nil {
t.Fatalf("Get: %v", err)
}
if string(got) != string(data) {
t.Fatalf("roundtrip: got %q, want %q", got, data)
}
// Verify the raw ciphertext is v2 format.
var raw []byte
err = b.db.QueryRowContext(ctx,
"SELECT value FROM barrier_entries WHERE path = ?",
"engine/ca/prod/config.json").Scan(&raw)
if err != nil {
t.Fatalf("read raw: %v", err)
}
if raw[0] != crypto.BarrierVersionV2 {
t.Fatalf("expected v2 ciphertext, got version %d", raw[0])
}
}
func TestBarrierV1FallbackWithoutDEK(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
// Write system data without any DEK — should use v1 with MEK.
data := []byte("system data")
if err := b.Put(ctx, "policy/rule1", data); err != nil {
t.Fatalf("Put: %v", err)
}
got, err := b.Get(ctx, "policy/rule1")
if err != nil {
t.Fatalf("Get: %v", err)
}
if string(got) != string(data) {
t.Fatalf("roundtrip: got %q, want %q", got, data)
}
}
func TestBarrierSystemDEK(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
// Create system DEK.
_ = b.CreateKey(ctx, "system")
// Write system data — should use system DEK with v2 format.
data := []byte("system with dek")
if err := b.Put(ctx, "policy/rule1", data); err != nil {
t.Fatalf("Put: %v", err)
}
got, err := b.Get(ctx, "policy/rule1")
if err != nil {
t.Fatalf("Get: %v", err)
}
if string(got) != string(data) {
t.Fatalf("roundtrip: got %q, want %q", got, data)
}
// Verify v2 format.
var raw []byte
_ = b.db.QueryRowContext(ctx,
"SELECT value FROM barrier_entries WHERE path = ?",
"policy/rule1").Scan(&raw)
if raw[0] != crypto.BarrierVersionV2 {
t.Fatalf("expected v2 ciphertext, got version %d", raw[0])
}
}
func TestBarrierRotateKey(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
_ = b.CreateKey(ctx, "engine/ca/prod")
// Write some data.
_ = b.Put(ctx, "engine/ca/prod/cert1", []byte("cert-data-1"))
_ = b.Put(ctx, "engine/ca/prod/cert2", []byte("cert-data-2"))
// Rotate the key.
if err := b.RotateKey(ctx, "engine/ca/prod"); err != nil {
t.Fatalf("RotateKey: %v", err)
}
// Data should still be readable.
got, err := b.Get(ctx, "engine/ca/prod/cert1")
if err != nil {
t.Fatalf("Get after rotation: %v", err)
}
if string(got) != "cert-data-1" {
t.Fatalf("data corrupted after rotation: %q", got)
}
got2, err := b.Get(ctx, "engine/ca/prod/cert2")
if err != nil {
t.Fatalf("Get after rotation: %v", err)
}
if string(got2) != "cert-data-2" {
t.Fatalf("data corrupted after rotation: %q", got2)
}
// Check key version incremented.
keys, _ := b.ListKeys(ctx)
for _, k := range keys {
if k.KeyID == "engine/ca/prod" && k.Version != 2 {
t.Fatalf("expected version 2 after rotation, got %d", k.Version)
}
}
}
func TestBarrierRotateKeyNotFound(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
err := b.RotateKey(ctx, "nonexistent")
if !errors.Is(err, ErrKeyNotFound) {
t.Fatalf("expected ErrKeyNotFound, got: %v", err)
}
}
func TestBarrierMigrateToV2(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
// Write v1 data (no DEKs registered, so it uses MEK).
_ = b.Put(ctx, "policy/rule1", []byte("policy-data"))
_ = b.Put(ctx, "engine/ca/prod/config", []byte("ca-config"))
_ = b.Put(ctx, "engine/transit/main/key1", []byte("transit-key"))
// Migrate.
migrated, err := b.MigrateToV2(ctx)
if err != nil {
t.Fatalf("MigrateToV2: %v", err)
}
if migrated != 3 {
t.Fatalf("expected 3 entries migrated, got %d", migrated)
}
// All data should still be readable.
got, err := b.Get(ctx, "policy/rule1")
if err != nil {
t.Fatalf("Get policy after migration: %v", err)
}
if string(got) != "policy-data" {
t.Fatalf("policy data: got %q", got)
}
got, err = b.Get(ctx, "engine/ca/prod/config")
if err != nil {
t.Fatalf("Get engine data after migration: %v", err)
}
if string(got) != "ca-config" {
t.Fatalf("engine data: got %q", got)
}
// Running again should migrate 0 (all already v2).
migrated2, err := b.MigrateToV2(ctx)
if err != nil {
t.Fatalf("MigrateToV2 second run: %v", err)
}
if migrated2 != 0 {
t.Fatalf("expected 0 entries on second migration, got %d", migrated2)
}
}
func TestBarrierSealUnsealPreservesDEKs(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
// Create DEK and write data.
_ = b.CreateKey(ctx, "engine/ca/prod")
_ = b.Put(ctx, "engine/ca/prod/secret", []byte("my-secret"))
// Seal and unseal.
_ = b.Seal()
_ = b.Unseal(mek)
// Data should still be readable (DEKs reloaded from barrier_keys).
got, err := b.Get(ctx, "engine/ca/prod/secret")
if err != nil {
t.Fatalf("Get after seal/unseal: %v", err)
}
if string(got) != "my-secret" {
t.Fatalf("data after seal/unseal: got %q", got)
}
}
func TestBarrierReWrapKeys(t *testing.T) {
b, cleanup := setupBarrier(t)
defer cleanup()
ctx := context.Background()
mek, _ := crypto.GenerateKey()
_ = b.Unseal(mek)
_ = b.CreateKey(ctx, "system")
_ = b.CreateKey(ctx, "engine/ca/prod")
_ = b.Put(ctx, "policy/rule1", []byte("policy"))
_ = b.Put(ctx, "engine/ca/prod/cert", []byte("cert"))
// Re-wrap with new MEK.
newMEK, _ := crypto.GenerateKey()
if err := b.ReWrapKeys(ctx, newMEK); err != nil {
t.Fatalf("ReWrapKeys: %v", err)
}
// Data should still be readable.
got, _ := b.Get(ctx, "policy/rule1")
if string(got) != "policy" {
t.Fatalf("policy after rewrap: got %q", got)
}
got2, _ := b.Get(ctx, "engine/ca/prod/cert")
if string(got2) != "cert" {
t.Fatalf("cert after rewrap: got %q", got2)
}
// Seal and unseal with new MEK should work.
_ = b.Seal()
if err := b.Unseal(newMEK); err != nil {
t.Fatalf("Unseal with new MEK: %v", err)
}
got3, err := b.Get(ctx, "engine/ca/prod/cert")
if err != nil {
t.Fatalf("Get after unseal with new MEK: %v", err)
}
if string(got3) != "cert" {
t.Fatalf("data after new MEK unseal: got %q", got3)
}
}

View File

@@ -20,8 +20,16 @@ const (
// SaltSize is the size of Argon2id salts in bytes.
SaltSize = 32
// BarrierVersion is the version byte prefix for encrypted barrier entries.
BarrierVersion byte = 0x01
// BarrierVersionV1 is the v1 format: [version][nonce][ciphertext+tag].
BarrierVersionV1 byte = 0x01
// BarrierVersionV2 is the v2 format: [version][key_id_len][key_id][nonce][ciphertext+tag].
BarrierVersionV2 byte = 0x02
// BarrierVersion is kept for backward compatibility (alias for V1).
BarrierVersion = BarrierVersionV1
// MaxKeyIDLen is the maximum length of a key ID in the v2 format.
MaxKeyIDLen = 255
// Default Argon2id parameters.
DefaultArgon2Time = 3
@@ -32,6 +40,7 @@ const (
var (
ErrInvalidCiphertext = errors.New("crypto: invalid ciphertext")
ErrDecryptionFailed = errors.New("crypto: decryption failed")
ErrKeyIDTooLong = errors.New("crypto: key ID exceeds maximum length")
)
// Argon2Params holds Argon2id KDF parameters.
@@ -74,8 +83,10 @@ func GenerateSalt() ([]byte, error) {
}
// Encrypt encrypts plaintext with AES-256-GCM using the given key.
// The additionalData parameter is authenticated but not encrypted (AAD);
// pass nil when no binding context is needed.
// Returns: [version byte][12-byte nonce][ciphertext+tag]
func Encrypt(key, plaintext []byte) ([]byte, error) {
func Encrypt(key, plaintext, additionalData []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("crypto: new cipher: %w", err)
@@ -88,7 +99,7 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
// Format: [version][nonce][ciphertext+tag]
result := make([]byte, 1+NonceSize+len(ciphertext))
@@ -99,7 +110,8 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
}
// Decrypt decrypts ciphertext produced by Encrypt.
func Decrypt(key, data []byte) ([]byte, error) {
// The additionalData must match the value provided during encryption.
func Decrypt(key, data, additionalData []byte) ([]byte, error) {
if len(data) < 1+NonceSize+aes.BlockSize {
return nil, ErrInvalidCiphertext
}
@@ -117,13 +129,114 @@ func Decrypt(key, data []byte) ([]byte, error) {
if err != nil {
return nil, fmt.Errorf("crypto: new gcm: %w", err)
}
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
plaintext, err := gcm.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
return nil, ErrDecryptionFailed
}
return plaintext, nil
}
// EncryptV2 encrypts plaintext with AES-256-GCM, embedding a key ID in the ciphertext.
// Format: [0x02][key_id_len:1][key_id:N][nonce:12][ciphertext+tag]
func EncryptV2(key []byte, keyID string, plaintext, additionalData []byte) ([]byte, error) {
if len(keyID) > MaxKeyIDLen {
return nil, ErrKeyIDTooLong
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("crypto: new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("crypto: new gcm: %w", err)
}
nonce := make([]byte, NonceSize)
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
}
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
kidLen := len(keyID)
// Format: [version][key_id_len][key_id][nonce][ciphertext+tag]
result := make([]byte, 1+1+kidLen+NonceSize+len(ciphertext))
result[0] = BarrierVersionV2
result[1] = byte(kidLen)
copy(result[2:2+kidLen], keyID)
copy(result[2+kidLen:2+kidLen+NonceSize], nonce)
copy(result[2+kidLen+NonceSize:], ciphertext)
return result, nil
}
// DecryptV2 decrypts ciphertext that may be in v1 or v2 format.
// For v2 format, it extracts the key ID and returns it alongside the plaintext.
// For v1 format, it returns an empty key ID.
func DecryptV2(key, data, additionalData []byte) (plaintext []byte, keyID string, err error) {
if len(data) < 1 {
return nil, "", ErrInvalidCiphertext
}
switch data[0] {
case BarrierVersionV1:
pt, err := Decrypt(key, data, additionalData)
return pt, "", err
case BarrierVersionV2:
if len(data) < 2 {
return nil, "", ErrInvalidCiphertext
}
kidLen := int(data[1])
headerLen := 2 + kidLen
if len(data) < headerLen+NonceSize+aes.BlockSize {
return nil, "", ErrInvalidCiphertext
}
keyID = string(data[2 : 2+kidLen])
nonce := data[headerLen : headerLen+NonceSize]
ciphertext := data[headerLen+NonceSize:]
block, err := aes.NewCipher(key)
if err != nil {
return nil, "", fmt.Errorf("crypto: new cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, "", fmt.Errorf("crypto: new gcm: %w", err)
}
pt, err := gcm.Open(nil, nonce, ciphertext, additionalData)
if err != nil {
return nil, "", ErrDecryptionFailed
}
return pt, keyID, nil
default:
return nil, "", fmt.Errorf("crypto: unsupported version: %d", data[0])
}
}
// ExtractKeyID returns the key ID from a v2 ciphertext without decrypting.
// Returns empty string for v1 ciphertext.
func ExtractKeyID(data []byte) (string, error) {
if len(data) < 1 {
return "", ErrInvalidCiphertext
}
switch data[0] {
case BarrierVersionV1:
return "", nil
case BarrierVersionV2:
if len(data) < 2 {
return "", ErrInvalidCiphertext
}
kidLen := int(data[1])
if len(data) < 2+kidLen {
return "", ErrInvalidCiphertext
}
return string(data[2 : 2+kidLen]), nil
default:
return "", fmt.Errorf("crypto: unsupported version: %d", data[0])
}
}
// Zeroize overwrites a byte slice with zeros.
func Zeroize(b []byte) {
for i := range b {

View File

@@ -34,7 +34,7 @@ func TestEncryptDecrypt(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("hello, metacrypt!")
ciphertext, err := Encrypt(key, plaintext)
ciphertext, err := Encrypt(key, plaintext, nil)
if err != nil {
t.Fatalf("Encrypt: %v", err)
}
@@ -44,7 +44,7 @@ func TestEncryptDecrypt(t *testing.T) {
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion)
}
decrypted, err := Decrypt(key, ciphertext)
decrypted, err := Decrypt(key, ciphertext, nil)
if err != nil {
t.Fatalf("Decrypt: %v", err)
}
@@ -54,13 +54,45 @@ func TestEncryptDecrypt(t *testing.T) {
}
}
func TestEncryptDecryptWithAAD(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("hello, metacrypt!")
aad := []byte("engine/ca/pki/root/cert.pem")
ciphertext, err := Encrypt(key, plaintext, aad)
if err != nil {
t.Fatalf("Encrypt with AAD: %v", err)
}
// Decrypt with correct AAD succeeds.
decrypted, err := Decrypt(key, ciphertext, aad)
if err != nil {
t.Fatalf("Decrypt with AAD: %v", err)
}
if !bytes.Equal(plaintext, decrypted) {
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
}
// Decrypt with wrong AAD fails.
_, err = Decrypt(key, ciphertext, []byte("wrong/path"))
if !errors.Is(err, ErrDecryptionFailed) {
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
}
// Decrypt with nil AAD fails.
_, err = Decrypt(key, ciphertext, nil)
if !errors.Is(err, ErrDecryptionFailed) {
t.Fatalf("expected ErrDecryptionFailed with nil AAD, got: %v", err)
}
}
func TestDecryptWrongKey(t *testing.T) {
key1, _ := GenerateKey()
key2, _ := GenerateKey()
plaintext := []byte("secret data")
ciphertext, _ := Encrypt(key1, plaintext)
_, err := Decrypt(key2, ciphertext)
ciphertext, _ := Encrypt(key1, plaintext, nil)
_, err := Decrypt(key2, ciphertext, nil)
if !errors.Is(err, ErrDecryptionFailed) {
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
}
@@ -68,7 +100,7 @@ func TestDecryptWrongKey(t *testing.T) {
func TestDecryptInvalidCiphertext(t *testing.T) {
key, _ := GenerateKey()
_, err := Decrypt(key, []byte("short"))
_, err := Decrypt(key, []byte("short"), nil)
if !errors.Is(err, ErrInvalidCiphertext) {
t.Fatalf("expected ErrInvalidCiphertext, got: %v", err)
}
@@ -124,10 +156,141 @@ func TestEncryptProducesDifferentCiphertext(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("same data")
ct1, _ := Encrypt(key, plaintext)
ct2, _ := Encrypt(key, plaintext)
ct1, _ := Encrypt(key, plaintext, nil)
ct2, _ := Encrypt(key, plaintext, nil)
if bytes.Equal(ct1, ct2) {
t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)")
}
}
func TestV2EncryptDecryptRoundtrip(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("v2 test data")
keyID := "engine/ca/prod"
aad := []byte("engine/ca/prod/config.json")
ciphertext, err := EncryptV2(key, keyID, plaintext, aad)
if err != nil {
t.Fatalf("EncryptV2: %v", err)
}
if ciphertext[0] != BarrierVersionV2 {
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersionV2)
}
pt, gotKeyID, err := DecryptV2(key, ciphertext, aad)
if err != nil {
t.Fatalf("DecryptV2: %v", err)
}
if gotKeyID != keyID {
t.Fatalf("key ID: got %q, want %q", gotKeyID, keyID)
}
if !bytes.Equal(plaintext, pt) {
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
}
}
func TestV2DecryptV1Compat(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("v1 legacy data")
// Encrypt with v1.
v1ct, err := Encrypt(key, plaintext, nil)
if err != nil {
t.Fatalf("Encrypt v1: %v", err)
}
// DecryptV2 should handle v1 ciphertext.
pt, keyID, err := DecryptV2(key, v1ct, nil)
if err != nil {
t.Fatalf("DecryptV2 with v1 ciphertext: %v", err)
}
if keyID != "" {
t.Fatalf("expected empty key ID for v1, got %q", keyID)
}
if !bytes.Equal(plaintext, pt) {
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
}
}
func TestV2WrongAAD(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("data")
aad := []byte("correct/path")
ct, _ := EncryptV2(key, "system", plaintext, aad)
_, _, err := DecryptV2(key, ct, []byte("wrong/path"))
if !errors.Is(err, ErrDecryptionFailed) {
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
}
}
func TestV2WrongKey(t *testing.T) {
key1, _ := GenerateKey()
key2, _ := GenerateKey()
plaintext := []byte("data")
ct, _ := EncryptV2(key1, "system", plaintext, nil)
_, _, err := DecryptV2(key2, ct, nil)
if !errors.Is(err, ErrDecryptionFailed) {
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
}
}
func TestExtractKeyID(t *testing.T) {
key, _ := GenerateKey()
// v1: empty key ID.
v1ct, _ := Encrypt(key, []byte("data"), nil)
kid, err := ExtractKeyID(v1ct)
if err != nil {
t.Fatalf("ExtractKeyID v1: %v", err)
}
if kid != "" {
t.Fatalf("expected empty key ID for v1, got %q", kid)
}
// v2: embedded key ID.
v2ct, _ := EncryptV2(key, "engine/transit/main", []byte("data"), nil)
kid, err = ExtractKeyID(v2ct)
if err != nil {
t.Fatalf("ExtractKeyID v2: %v", err)
}
if kid != "engine/transit/main" {
t.Fatalf("key ID: got %q, want %q", kid, "engine/transit/main")
}
}
func TestV2KeyIDTooLong(t *testing.T) {
key, _ := GenerateKey()
longID := string(make([]byte, MaxKeyIDLen+1))
_, err := EncryptV2(key, longID, []byte("data"), nil)
if !errors.Is(err, ErrKeyIDTooLong) {
t.Fatalf("expected ErrKeyIDTooLong, got: %v", err)
}
}
func TestV2EmptyKeyID(t *testing.T) {
key, _ := GenerateKey()
plaintext := []byte("data with empty key id")
ct, err := EncryptV2(key, "", plaintext, nil)
if err != nil {
t.Fatalf("EncryptV2 empty key ID: %v", err)
}
pt, keyID, err := DecryptV2(key, ct, nil)
if err != nil {
t.Fatalf("DecryptV2 empty key ID: %v", err)
}
if keyID != "" {
t.Fatalf("expected empty key ID, got %q", keyID)
}
if !bytes.Equal(plaintext, pt) {
t.Fatalf("roundtrip failed")
}
}

View File

@@ -20,7 +20,7 @@ func TestOpenAndMigrate(t *testing.T) {
}
// Verify tables exist.
tables := []string{"seal_config", "barrier_entries", "schema_migrations"}
tables := []string{"seal_config", "barrier_entries", "schema_migrations", "barrier_keys"}
for _, table := range tables {
var name string
err := database.QueryRow(
@@ -38,7 +38,7 @@ func TestOpenAndMigrate(t *testing.T) {
// Check migration version.
var version int
_ = database.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
if version != 1 {
t.Errorf("migration version: got %d, want 1", version)
if version != 2 {
t.Errorf("migration version: got %d, want 2", version)
}
}

View File

@@ -30,6 +30,15 @@ var migrations = []string{
version INTEGER PRIMARY KEY,
applied_at DATETIME NOT NULL DEFAULT (datetime('now'))
);`,
// Version 2: barrier key registry for per-engine DEKs
`CREATE TABLE IF NOT EXISTS barrier_keys (
key_id TEXT PRIMARY KEY,
version INTEGER NOT NULL DEFAULT 1,
encrypted_dek BLOB NOT NULL,
created_at DATETIME NOT NULL DEFAULT (datetime('now')),
rotated_at DATETIME NOT NULL DEFAULT (datetime('now'))
);`,
}
// Migrate applies all pending migrations.

View File

@@ -147,6 +147,14 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
eng := factory()
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
// Create a per-engine DEK in the barrier for this mount.
if aesBarrier, ok := r.barrier.(*barrier.AESGCMBarrier); ok {
dekKeyID := fmt.Sprintf("engine/%s/%s", engineType, name)
if err := aesBarrier.CreateKey(ctx, dekKeyID); err != nil {
return fmt.Errorf("engine: create DEK %q: %w", dekKeyID, err)
}
}
if err := eng.Initialize(ctx, r.barrier, mountPath, config); err != nil {
return fmt.Errorf("engine: initialize %q: %w", name, err)
}

View File

@@ -0,0 +1,85 @@
package grpcserver
import (
"context"
"errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
)
type barrierServer struct {
pb.UnimplementedBarrierServiceServer
s *GRPCServer
}
func (bs *barrierServer) ListKeys(ctx context.Context, _ *pb.ListKeysRequest) (*pb.ListKeysResponse, error) {
keys, err := bs.s.sealMgr.Barrier().ListKeys(ctx)
if err != nil {
bs.s.logger.Error("grpc: list barrier keys", "error", err)
return nil, status.Error(codes.Internal, "failed to list keys")
}
resp := &pb.ListKeysResponse{}
for _, k := range keys {
resp.Keys = append(resp.Keys, &pb.BarrierKeyInfo{
KeyId: k.KeyID,
Version: int32(k.Version),
CreatedAt: k.CreatedAt,
RotatedAt: k.RotatedAt,
})
}
return resp, nil
}
func (bs *barrierServer) RotateMEK(ctx context.Context, req *pb.RotateMEKRequest) (*pb.RotateMEKResponse, error) {
if req.Password == "" {
return nil, status.Error(codes.InvalidArgument, "password is required")
}
if err := bs.s.sealMgr.RotateMEK(ctx, []byte(req.Password)); err != nil {
if errors.Is(err, seal.ErrInvalidPassword) {
return nil, status.Error(codes.Unauthenticated, "invalid password")
}
if errors.Is(err, seal.ErrSealed) {
return nil, status.Error(codes.FailedPrecondition, "sealed")
}
bs.s.logger.Error("grpc: rotate MEK", "error", err)
return nil, status.Error(codes.Internal, "rotation failed")
}
bs.s.logger.Info("audit: MEK rotated")
return &pb.RotateMEKResponse{Ok: true}, nil
}
func (bs *barrierServer) RotateKey(ctx context.Context, req *pb.RotateKeyRequest) (*pb.RotateKeyResponse, error) {
if req.KeyId == "" {
return nil, status.Error(codes.InvalidArgument, "key_id is required")
}
if err := bs.s.sealMgr.Barrier().RotateKey(ctx, req.KeyId); err != nil {
if errors.Is(err, barrier.ErrKeyNotFound) {
return nil, status.Error(codes.NotFound, "key not found")
}
bs.s.logger.Error("grpc: rotate key", "key_id", req.KeyId, "error", err)
return nil, status.Error(codes.Internal, "rotation failed")
}
bs.s.logger.Info("audit: DEK rotated", "key_id", req.KeyId)
return &pb.RotateKeyResponse{Ok: true}, nil
}
func (bs *barrierServer) Migrate(ctx context.Context, _ *pb.MigrateBarrierRequest) (*pb.MigrateBarrierResponse, error) {
migrated, err := bs.s.sealMgr.Barrier().MigrateToV2(ctx)
if err != nil {
bs.s.logger.Error("grpc: barrier migration", "error", err)
return nil, status.Error(codes.Internal, "migration failed")
}
bs.s.logger.Info("audit: barrier migrated to v2", "entries_migrated", migrated)
return &pb.MigrateBarrierResponse{Migrated: int32(migrated)}, nil
}

View File

@@ -60,7 +60,7 @@ func (s *GRPCServer) Start() error {
}
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{tlsCert},
MinVersion: tls.VersionTLS12,
MinVersion: tls.VersionTLS13,
}
creds := credentials.NewTLS(tlsCfg)
@@ -81,6 +81,7 @@ func (s *GRPCServer) Start() error {
pb.RegisterPKIServiceServer(s.srv, &pkiServer{s: s})
pb.RegisterCAServiceServer(s.srv, &caServer{s: s})
pb.RegisterPolicyServiceServer(s.srv, &policyServer{s: s})
pb.RegisterBarrierServiceServer(s.srv, &barrierServer{s: s})
pb.RegisterACMEServiceServer(s.srv, &acmeServer{s: s})
lis, err := net.Listen("tcp", s.cfg.Server.GRPCAddr)
@@ -136,7 +137,11 @@ func sealRequiredMethods() map[string]bool {
"/metacrypt.v2.ACMEService/CreateEAB": true,
"/metacrypt.v2.ACMEService/SetConfig": true,
"/metacrypt.v2.ACMEService/ListAccounts": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.BarrierService/ListKeys": true,
"/metacrypt.v2.BarrierService/RotateMEK": true,
"/metacrypt.v2.BarrierService/RotateKey": true,
"/metacrypt.v2.BarrierService/Migrate": true,
}
}
@@ -166,7 +171,11 @@ func authRequiredMethods() map[string]bool {
"/metacrypt.v2.ACMEService/CreateEAB": true,
"/metacrypt.v2.ACMEService/SetConfig": true,
"/metacrypt.v2.ACMEService/ListAccounts": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.BarrierService/ListKeys": true,
"/metacrypt.v2.BarrierService/RotateMEK": true,
"/metacrypt.v2.BarrierService/RotateKey": true,
"/metacrypt.v2.BarrierService/Migrate": true,
}
}
@@ -182,9 +191,15 @@ func adminRequiredMethods() map[string]bool {
"/metacrypt.v2.CAService/RevokeCert": true,
"/metacrypt.v2.CAService/DeleteCert": true,
"/metacrypt.v2.PolicyService/CreatePolicy": true,
"/metacrypt.v2.PolicyService/DeletePolicy": true,
"/metacrypt.v2.PolicyService/ListPolicies": true,
"/metacrypt.v2.PolicyService/GetPolicy": true,
"/metacrypt.v2.PolicyService/DeletePolicy": true,
"/metacrypt.v2.ACMEService/SetConfig": true,
"/metacrypt.v2.ACMEService/ListAccounts": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.ACMEService/ListOrders": true,
"/metacrypt.v2.BarrierService/ListKeys": true,
"/metacrypt.v2.BarrierService/RotateMEK": true,
"/metacrypt.v2.BarrierService/RotateKey": true,
"/metacrypt.v2.BarrierService/Migrate": true,
}
}

View File

@@ -22,6 +22,38 @@ const (
EffectDeny Effect = "deny"
)
// Action constants for policy evaluation.
const (
ActionAny = "any" // matches all non-admin actions
ActionRead = "read" // retrieve/list operations
ActionWrite = "write" // create/update/delete operations
ActionEncrypt = "encrypt" // encrypt data
ActionDecrypt = "decrypt" // decrypt data
ActionSign = "sign" // sign data
ActionVerify = "verify" // verify signatures
ActionHMAC = "hmac" // compute HMAC
ActionAdmin = "admin" // administrative operations (never matched by "any")
)
// validEffects is the set of recognized effects.
var validEffects = map[Effect]bool{
EffectAllow: true,
EffectDeny: true,
}
// validActions is the set of recognized actions.
var validActions = map[string]bool{
ActionAny: true,
ActionRead: true,
ActionWrite: true,
ActionEncrypt: true,
ActionDecrypt: true,
ActionSign: true,
ActionVerify: true,
ActionHMAC: true,
ActionAdmin: true,
}
// Rule is a policy rule stored in the barrier.
type Rule struct {
ID string `json:"id"`
@@ -88,8 +120,34 @@ func (e *Engine) Match(ctx context.Context, req *Request) (Effect, bool, error)
return EffectDeny, false, nil // default deny, no matching rule
}
// CreateRule stores a new policy rule.
// LintRule validates a rule's effect and actions. It returns a list of problems
// (empty if the rule is valid). This does not check resource patterns or other
// fields — only the enumerated values that must come from a known set.
func LintRule(rule *Rule) []string {
var problems []string
if rule.ID == "" {
problems = append(problems, "rule ID is required")
}
if !validEffects[rule.Effect] {
problems = append(problems, fmt.Sprintf("invalid effect %q (must be %q or %q)", rule.Effect, EffectAllow, EffectDeny))
}
for _, a := range rule.Actions {
if !validActions[strings.ToLower(a)] {
problems = append(problems, fmt.Sprintf("invalid action %q", a))
}
}
return problems
}
// CreateRule validates and stores a new policy rule.
func (e *Engine) CreateRule(ctx context.Context, rule *Rule) error {
if problems := LintRule(rule); len(problems) > 0 {
return fmt.Errorf("policy: invalid rule: %s", strings.Join(problems, "; "))
}
data, err := json.Marshal(rule)
if err != nil {
return fmt.Errorf("policy: marshal rule: %w", err)
@@ -157,14 +215,28 @@ func matchesRule(rule *Rule, req *Request) bool {
return false
}
// Check action match.
if len(rule.Actions) > 0 && !containsString(rule.Actions, req.Action) {
// Check action match. The "any" action matches all non-admin actions.
if len(rule.Actions) > 0 && !matchesAction(rule.Actions, req.Action) {
return false
}
return true
}
// matchesAction checks whether any of the rule's actions match the requested action.
// The special "any" action matches all actions except "admin".
func matchesAction(ruleActions []string, reqAction string) bool {
for _, a := range ruleActions {
if strings.EqualFold(a, reqAction) {
return true
}
if strings.EqualFold(a, ActionAny) && !strings.EqualFold(reqAction, ActionAdmin) {
return true
}
}
return false
}
func containsString(haystack []string, needle string) bool {
for _, s := range haystack {
if strings.EqualFold(s, needle) {

View File

@@ -141,6 +141,96 @@ func TestPolicyPriorityOrder(t *testing.T) {
}
}
func TestActionAnyMatchesNonAdmin(t *testing.T) {
e, cleanup := setupPolicy(t)
defer cleanup()
ctx := context.Background()
_ = e.CreateRule(ctx, &Rule{
ID: "any-rule",
Priority: 100,
Effect: EffectAllow,
Roles: []string{"user"},
Resources: []string{"transit/*"},
Actions: []string{ActionAny},
})
// "any" should match encrypt, decrypt, sign, verify, hmac, read, write.
for _, action := range []string{ActionEncrypt, ActionDecrypt, ActionSign, ActionVerify, ActionHMAC, ActionRead, ActionWrite} {
effect, _ := e.Evaluate(ctx, &Request{
Username: "alice",
Roles: []string{"user"},
Resource: "transit/default",
Action: action,
})
if effect != EffectAllow {
t.Errorf("action %q should be allowed by 'any', got: %s", action, effect)
}
}
// "any" must NOT match "admin".
effect, _ := e.Evaluate(ctx, &Request{
Username: "alice",
Roles: []string{"user"},
Resource: "transit/default",
Action: ActionAdmin,
})
if effect != EffectDeny {
t.Fatalf("action 'admin' should not be matched by 'any', got: %s", effect)
}
}
func TestLintRule(t *testing.T) {
// Valid rule.
problems := LintRule(&Rule{
ID: "ok",
Effect: EffectAllow,
Actions: []string{ActionAny, ActionEncrypt},
})
if len(problems) > 0 {
t.Errorf("expected no problems, got: %v", problems)
}
// Missing ID.
problems = LintRule(&Rule{Effect: EffectAllow})
if len(problems) != 1 {
t.Errorf("expected 1 problem for missing ID, got: %v", problems)
}
// Invalid effect.
problems = LintRule(&Rule{ID: "bad-effect", Effect: "maybe"})
if len(problems) != 1 {
t.Errorf("expected 1 problem for bad effect, got: %v", problems)
}
// Invalid action.
problems = LintRule(&Rule{ID: "bad-action", Effect: EffectAllow, Actions: []string{"destroy"}})
if len(problems) != 1 {
t.Errorf("expected 1 problem for bad action, got: %v", problems)
}
// Multiple problems.
problems = LintRule(&Rule{Effect: "bogus", Actions: []string{"nope"}})
if len(problems) != 3 { // missing ID + bad effect + bad action
t.Errorf("expected 3 problems, got: %v", problems)
}
}
func TestCreateRuleRejectsInvalid(t *testing.T) {
e, cleanup := setupPolicy(t)
defer cleanup()
ctx := context.Background()
err := e.CreateRule(ctx, &Rule{
ID: "bad",
Effect: EffectAllow,
Actions: []string{"obliterate"},
})
if err == nil {
t.Fatal("expected error for invalid action, got nil")
}
}
func TestPolicyUsernameMatch(t *testing.T) {
e, cleanup := setupPolicy(t)
defer cleanup()

View File

@@ -141,7 +141,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
defer crypto.Zeroize(kwk)
// Encrypt MEK with KWK.
encryptedMEK, err := crypto.Encrypt(kwk, mek)
encryptedMEK, err := crypto.Encrypt(kwk, mek, nil)
if err != nil {
crypto.Zeroize(mek)
return fmt.Errorf("seal: encrypt mek: %w", err)
@@ -220,7 +220,7 @@ func (m *Manager) Unseal(password []byte) error {
kwk := crypto.DeriveKey(password, salt, params)
defer crypto.Zeroize(kwk)
mek, err := crypto.Decrypt(kwk, encryptedMEK)
mek, err := crypto.Decrypt(kwk, encryptedMEK, nil)
if err != nil {
m.logger.Debug("unseal failed: invalid password")
return ErrInvalidPassword
@@ -239,6 +239,79 @@ func (m *Manager) Unseal(password []byte) error {
return nil
}
// RotateMEK generates a new MEK, re-wraps all DEKs in the barrier, and
// updates the encrypted MEK in seal_config. The password is required to
// derive the KWK for re-encrypting the new MEK.
func (m *Manager) RotateMEK(ctx context.Context, password []byte) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.state != StateUnsealed {
return ErrSealed
}
// Read seal config for KDF params.
var (
salt []byte
argTime, argMem uint32
argThreads uint8
)
err := m.db.QueryRow(`
SELECT kdf_salt, argon2_time, argon2_memory, argon2_threads
FROM seal_config WHERE id = 1`).Scan(&salt, &argTime, &argMem, &argThreads)
if err != nil {
return fmt.Errorf("seal: read config: %w", err)
}
// Verify password by decrypting existing MEK.
params := crypto.Argon2Params{Time: argTime, Memory: argMem, Threads: argThreads}
kwk := crypto.DeriveKey(password, salt, params)
defer crypto.Zeroize(kwk)
var encryptedMEK []byte
err = m.db.QueryRow("SELECT encrypted_mek FROM seal_config WHERE id = 1").Scan(&encryptedMEK)
if err != nil {
return fmt.Errorf("seal: read encrypted mek: %w", err)
}
_, err = crypto.Decrypt(kwk, encryptedMEK, nil)
if err != nil {
return ErrInvalidPassword
}
// Generate new MEK.
newMEK, err := crypto.GenerateKey()
if err != nil {
return fmt.Errorf("seal: generate new mek: %w", err)
}
// Re-wrap all DEKs with new MEK.
if err := m.barrier.ReWrapKeys(ctx, newMEK); err != nil {
crypto.Zeroize(newMEK)
return fmt.Errorf("seal: re-wrap keys: %w", err)
}
// Encrypt new MEK with KWK.
newEncMEK, err := crypto.Encrypt(kwk, newMEK, nil)
if err != nil {
crypto.Zeroize(newMEK)
return fmt.Errorf("seal: encrypt new mek: %w", err)
}
// Update seal_config.
_, err = m.db.ExecContext(ctx,
"UPDATE seal_config SET encrypted_mek = ? WHERE id = 1", newEncMEK)
if err != nil {
crypto.Zeroize(newMEK)
return fmt.Errorf("seal: update seal config: %w", err)
}
// Swap in-memory MEK.
crypto.Zeroize(m.mek)
m.mek = newMEK
m.logger.Info("MEK rotated successfully")
return nil
}
// Seal seals the service: zeroizes MEK, seals the barrier.
func (m *Manager) Seal() error {
m.mu.Lock()

View File

@@ -120,6 +120,94 @@ func TestSealCheckInitializedPersists(t *testing.T) {
}
}
func TestSealRotateMEK(t *testing.T) {
mgr, cleanup := setupSeal(t)
defer cleanup()
ctx := context.Background()
_ = mgr.CheckInitialized()
password := []byte("test-password")
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
_ = mgr.Initialize(ctx, password, params)
// Create a DEK and write data through the barrier.
b := mgr.Barrier()
_ = b.CreateKey(ctx, "system")
_ = b.CreateKey(ctx, "engine/ca/prod")
_ = b.Put(ctx, "policy/rule1", []byte("policy-data"))
_ = b.Put(ctx, "engine/ca/prod/cert", []byte("cert-data"))
// Rotate MEK.
if err := mgr.RotateMEK(ctx, password); err != nil {
t.Fatalf("RotateMEK: %v", err)
}
// Data should still be readable.
got, err := b.Get(ctx, "policy/rule1")
if err != nil {
t.Fatalf("Get after MEK rotation: %v", err)
}
if string(got) != "policy-data" {
t.Fatalf("data: got %q", got)
}
got2, err := b.Get(ctx, "engine/ca/prod/cert")
if err != nil {
t.Fatalf("Get engine data after MEK rotation: %v", err)
}
if string(got2) != "cert-data" {
t.Fatalf("data: got %q", got2)
}
// Seal and unseal with the same password should work
// (the new MEK is now encrypted with the KWK).
if err := mgr.Seal(); err != nil {
t.Fatalf("Seal: %v", err)
}
if err := mgr.Unseal(password); err != nil {
t.Fatalf("Unseal after MEK rotation: %v", err)
}
got3, err := b.Get(ctx, "engine/ca/prod/cert")
if err != nil {
t.Fatalf("Get after seal/unseal: %v", err)
}
if string(got3) != "cert-data" {
t.Fatalf("data after seal/unseal: got %q", got3)
}
}
func TestSealRotateMEKWrongPassword(t *testing.T) {
mgr, cleanup := setupSeal(t)
defer cleanup()
ctx := context.Background()
_ = mgr.CheckInitialized()
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
_ = mgr.Initialize(ctx, []byte("correct"), params)
err := mgr.RotateMEK(ctx, []byte("wrong"))
if !errors.Is(err, ErrInvalidPassword) {
t.Fatalf("expected ErrInvalidPassword, got: %v", err)
}
}
func TestSealRotateMEKWhenSealed(t *testing.T) {
mgr, cleanup := setupSeal(t)
defer cleanup()
ctx := context.Background()
_ = mgr.CheckInitialized()
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
_ = mgr.Initialize(ctx, []byte("password"), params)
_ = mgr.Seal()
err := mgr.RotateMEK(ctx, []byte("password"))
if !errors.Is(err, ErrSealed) {
t.Fatalf("expected ErrSealed, got: %v", err)
}
}
func TestSealStateString(t *testing.T) {
tests := []struct {
want string

View File

@@ -11,6 +11,7 @@ import (
mcias "git.wntrmute.dev/kyle/mcias/clients/go"
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
@@ -45,6 +46,11 @@ func (s *Server) registerRoutes(r chi.Router) {
r.Get("/v1/pki/{mount}/issuer/{name}", s.requireUnseal(s.handlePKIIssuer))
r.Get("/v1/pki/{mount}/issuer/{name}/crl", s.requireUnseal(s.handlePKICRL))
r.Get("/v1/barrier/keys", s.requireAdmin(s.handleBarrierKeys))
r.Post("/v1/barrier/rotate-mek", s.requireAdmin(s.handleRotateMEK))
r.Post("/v1/barrier/rotate-key", s.requireAdmin(s.handleRotateKey))
r.Post("/v1/barrier/migrate", s.requireAdmin(s.handleBarrierMigrate))
r.HandleFunc("/v1/policy/rules", s.requireAuth(s.handlePolicyRules))
r.HandleFunc("/v1/policy/rule", s.requireAuth(s.handlePolicyRule))
s.registerACMERoutes(r)
@@ -253,6 +259,31 @@ func (s *Server) handleEngineUnmount(w http.ResponseWriter, r *http.Request) {
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
}
// adminOnlyOperations lists engine operations that require admin role.
// This enforces the same gates as the typed REST routes, ensuring the
// generic endpoint cannot bypass admin requirements.
var adminOnlyOperations = map[string]bool{
// CA engine.
"import-root": true,
"create-issuer": true,
"delete-issuer": true,
"revoke-cert": true,
"delete-cert": true,
// Transit engine.
"create-key": true,
"delete-key": true,
"rotate-key": true,
"update-key-config": true,
"trim-key": true,
// SSH CA engine.
"create-profile": true,
"update-profile": true,
"delete-profile": true,
// User engine.
"provision": true,
"delete-user": true,
}
func (s *Server) handleEngineRequest(w http.ResponseWriter, r *http.Request) {
var req struct {
Data map[string]interface{} `json:"data"`
@@ -271,6 +302,12 @@ func (s *Server) handleEngineRequest(w http.ResponseWriter, r *http.Request) {
info := TokenInfoFromContext(r.Context())
// Enforce admin requirement for operations that have admin-only typed routes.
if adminOnlyOperations[req.Operation] && !info.IsAdmin {
http.Error(w, `{"error":"forbidden: admin required"}`, http.StatusForbidden)
return
}
// Evaluate policy before dispatching to the engine.
policyReq := &policy.Request{
Username: info.Username,
@@ -412,6 +449,90 @@ func (s *Server) handlePolicyRule(w http.ResponseWriter, r *http.Request) {
}
}
// --- Barrier Key Management Handlers ---
func (s *Server) handleBarrierKeys(w http.ResponseWriter, r *http.Request) {
keys, err := s.seal.Barrier().ListKeys(r.Context())
if err != nil {
s.logger.Error("list barrier keys", "error", err)
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
return
}
if keys == nil {
keys = []barrier.KeyInfo{}
}
writeJSON(w, http.StatusOK, keys)
}
func (s *Server) handleRotateMEK(w http.ResponseWriter, r *http.Request) {
var req struct {
Password string `json:"password"`
}
if err := readJSON(r, &req); err != nil {
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
return
}
if req.Password == "" {
http.Error(w, `{"error":"password is required"}`, http.StatusBadRequest)
return
}
if err := s.seal.RotateMEK(r.Context(), []byte(req.Password)); err != nil {
if errors.Is(err, seal.ErrInvalidPassword) {
http.Error(w, `{"error":"invalid password"}`, http.StatusUnauthorized)
return
}
if errors.Is(err, seal.ErrSealed) {
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
return
}
s.logger.Error("rotate MEK", "error", err)
http.Error(w, `{"error":"rotation failed"}`, http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
}
func (s *Server) handleRotateKey(w http.ResponseWriter, r *http.Request) {
var req struct {
KeyID string `json:"key_id"`
}
if err := readJSON(r, &req); err != nil {
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
return
}
if req.KeyID == "" {
http.Error(w, `{"error":"key_id is required"}`, http.StatusBadRequest)
return
}
if err := s.seal.Barrier().RotateKey(r.Context(), req.KeyID); err != nil {
if errors.Is(err, barrier.ErrKeyNotFound) {
http.Error(w, `{"error":"key not found"}`, http.StatusNotFound)
return
}
s.logger.Error("rotate key", "key_id", req.KeyID, "error", err)
http.Error(w, `{"error":"rotation failed"}`, http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
}
func (s *Server) handleBarrierMigrate(w http.ResponseWriter, r *http.Request) {
migrated, err := s.seal.Barrier().MigrateToV2(r.Context())
if err != nil {
s.logger.Error("barrier migration", "error", err)
http.Error(w, `{"error":"migration failed"}`, http.StatusInternalServerError)
return
}
writeJSON(w, http.StatusOK, map[string]interface{}{
"migrated": migrated,
})
}
// --- CA Certificate Handlers ---
func (s *Server) handleGetCert(w http.ResponseWriter, r *http.Request) {
@@ -608,13 +729,29 @@ func (s *Server) getCAEngine(mountName string) (*ca.CAEngine, error) {
return caEng, nil
}
// operationAction maps an engine operation name to a policy action ("read" or "write").
// operationAction maps an engine operation name to a policy action.
func operationAction(op string) string {
switch op {
case "list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer":
return "read"
// Read operations.
case "list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer",
"list-keys", "get-key", "get-public-key", "list-users", "get-profile", "list-profiles":
return policy.ActionRead
// Granular cryptographic operations (including batch variants).
case "encrypt", "batch-encrypt":
return policy.ActionEncrypt
case "decrypt", "batch-decrypt":
return policy.ActionDecrypt
case "sign", "sign-host", "sign-user":
return policy.ActionSign
case "verify":
return policy.ActionVerify
case "hmac":
return policy.ActionHMAC
// Everything else is a write.
default:
return "write"
return policy.ActionWrite
}
}

View File

@@ -58,13 +58,7 @@ func (s *Server) Start() error {
s.registerRoutes(r)
tlsCfg := &tls.Config{
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
},
MinVersion: tls.VersionTLS13,
}
s.httpSrv = &http.Server{

View File

@@ -241,18 +241,84 @@ func TestEngineRequestPolicyAllowsWithRule(t *testing.T) {
}
}
// TestOperationAction verifies the read/write classification of operations.
func TestOperationAction(t *testing.T) {
readOps := []string{"list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer"}
for _, op := range readOps {
if got := operationAction(op); got != "read" {
t.Errorf("operationAction(%q) = %q, want %q", op, got, "read")
// TestEngineRequestAdminOnlyBlocksNonAdmin verifies that admin-only operations
// via the generic endpoint are rejected for non-admin users.
func TestEngineRequestAdminOnlyBlocksNonAdmin(t *testing.T) {
srv, sealMgr, _ := setupTestServer(t)
unsealServer(t, sealMgr, nil)
for _, op := range []string{"create-issuer", "delete-cert", "create-key", "rotate-key", "create-profile", "provision"} {
body := makeEngineRequest("test-mount", op)
req := httptest.NewRequest(http.MethodPost, "/v1/engine/request", strings.NewReader(body))
req = withTokenInfo(req, &auth.TokenInfo{Username: "alice", Roles: []string{"user"}, IsAdmin: false})
w := httptest.NewRecorder()
srv.handleEngineRequest(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("operation %q: expected 403 for non-admin, got %d", op, w.Code)
}
}
writeOps := []string{"issue", "renew", "create-issuer", "delete-issuer", "sign-csr", "revoke"}
for _, op := range writeOps {
if got := operationAction(op); got != "write" {
t.Errorf("operationAction(%q) = %q, want %q", op, got, "write")
}
// TestEngineRequestAdminOnlyAllowsAdmin verifies that admin-only operations
// via the generic endpoint are allowed for admin users.
func TestEngineRequestAdminOnlyAllowsAdmin(t *testing.T) {
srv, sealMgr, _ := setupTestServer(t)
unsealServer(t, sealMgr, nil)
for _, op := range []string{"create-issuer", "delete-cert", "create-key", "rotate-key", "create-profile", "provision"} {
body := makeEngineRequest("test-mount", op)
req := httptest.NewRequest(http.MethodPost, "/v1/engine/request", strings.NewReader(body))
req = withTokenInfo(req, &auth.TokenInfo{Username: "admin", Roles: []string{"admin"}, IsAdmin: true})
w := httptest.NewRecorder()
srv.handleEngineRequest(w, req)
// Admin passes the admin check; will get 404 (mount not found) not 403.
if w.Code == http.StatusForbidden {
t.Errorf("operation %q: admin should not be forbidden, got 403", op)
}
}
}
// TestOperationAction verifies the action classification of operations.
func TestOperationAction(t *testing.T) {
tests := map[string]string{
// Read operations.
"list-issuers": policy.ActionRead,
"list-certs": policy.ActionRead,
"get-cert": policy.ActionRead,
"get-root": policy.ActionRead,
"get-chain": policy.ActionRead,
"get-issuer": policy.ActionRead,
"list-keys": policy.ActionRead,
"get-key": policy.ActionRead,
"get-public-key": policy.ActionRead,
"list-users": policy.ActionRead,
"get-profile": policy.ActionRead,
"list-profiles": policy.ActionRead,
// Granular crypto operations (including batch variants).
"encrypt": policy.ActionEncrypt,
"batch-encrypt": policy.ActionEncrypt,
"decrypt": policy.ActionDecrypt,
"batch-decrypt": policy.ActionDecrypt,
"sign": policy.ActionSign,
"sign-host": policy.ActionSign,
"sign-user": policy.ActionSign,
"verify": policy.ActionVerify,
"hmac": policy.ActionHMAC,
// Write operations.
"issue": policy.ActionWrite,
"renew": policy.ActionWrite,
"create-issuer": policy.ActionWrite,
"delete-issuer": policy.ActionWrite,
"sign-csr": policy.ActionWrite,
"revoke": policy.ActionWrite,
}
for op, want := range tests {
if got := operationAction(op); got != want {
t.Errorf("operationAction(%q) = %q, want %q", op, got, want)
}
}
}

View File

@@ -193,6 +193,7 @@ func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
vault: vault,
logger: slog.Default(),
staticFS: staticFS,
csrf: newCSRFProtect(),
}
}

View File

@@ -30,7 +30,7 @@ type VaultClient struct {
func NewVaultClient(addr, caCertPath string, logger *slog.Logger) (*VaultClient, error) {
logger.Debug("connecting to vault", "addr", addr, "ca_cert", caCertPath)
tlsCfg := &tls.Config{MinVersion: tls.VersionTLS12}
tlsCfg := &tls.Config{MinVersion: tls.VersionTLS13}
if caCertPath != "" {
logger.Debug("loading vault CA certificate", "path", caCertPath)
pemData, err := os.ReadFile(caCertPath) //nolint:gosec

119
internal/webserver/csrf.go Normal file
View File

@@ -0,0 +1,119 @@
package webserver
import (
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"net/http"
"strings"
"sync"
)
const (
csrfCookieName = "metacrypt_csrf"
csrfFieldName = "csrf_token"
csrfTokenLen = 32
)
// csrfProtect provides CSRF protection using the signed double-submit cookie
// pattern. A random secret is generated at startup. CSRF tokens are an HMAC of
// a random nonce, sent as both a cookie and a hidden form field. On POST the
// middleware verifies that the form field matches the cookie's HMAC.
type csrfProtect struct {
secret []byte
once sync.Once
}
func newCSRFProtect() *csrfProtect {
secret := make([]byte, 32)
if _, err := rand.Read(secret); err != nil {
panic("csrf: failed to generate secret: " + err.Error())
}
return &csrfProtect{secret: secret}
}
// generateToken creates a new CSRF token: base64(nonce) + "." + base64(hmac(nonce)).
func (c *csrfProtect) generateToken() (string, error) {
nonce := make([]byte, csrfTokenLen)
if _, err := rand.Read(nonce); err != nil {
return "", fmt.Errorf("csrf: generate nonce: %w", err)
}
nonceB64 := base64.RawURLEncoding.EncodeToString(nonce)
mac := hmac.New(sha256.New, c.secret)
mac.Write(nonce)
sigB64 := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return nonceB64 + "." + sigB64, nil
}
// validToken checks that a token has a valid HMAC signature.
func (c *csrfProtect) validToken(token string) bool {
parts := strings.SplitN(token, ".", 2)
if len(parts) != 2 {
return false
}
nonce, err := base64.RawURLEncoding.DecodeString(parts[0])
if err != nil || len(nonce) != csrfTokenLen {
return false
}
sig, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return false
}
mac := hmac.New(sha256.New, c.secret)
mac.Write(nonce)
return hmac.Equal(mac.Sum(nil), sig)
}
// setToken generates a new CSRF token, sets it as a cookie, and returns it
// for embedding in a form.
func (c *csrfProtect) setToken(w http.ResponseWriter) string {
token, err := c.generateToken()
if err != nil {
return ""
}
http.SetCookie(w, &http.Cookie{
Name: csrfCookieName,
Value: token,
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteStrictMode,
})
return token
}
// middleware returns an HTTP middleware that enforces CSRF validation on
// mutation requests (POST, PUT, PATCH, DELETE). GET/HEAD/OPTIONS are passed
// through. The HTMX hx-post for /v1/seal is excluded since it hits the API
// server directly and uses token auth, not cookies.
func (c *csrfProtect) middleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case http.MethodGet, http.MethodHead, http.MethodOptions:
next.ServeHTTP(w, r)
return
}
// Read token from form field (works for both regular forms and
// multipart forms since ParseForm/ParseMultipartForm will have
// been called or the field is available via FormValue).
formToken := r.FormValue(csrfFieldName)
// Read token from cookie.
cookie, err := r.Cookie(csrfCookieName)
if err != nil || cookie.Value == "" {
http.Error(w, "CSRF validation failed", http.StatusForbidden)
return
}
// Both tokens must be present, match each other, and be validly signed.
if formToken == "" || formToken != cookie.Value || !c.validToken(formToken) {
http.Error(w, "CSRF validation failed", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

View File

@@ -0,0 +1,105 @@
package webserver
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
)
func TestCSRFTokenGenerateAndValidate(t *testing.T) {
c := newCSRFProtect()
token, err := c.generateToken()
if err != nil {
t.Fatalf("generateToken: %v", err)
}
if !c.validToken(token) {
t.Fatal("valid token rejected")
}
}
func TestCSRFTokenInvalidFormats(t *testing.T) {
c := newCSRFProtect()
for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} {
if c.validToken(bad) {
t.Errorf("should reject %q", bad)
}
}
}
func TestCSRFTokenCrossSecret(t *testing.T) {
c1 := newCSRFProtect()
c2 := newCSRFProtect()
token, _ := c1.generateToken()
if c2.validToken(token) {
t.Fatal("token from different secret should be rejected")
}
}
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
c := newCSRFProtect()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodGet, "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("GET should pass through, got %d", w.Code)
}
}
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
c := newCSRFProtect()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Fatalf("POST without CSRF token should be forbidden, got %d", w.Code)
}
}
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
c := newCSRFProtect()
token, _ := c.generateToken()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := csrfFieldName + "=" + token
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
}
}
func TestCSRFMiddlewareRejectsMismatch(t *testing.T) {
c := newCSRFProtect()
token1, _ := c.generateToken()
token2, _ := c.generateToken()
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
body := csrfFieldName + "=" + token1
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2})
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code)
}
}

View File

@@ -188,7 +188,7 @@ func (ws *WebServer) handleLogin(w http.ResponseWriter, r *http.Request) {
Path: "/",
HttpOnly: true,
Secure: true,
SameSite: http.SameSiteLaxMode,
SameSite: http.SameSiteStrictMode,
})
http.Redirect(w, r, "/dashboard", http.StatusFound)
default:

View File

@@ -70,6 +70,7 @@ type WebServer struct {
logger *slog.Logger
httpSrv *http.Server
staticFS fs.FS
csrf *csrfProtect
tgzCache sync.Map // key: UUID string → *tgzEntry
userCache sync.Map // key: UUID string → *cachedUsername
}
@@ -125,6 +126,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*WebServer, error) {
vault: vault,
logger: logger,
staticFS: staticFS,
csrf: newCSRFProtect(),
}
if tok := cfg.MCIAS.ServiceToken; tok != "" {
@@ -188,6 +190,7 @@ func (lw *loggingResponseWriter) Unwrap() http.ResponseWriter {
func (ws *WebServer) Start() error {
r := chi.NewRouter()
r.Use(ws.loggingMiddleware)
r.Use(ws.csrf.middleware)
ws.registerRoutes(r)
ws.httpSrv = &http.Server{
@@ -201,7 +204,7 @@ func (ws *WebServer) Start() error {
ws.logger.Info("starting web server", "addr", ws.cfg.Web.ListenAddr)
if ws.cfg.Web.TLSCert != "" && ws.cfg.Web.TLSKey != "" {
ws.httpSrv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
ws.httpSrv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS13}
err := ws.httpSrv.ListenAndServeTLS(ws.cfg.Web.TLSCert, ws.cfg.Web.TLSKey)
if err != nil && err != http.ErrServerClosed {
return fmt.Errorf("webserver: %w", err)
@@ -226,7 +229,18 @@ func (ws *WebServer) Shutdown(ctx context.Context) error {
}
func (ws *WebServer) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
tmpl, err := template.ParseFS(webui.FS,
csrfToken := ws.csrf.setToken(w)
funcMap := template.FuncMap{
"csrfField": func() template.HTML {
return template.HTML(fmt.Sprintf(
`<input type="hidden" name="%s" value="%s">`,
csrfFieldName, template.HTMLEscapeString(csrfToken),
))
},
}
tmpl, err := template.New("").Funcs(funcMap).ParseFS(webui.FS,
"templates/layout.html",
"templates/"+name,
)