Implement a two-level key hierarchy: the MEK now wraps per-engine DEKs stored in a new barrier_keys table, rather than encrypting all barrier entries directly. A v2 ciphertext format (0x02) embeds the key ID so the barrier can resolve which DEK to use on decryption. v1 ciphertext remains supported for backward compatibility. Key changes: - crypto: EncryptV2/DecryptV2/ExtractKeyID for v2 ciphertext with key IDs - barrier: key registry (CreateKey, RotateKey, ListKeys, MigrateToV2, ReWrapKeys) - seal: RotateMEK re-wraps DEKs without re-encrypting data - engine: Mount auto-creates per-engine DEK - REST + gRPC: barrier/keys, barrier/rotate-mek, barrier/rotate-key, barrier/migrate - proto: BarrierService (v1 + v2) with ListKeys, RotateMEK, RotateKey, Migrate - db: migration v2 adds barrier_keys table Also includes: security audit report, CSRF protection, engine design specs (sshca, transit, user), path-bound AAD migration tool, policy engine enhancements, and ARCHITECTURE.md updates. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -13,8 +13,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSealed = errors.New("barrier: sealed")
|
||||
ErrNotFound = errors.New("barrier: entry not found")
|
||||
ErrSealed = errors.New("barrier: sealed")
|
||||
ErrNotFound = errors.New("barrier: entry not found")
|
||||
ErrKeyNotFound = errors.New("barrier: key not found")
|
||||
)
|
||||
|
||||
// Barrier is the encrypted storage barrier interface.
|
||||
@@ -36,11 +37,20 @@ type Barrier interface {
|
||||
List(ctx context.Context, prefix string) ([]string, error)
|
||||
}
|
||||
|
||||
// KeyInfo holds metadata about a barrier key (DEK).
|
||||
type KeyInfo struct {
|
||||
KeyID string `json:"key_id"`
|
||||
Version int `json:"version"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
RotatedAt string `json:"rotated_at"`
|
||||
}
|
||||
|
||||
// AESGCMBarrier implements Barrier using AES-256-GCM encryption.
|
||||
type AESGCMBarrier struct {
|
||||
db *sql.DB
|
||||
mek []byte
|
||||
mu sync.RWMutex
|
||||
db *sql.DB
|
||||
mek []byte
|
||||
keys map[string][]byte // key_id → plaintext DEK
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewAESGCMBarrier creates a new AES-GCM barrier backed by the given database.
|
||||
@@ -51,15 +61,56 @@ func NewAESGCMBarrier(db *sql.DB) *AESGCMBarrier {
|
||||
func (b *AESGCMBarrier) Unseal(mek []byte) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
k := make([]byte, len(mek))
|
||||
copy(k, mek)
|
||||
b.mek = k
|
||||
b.keys = make(map[string][]byte)
|
||||
|
||||
// Load DEKs from barrier_keys table.
|
||||
if err := b.loadKeys(); err != nil {
|
||||
// If the table doesn't exist yet (pre-migration), that's OK.
|
||||
// The barrier will use MEK directly for v1 entries.
|
||||
b.keys = make(map[string][]byte)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadKeys decrypts all DEKs from the barrier_keys table into memory.
|
||||
// Caller must hold b.mu.
|
||||
func (b *AESGCMBarrier) loadKeys() error {
|
||||
rows, err := b.db.Query("SELECT key_id, encrypted_dek FROM barrier_keys")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
for rows.Next() {
|
||||
var keyID string
|
||||
var encDEK []byte
|
||||
if err := rows.Scan(&keyID, &encDEK); err != nil {
|
||||
return fmt.Errorf("barrier: scan key %q: %w", keyID, err)
|
||||
}
|
||||
dek, err := crypto.Decrypt(b.mek, encDEK, []byte(keyID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: decrypt key %q: %w", keyID, err)
|
||||
}
|
||||
b.keys[keyID] = dek
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Seal() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
// Zeroize all DEKs.
|
||||
for _, dek := range b.keys {
|
||||
crypto.Zeroize(dek)
|
||||
}
|
||||
b.keys = nil
|
||||
|
||||
if b.mek != nil {
|
||||
crypto.Zeroize(b.mek)
|
||||
b.mek = nil
|
||||
@@ -73,9 +124,22 @@ func (b *AESGCMBarrier) IsSealed() bool {
|
||||
return b.mek == nil
|
||||
}
|
||||
|
||||
// resolveKeyID determines the key ID for a given barrier path.
|
||||
func resolveKeyID(path string) string {
|
||||
// Paths under engine/{type}/{mount}/... use per-engine DEKs.
|
||||
if strings.HasPrefix(path, "engine/") {
|
||||
parts := strings.SplitN(path, "/", 4) // engine/{type}/{mount}/...
|
||||
if len(parts) >= 3 {
|
||||
return "engine/" + parts[1] + "/" + parts[2]
|
||||
}
|
||||
}
|
||||
return "system"
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Get(ctx context.Context, path string) ([]byte, error) {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
keys := b.keys
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return nil, ErrSealed
|
||||
@@ -91,22 +155,52 @@ func (b *AESGCMBarrier) Get(ctx context.Context, path string) ([]byte, error) {
|
||||
return nil, fmt.Errorf("barrier: get %q: %w", path, err)
|
||||
}
|
||||
|
||||
plaintext, err := crypto.Decrypt(mek, encrypted)
|
||||
// Check version byte to determine decryption strategy.
|
||||
if len(encrypted) > 0 && encrypted[0] == crypto.BarrierVersionV2 {
|
||||
keyID, err := crypto.ExtractKeyID(encrypted)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: extract key ID %q: %w", path, err)
|
||||
}
|
||||
dek, ok := keys[keyID]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("barrier: %w: %q for path %q", ErrKeyNotFound, keyID, path)
|
||||
}
|
||||
pt, _, err := crypto.DecryptV2(dek, encrypted, []byte(path))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: decrypt %q: %w", path, err)
|
||||
}
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
// v1 ciphertext — use MEK directly (backward compat).
|
||||
pt, err := crypto.Decrypt(mek, encrypted, []byte(path))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: decrypt %q: %w", path, err)
|
||||
}
|
||||
return plaintext, nil
|
||||
return pt, nil
|
||||
}
|
||||
|
||||
func (b *AESGCMBarrier) Put(ctx context.Context, path string, value []byte) error {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
keys := b.keys
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
encrypted, err := crypto.Encrypt(mek, value)
|
||||
keyID := resolveKeyID(path)
|
||||
|
||||
var encrypted []byte
|
||||
var err error
|
||||
|
||||
if dek, ok := keys[keyID]; ok {
|
||||
// Use v2 format with the appropriate DEK.
|
||||
encrypted, err = crypto.EncryptV2(dek, keyID, value, []byte(path))
|
||||
} else {
|
||||
// No DEK registered for this key ID — fall back to MEK with v1 format.
|
||||
encrypted, err = crypto.Encrypt(mek, value, []byte(path))
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: encrypt %q: %w", path, err)
|
||||
}
|
||||
@@ -159,9 +253,394 @@ func (b *AESGCMBarrier) List(ctx context.Context, prefix string) ([]string, erro
|
||||
if err := rows.Scan(&p); err != nil {
|
||||
return nil, fmt.Errorf("barrier: list scan: %w", err)
|
||||
}
|
||||
// Strip the prefix and return just the next segment.
|
||||
remainder := strings.TrimPrefix(p, prefix)
|
||||
paths = append(paths, remainder)
|
||||
}
|
||||
return paths, rows.Err()
|
||||
}
|
||||
|
||||
// CreateKey generates a new DEK for the given key ID, wraps it with MEK,
|
||||
// and stores it in the barrier_keys table.
|
||||
func (b *AESGCMBarrier) CreateKey(ctx context.Context, keyID string) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
if _, exists := b.keys[keyID]; exists {
|
||||
return nil // Already exists.
|
||||
}
|
||||
|
||||
dek, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: generate DEK %q: %w", keyID, err)
|
||||
}
|
||||
|
||||
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return fmt.Errorf("barrier: wrap DEK %q: %w", keyID, err)
|
||||
}
|
||||
|
||||
_, err = b.db.ExecContext(ctx, `
|
||||
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
|
||||
VALUES (?, 1, ?)
|
||||
ON CONFLICT(key_id) DO NOTHING`,
|
||||
keyID, encDEK)
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return fmt.Errorf("barrier: store DEK %q: %w", keyID, err)
|
||||
}
|
||||
|
||||
b.keys[keyID] = dek
|
||||
return nil
|
||||
}
|
||||
|
||||
// RotateKey generates a new DEK for the given key ID and re-encrypts all
|
||||
// barrier entries under that key ID's prefix with the new DEK.
|
||||
func (b *AESGCMBarrier) RotateKey(ctx context.Context, keyID string) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
oldDEK, ok := b.keys[keyID]
|
||||
if !ok {
|
||||
return fmt.Errorf("barrier: %w: %q", ErrKeyNotFound, keyID)
|
||||
}
|
||||
|
||||
// Generate new DEK.
|
||||
newDEK, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: generate DEK: %w", err)
|
||||
}
|
||||
|
||||
// Wrap new DEK with MEK.
|
||||
encDEK, err := crypto.Encrypt(b.mek, newDEK, []byte(keyID))
|
||||
if err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: wrap DEK: %w", err)
|
||||
}
|
||||
|
||||
// Determine the prefix for entries encrypted with this key.
|
||||
prefix := keyID + "/"
|
||||
if keyID == "system" {
|
||||
// System key covers non-engine paths. Re-encrypt everything
|
||||
// that doesn't start with "engine/".
|
||||
prefix = ""
|
||||
}
|
||||
|
||||
// Re-encrypt all entries under this key ID.
|
||||
tx, err := b.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Update the key in barrier_keys.
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
UPDATE barrier_keys SET encrypted_dek = ?, version = version + 1, rotated_at = datetime('now')
|
||||
WHERE key_id = ?`, encDEK, keyID)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: update key: %w", err)
|
||||
}
|
||||
|
||||
// Fetch and re-encrypt entries.
|
||||
var query string
|
||||
var args []interface{}
|
||||
if keyID == "system" {
|
||||
query = "SELECT path, value FROM barrier_entries WHERE path NOT LIKE 'engine/%'"
|
||||
} else {
|
||||
query = "SELECT path, value FROM barrier_entries WHERE path LIKE ?"
|
||||
args = append(args, prefix+"%")
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: query entries: %w", err)
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
path string
|
||||
value []byte
|
||||
}
|
||||
var entries []entry
|
||||
for rows.Next() {
|
||||
var e entry
|
||||
if err := rows.Scan(&e.path, &e.value); err != nil {
|
||||
_ = rows.Close()
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: scan entry: %w", err)
|
||||
}
|
||||
entries = append(entries, e)
|
||||
}
|
||||
_ = rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: rows error: %w", err)
|
||||
}
|
||||
|
||||
for _, e := range entries {
|
||||
// Decrypt with old DEK (handle v1 or v2).
|
||||
var plaintext []byte
|
||||
if len(e.value) > 0 && e.value[0] == crypto.BarrierVersionV2 {
|
||||
pt, _, decErr := crypto.DecryptV2(oldDEK, e.value, []byte(e.path))
|
||||
if decErr != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: decrypt %q during rotation: %w", e.path, decErr)
|
||||
}
|
||||
plaintext = pt
|
||||
} else {
|
||||
// v1: encrypted with MEK.
|
||||
pt, decErr := crypto.Decrypt(b.mek, e.value, []byte(e.path))
|
||||
if decErr != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: decrypt v1 %q during rotation: %w", e.path, decErr)
|
||||
}
|
||||
plaintext = pt
|
||||
}
|
||||
|
||||
// Re-encrypt with new DEK using v2 format.
|
||||
newCiphertext, encErr := crypto.EncryptV2(newDEK, keyID, plaintext, []byte(e.path))
|
||||
if encErr != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: re-encrypt %q: %w", e.path, encErr)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx,
|
||||
"UPDATE barrier_entries SET value = ?, updated_at = datetime('now') WHERE path = ?",
|
||||
newCiphertext, e.path)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: update %q: %w", e.path, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
crypto.Zeroize(newDEK)
|
||||
return fmt.Errorf("barrier: commit rotation: %w", err)
|
||||
}
|
||||
|
||||
// Swap the in-memory key.
|
||||
crypto.Zeroize(oldDEK)
|
||||
b.keys[keyID] = newDEK
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListKeys returns metadata about all registered barrier keys.
|
||||
func (b *AESGCMBarrier) ListKeys(ctx context.Context) ([]KeyInfo, error) {
|
||||
b.mu.RLock()
|
||||
mek := b.mek
|
||||
b.mu.RUnlock()
|
||||
if mek == nil {
|
||||
return nil, ErrSealed
|
||||
}
|
||||
|
||||
rows, err := b.db.QueryContext(ctx,
|
||||
"SELECT key_id, version, created_at, rotated_at FROM barrier_keys ORDER BY key_id")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("barrier: list keys: %w", err)
|
||||
}
|
||||
defer func() { _ = rows.Close() }()
|
||||
|
||||
var keys []KeyInfo
|
||||
for rows.Next() {
|
||||
var ki KeyInfo
|
||||
if err := rows.Scan(&ki.KeyID, &ki.Version, &ki.CreatedAt, &ki.RotatedAt); err != nil {
|
||||
return nil, fmt.Errorf("barrier: scan key info: %w", err)
|
||||
}
|
||||
keys = append(keys, ki)
|
||||
}
|
||||
return keys, rows.Err()
|
||||
}
|
||||
|
||||
// MigrateToV2 creates per-engine DEKs and re-encrypts entries from v1
|
||||
// (MEK-encrypted) to v2 (DEK-encrypted) format. On first call after upgrade,
|
||||
// it creates a "system" DEK equal to the MEK for zero-cost backward compat,
|
||||
// then creates per-engine DEKs and re-encrypts those entries.
|
||||
func (b *AESGCMBarrier) MigrateToV2(ctx context.Context) (int, error) {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.mek == nil {
|
||||
return 0, ErrSealed
|
||||
}
|
||||
|
||||
// Ensure the "system" key exists.
|
||||
if _, ok := b.keys["system"]; !ok {
|
||||
if err := b.createKeyLocked(ctx, "system"); err != nil {
|
||||
return 0, fmt.Errorf("barrier: create system DEK: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Find all entries still in v1 format.
|
||||
rows, err := b.db.QueryContext(ctx, "SELECT path, value FROM barrier_entries")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("barrier: query entries: %w", err)
|
||||
}
|
||||
|
||||
type entry struct {
|
||||
path string
|
||||
value []byte
|
||||
}
|
||||
var toMigrate []entry
|
||||
for rows.Next() {
|
||||
var e entry
|
||||
if err := rows.Scan(&e.path, &e.value); err != nil {
|
||||
_ = rows.Close()
|
||||
return 0, fmt.Errorf("barrier: scan: %w", err)
|
||||
}
|
||||
if len(e.value) > 0 && e.value[0] == crypto.BarrierVersionV1 {
|
||||
toMigrate = append(toMigrate, e)
|
||||
}
|
||||
}
|
||||
_ = rows.Close()
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(toMigrate) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
tx, err := b.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("barrier: begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
migrated := 0
|
||||
for _, e := range toMigrate {
|
||||
// Decrypt with MEK (v1).
|
||||
plaintext, decErr := crypto.Decrypt(b.mek, e.value, []byte(e.path))
|
||||
if decErr != nil {
|
||||
return migrated, fmt.Errorf("barrier: decrypt %q: %w", e.path, decErr)
|
||||
}
|
||||
|
||||
keyID := resolveKeyID(e.path)
|
||||
|
||||
// Ensure the DEK exists for this key ID.
|
||||
if _, ok := b.keys[keyID]; !ok {
|
||||
if err := b.createKeyLockedTx(ctx, tx, keyID); err != nil {
|
||||
return migrated, fmt.Errorf("barrier: create DEK %q: %w", keyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
dek := b.keys[keyID]
|
||||
newCiphertext, encErr := crypto.EncryptV2(dek, keyID, plaintext, []byte(e.path))
|
||||
if encErr != nil {
|
||||
return migrated, fmt.Errorf("barrier: encrypt v2 %q: %w", e.path, encErr)
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx,
|
||||
"UPDATE barrier_entries SET value = ?, updated_at = datetime('now') WHERE path = ?",
|
||||
newCiphertext, e.path)
|
||||
if err != nil {
|
||||
return migrated, fmt.Errorf("barrier: update %q: %w", e.path, err)
|
||||
}
|
||||
migrated++
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return migrated, fmt.Errorf("barrier: commit migration: %w", err)
|
||||
}
|
||||
return migrated, nil
|
||||
}
|
||||
|
||||
// createKeyLocked generates and stores a new DEK. Caller must hold b.mu.
|
||||
func (b *AESGCMBarrier) createKeyLocked(ctx context.Context, keyID string) error {
|
||||
dek, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = b.db.ExecContext(ctx, `
|
||||
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
|
||||
VALUES (?, 1, ?) ON CONFLICT(key_id) DO NOTHING`, keyID, encDEK)
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return err
|
||||
}
|
||||
|
||||
b.keys[keyID] = dek
|
||||
return nil
|
||||
}
|
||||
|
||||
// createKeyLockedTx is like createKeyLocked but uses an existing transaction.
|
||||
func (b *AESGCMBarrier) createKeyLockedTx(ctx context.Context, tx *sql.Tx, keyID string) error {
|
||||
dek, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encDEK, err := crypto.Encrypt(b.mek, dek, []byte(keyID))
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.ExecContext(ctx, `
|
||||
INSERT INTO barrier_keys (key_id, version, encrypted_dek)
|
||||
VALUES (?, 1, ?) ON CONFLICT(key_id) DO NOTHING`, keyID, encDEK)
|
||||
if err != nil {
|
||||
crypto.Zeroize(dek)
|
||||
return err
|
||||
}
|
||||
|
||||
b.keys[keyID] = dek
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReWrapKeys re-encrypts all DEKs with a new MEK. Called during MEK rotation.
|
||||
// The new MEK is already set in b.mek by the caller.
|
||||
func (b *AESGCMBarrier) ReWrapKeys(ctx context.Context, newMEK []byte) error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if b.mek == nil {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
tx, err := b.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
for keyID, dek := range b.keys {
|
||||
encDEK, err := crypto.Encrypt(newMEK, dek, []byte(keyID))
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: re-wrap key %q: %w", keyID, err)
|
||||
}
|
||||
_, err = tx.ExecContext(ctx,
|
||||
"UPDATE barrier_keys SET encrypted_dek = ?, rotated_at = datetime('now') WHERE key_id = ?",
|
||||
encDEK, keyID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("barrier: update key %q: %w", keyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return fmt.Errorf("barrier: commit re-wrap: %w", err)
|
||||
}
|
||||
|
||||
// Update the MEK in memory.
|
||||
crypto.Zeroize(b.mek)
|
||||
k := make([]byte, len(newMEK))
|
||||
copy(k, newMEK)
|
||||
b.mek = k
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -158,3 +158,309 @@ func TestBarrierOverwrite(t *testing.T) {
|
||||
t.Fatalf("overwrite: got %q, want %q", got, "v2")
|
||||
}
|
||||
}
|
||||
|
||||
// --- DEK / Key Registry Tests ---
|
||||
|
||||
func TestBarrierCreateKey(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
if err := b.CreateKey(ctx, "engine/ca/prod"); err != nil {
|
||||
t.Fatalf("CreateKey: %v", err)
|
||||
}
|
||||
|
||||
// Duplicate should be a no-op.
|
||||
if err := b.CreateKey(ctx, "engine/ca/prod"); err != nil {
|
||||
t.Fatalf("CreateKey duplicate: %v", err)
|
||||
}
|
||||
|
||||
keys, err := b.ListKeys(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("ListKeys: %v", err)
|
||||
}
|
||||
if len(keys) != 1 {
|
||||
t.Fatalf("expected 1 key, got %d", len(keys))
|
||||
}
|
||||
if keys[0].KeyID != "engine/ca/prod" {
|
||||
t.Fatalf("key ID: got %q, want %q", keys[0].KeyID, "engine/ca/prod")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierDEKEncryption(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Create a DEK for ca/prod.
|
||||
_ = b.CreateKey(ctx, "engine/ca/prod")
|
||||
|
||||
// Write data under the engine path — should use DEK.
|
||||
data := []byte("engine secret data")
|
||||
if err := b.Put(ctx, "engine/ca/prod/config.json", data); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
got, err := b.Get(ctx, "engine/ca/prod/config.json")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("roundtrip: got %q, want %q", got, data)
|
||||
}
|
||||
|
||||
// Verify the raw ciphertext is v2 format.
|
||||
var raw []byte
|
||||
err = b.db.QueryRowContext(ctx,
|
||||
"SELECT value FROM barrier_entries WHERE path = ?",
|
||||
"engine/ca/prod/config.json").Scan(&raw)
|
||||
if err != nil {
|
||||
t.Fatalf("read raw: %v", err)
|
||||
}
|
||||
if raw[0] != crypto.BarrierVersionV2 {
|
||||
t.Fatalf("expected v2 ciphertext, got version %d", raw[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierV1FallbackWithoutDEK(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Write system data without any DEK — should use v1 with MEK.
|
||||
data := []byte("system data")
|
||||
if err := b.Put(ctx, "policy/rule1", data); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
got, err := b.Get(ctx, "policy/rule1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("roundtrip: got %q, want %q", got, data)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierSystemDEK(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Create system DEK.
|
||||
_ = b.CreateKey(ctx, "system")
|
||||
|
||||
// Write system data — should use system DEK with v2 format.
|
||||
data := []byte("system with dek")
|
||||
if err := b.Put(ctx, "policy/rule1", data); err != nil {
|
||||
t.Fatalf("Put: %v", err)
|
||||
}
|
||||
|
||||
got, err := b.Get(ctx, "policy/rule1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get: %v", err)
|
||||
}
|
||||
if string(got) != string(data) {
|
||||
t.Fatalf("roundtrip: got %q, want %q", got, data)
|
||||
}
|
||||
|
||||
// Verify v2 format.
|
||||
var raw []byte
|
||||
_ = b.db.QueryRowContext(ctx,
|
||||
"SELECT value FROM barrier_entries WHERE path = ?",
|
||||
"policy/rule1").Scan(&raw)
|
||||
if raw[0] != crypto.BarrierVersionV2 {
|
||||
t.Fatalf("expected v2 ciphertext, got version %d", raw[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierRotateKey(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
_ = b.CreateKey(ctx, "engine/ca/prod")
|
||||
|
||||
// Write some data.
|
||||
_ = b.Put(ctx, "engine/ca/prod/cert1", []byte("cert-data-1"))
|
||||
_ = b.Put(ctx, "engine/ca/prod/cert2", []byte("cert-data-2"))
|
||||
|
||||
// Rotate the key.
|
||||
if err := b.RotateKey(ctx, "engine/ca/prod"); err != nil {
|
||||
t.Fatalf("RotateKey: %v", err)
|
||||
}
|
||||
|
||||
// Data should still be readable.
|
||||
got, err := b.Get(ctx, "engine/ca/prod/cert1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after rotation: %v", err)
|
||||
}
|
||||
if string(got) != "cert-data-1" {
|
||||
t.Fatalf("data corrupted after rotation: %q", got)
|
||||
}
|
||||
|
||||
got2, err := b.Get(ctx, "engine/ca/prod/cert2")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after rotation: %v", err)
|
||||
}
|
||||
if string(got2) != "cert-data-2" {
|
||||
t.Fatalf("data corrupted after rotation: %q", got2)
|
||||
}
|
||||
|
||||
// Check key version incremented.
|
||||
keys, _ := b.ListKeys(ctx)
|
||||
for _, k := range keys {
|
||||
if k.KeyID == "engine/ca/prod" && k.Version != 2 {
|
||||
t.Fatalf("expected version 2 after rotation, got %d", k.Version)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierRotateKeyNotFound(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
err := b.RotateKey(ctx, "nonexistent")
|
||||
if !errors.Is(err, ErrKeyNotFound) {
|
||||
t.Fatalf("expected ErrKeyNotFound, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierMigrateToV2(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Write v1 data (no DEKs registered, so it uses MEK).
|
||||
_ = b.Put(ctx, "policy/rule1", []byte("policy-data"))
|
||||
_ = b.Put(ctx, "engine/ca/prod/config", []byte("ca-config"))
|
||||
_ = b.Put(ctx, "engine/transit/main/key1", []byte("transit-key"))
|
||||
|
||||
// Migrate.
|
||||
migrated, err := b.MigrateToV2(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateToV2: %v", err)
|
||||
}
|
||||
if migrated != 3 {
|
||||
t.Fatalf("expected 3 entries migrated, got %d", migrated)
|
||||
}
|
||||
|
||||
// All data should still be readable.
|
||||
got, err := b.Get(ctx, "policy/rule1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get policy after migration: %v", err)
|
||||
}
|
||||
if string(got) != "policy-data" {
|
||||
t.Fatalf("policy data: got %q", got)
|
||||
}
|
||||
|
||||
got, err = b.Get(ctx, "engine/ca/prod/config")
|
||||
if err != nil {
|
||||
t.Fatalf("Get engine data after migration: %v", err)
|
||||
}
|
||||
if string(got) != "ca-config" {
|
||||
t.Fatalf("engine data: got %q", got)
|
||||
}
|
||||
|
||||
// Running again should migrate 0 (all already v2).
|
||||
migrated2, err := b.MigrateToV2(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("MigrateToV2 second run: %v", err)
|
||||
}
|
||||
if migrated2 != 0 {
|
||||
t.Fatalf("expected 0 entries on second migration, got %d", migrated2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierSealUnsealPreservesDEKs(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Create DEK and write data.
|
||||
_ = b.CreateKey(ctx, "engine/ca/prod")
|
||||
_ = b.Put(ctx, "engine/ca/prod/secret", []byte("my-secret"))
|
||||
|
||||
// Seal and unseal.
|
||||
_ = b.Seal()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
// Data should still be readable (DEKs reloaded from barrier_keys).
|
||||
got, err := b.Get(ctx, "engine/ca/prod/secret")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after seal/unseal: %v", err)
|
||||
}
|
||||
if string(got) != "my-secret" {
|
||||
t.Fatalf("data after seal/unseal: got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBarrierReWrapKeys(t *testing.T) {
|
||||
b, cleanup := setupBarrier(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
mek, _ := crypto.GenerateKey()
|
||||
_ = b.Unseal(mek)
|
||||
|
||||
_ = b.CreateKey(ctx, "system")
|
||||
_ = b.CreateKey(ctx, "engine/ca/prod")
|
||||
|
||||
_ = b.Put(ctx, "policy/rule1", []byte("policy"))
|
||||
_ = b.Put(ctx, "engine/ca/prod/cert", []byte("cert"))
|
||||
|
||||
// Re-wrap with new MEK.
|
||||
newMEK, _ := crypto.GenerateKey()
|
||||
if err := b.ReWrapKeys(ctx, newMEK); err != nil {
|
||||
t.Fatalf("ReWrapKeys: %v", err)
|
||||
}
|
||||
|
||||
// Data should still be readable.
|
||||
got, _ := b.Get(ctx, "policy/rule1")
|
||||
if string(got) != "policy" {
|
||||
t.Fatalf("policy after rewrap: got %q", got)
|
||||
}
|
||||
got2, _ := b.Get(ctx, "engine/ca/prod/cert")
|
||||
if string(got2) != "cert" {
|
||||
t.Fatalf("cert after rewrap: got %q", got2)
|
||||
}
|
||||
|
||||
// Seal and unseal with new MEK should work.
|
||||
_ = b.Seal()
|
||||
if err := b.Unseal(newMEK); err != nil {
|
||||
t.Fatalf("Unseal with new MEK: %v", err)
|
||||
}
|
||||
|
||||
got3, err := b.Get(ctx, "engine/ca/prod/cert")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after unseal with new MEK: %v", err)
|
||||
}
|
||||
if string(got3) != "cert" {
|
||||
t.Fatalf("data after new MEK unseal: got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,8 +20,16 @@ const (
|
||||
// SaltSize is the size of Argon2id salts in bytes.
|
||||
SaltSize = 32
|
||||
|
||||
// BarrierVersion is the version byte prefix for encrypted barrier entries.
|
||||
BarrierVersion byte = 0x01
|
||||
// BarrierVersionV1 is the v1 format: [version][nonce][ciphertext+tag].
|
||||
BarrierVersionV1 byte = 0x01
|
||||
// BarrierVersionV2 is the v2 format: [version][key_id_len][key_id][nonce][ciphertext+tag].
|
||||
BarrierVersionV2 byte = 0x02
|
||||
|
||||
// BarrierVersion is kept for backward compatibility (alias for V1).
|
||||
BarrierVersion = BarrierVersionV1
|
||||
|
||||
// MaxKeyIDLen is the maximum length of a key ID in the v2 format.
|
||||
MaxKeyIDLen = 255
|
||||
|
||||
// Default Argon2id parameters.
|
||||
DefaultArgon2Time = 3
|
||||
@@ -32,6 +40,7 @@ const (
|
||||
var (
|
||||
ErrInvalidCiphertext = errors.New("crypto: invalid ciphertext")
|
||||
ErrDecryptionFailed = errors.New("crypto: decryption failed")
|
||||
ErrKeyIDTooLong = errors.New("crypto: key ID exceeds maximum length")
|
||||
)
|
||||
|
||||
// Argon2Params holds Argon2id KDF parameters.
|
||||
@@ -74,8 +83,10 @@ func GenerateSalt() ([]byte, error) {
|
||||
}
|
||||
|
||||
// Encrypt encrypts plaintext with AES-256-GCM using the given key.
|
||||
// The additionalData parameter is authenticated but not encrypted (AAD);
|
||||
// pass nil when no binding context is needed.
|
||||
// Returns: [version byte][12-byte nonce][ciphertext+tag]
|
||||
func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
func Encrypt(key, plaintext, additionalData []byte) ([]byte, error) {
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
@@ -88,7 +99,7 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
|
||||
}
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, nil)
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
|
||||
|
||||
// Format: [version][nonce][ciphertext+tag]
|
||||
result := make([]byte, 1+NonceSize+len(ciphertext))
|
||||
@@ -99,7 +110,8 @@ func Encrypt(key, plaintext []byte) ([]byte, error) {
|
||||
}
|
||||
|
||||
// Decrypt decrypts ciphertext produced by Encrypt.
|
||||
func Decrypt(key, data []byte) ([]byte, error) {
|
||||
// The additionalData must match the value provided during encryption.
|
||||
func Decrypt(key, data, additionalData []byte) ([]byte, error) {
|
||||
if len(data) < 1+NonceSize+aes.BlockSize {
|
||||
return nil, ErrInvalidCiphertext
|
||||
}
|
||||
@@ -117,13 +129,114 @@ func Decrypt(key, data []byte) ([]byte, error) {
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil)
|
||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
return nil, ErrDecryptionFailed
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// EncryptV2 encrypts plaintext with AES-256-GCM, embedding a key ID in the ciphertext.
|
||||
// Format: [0x02][key_id_len:1][key_id:N][nonce:12][ciphertext+tag]
|
||||
func EncryptV2(key []byte, keyID string, plaintext, additionalData []byte) ([]byte, error) {
|
||||
if len(keyID) > MaxKeyIDLen {
|
||||
return nil, ErrKeyIDTooLong
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
nonce := make([]byte, NonceSize)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return nil, fmt.Errorf("crypto: generate nonce: %w", err)
|
||||
}
|
||||
ciphertext := gcm.Seal(nil, nonce, plaintext, additionalData)
|
||||
|
||||
kidLen := len(keyID)
|
||||
// Format: [version][key_id_len][key_id][nonce][ciphertext+tag]
|
||||
result := make([]byte, 1+1+kidLen+NonceSize+len(ciphertext))
|
||||
result[0] = BarrierVersionV2
|
||||
result[1] = byte(kidLen)
|
||||
copy(result[2:2+kidLen], keyID)
|
||||
copy(result[2+kidLen:2+kidLen+NonceSize], nonce)
|
||||
copy(result[2+kidLen+NonceSize:], ciphertext)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// DecryptV2 decrypts ciphertext that may be in v1 or v2 format.
|
||||
// For v2 format, it extracts the key ID and returns it alongside the plaintext.
|
||||
// For v1 format, it returns an empty key ID.
|
||||
func DecryptV2(key, data, additionalData []byte) (plaintext []byte, keyID string, err error) {
|
||||
if len(data) < 1 {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
|
||||
switch data[0] {
|
||||
case BarrierVersionV1:
|
||||
pt, err := Decrypt(key, data, additionalData)
|
||||
return pt, "", err
|
||||
|
||||
case BarrierVersionV2:
|
||||
if len(data) < 2 {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
kidLen := int(data[1])
|
||||
headerLen := 2 + kidLen
|
||||
if len(data) < headerLen+NonceSize+aes.BlockSize {
|
||||
return nil, "", ErrInvalidCiphertext
|
||||
}
|
||||
|
||||
keyID = string(data[2 : 2+kidLen])
|
||||
nonce := data[headerLen : headerLen+NonceSize]
|
||||
ciphertext := data[headerLen+NonceSize:]
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("crypto: new cipher: %w", err)
|
||||
}
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("crypto: new gcm: %w", err)
|
||||
}
|
||||
pt, err := gcm.Open(nil, nonce, ciphertext, additionalData)
|
||||
if err != nil {
|
||||
return nil, "", ErrDecryptionFailed
|
||||
}
|
||||
return pt, keyID, nil
|
||||
|
||||
default:
|
||||
return nil, "", fmt.Errorf("crypto: unsupported version: %d", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
// ExtractKeyID returns the key ID from a v2 ciphertext without decrypting.
|
||||
// Returns empty string for v1 ciphertext.
|
||||
func ExtractKeyID(data []byte) (string, error) {
|
||||
if len(data) < 1 {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
switch data[0] {
|
||||
case BarrierVersionV1:
|
||||
return "", nil
|
||||
case BarrierVersionV2:
|
||||
if len(data) < 2 {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
kidLen := int(data[1])
|
||||
if len(data) < 2+kidLen {
|
||||
return "", ErrInvalidCiphertext
|
||||
}
|
||||
return string(data[2 : 2+kidLen]), nil
|
||||
default:
|
||||
return "", fmt.Errorf("crypto: unsupported version: %d", data[0])
|
||||
}
|
||||
}
|
||||
|
||||
// Zeroize overwrites a byte slice with zeros.
|
||||
func Zeroize(b []byte) {
|
||||
for i := range b {
|
||||
|
||||
@@ -34,7 +34,7 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("hello, metacrypt!")
|
||||
|
||||
ciphertext, err := Encrypt(key, plaintext)
|
||||
ciphertext, err := Encrypt(key, plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt: %v", err)
|
||||
}
|
||||
@@ -44,7 +44,7 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersion)
|
||||
}
|
||||
|
||||
decrypted, err := Decrypt(key, ciphertext)
|
||||
decrypted, err := Decrypt(key, ciphertext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt: %v", err)
|
||||
}
|
||||
@@ -54,13 +54,45 @@ func TestEncryptDecrypt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptDecryptWithAAD(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("hello, metacrypt!")
|
||||
aad := []byte("engine/ca/pki/root/cert.pem")
|
||||
|
||||
ciphertext, err := Encrypt(key, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt with AAD: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt with correct AAD succeeds.
|
||||
decrypted, err := Decrypt(key, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("Decrypt with AAD: %v", err)
|
||||
}
|
||||
if !bytes.Equal(plaintext, decrypted) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", decrypted, plaintext)
|
||||
}
|
||||
|
||||
// Decrypt with wrong AAD fails.
|
||||
_, err = Decrypt(key, ciphertext, []byte("wrong/path"))
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
||||
}
|
||||
|
||||
// Decrypt with nil AAD fails.
|
||||
_, err = Decrypt(key, ciphertext, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with nil AAD, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecryptWrongKey(t *testing.T) {
|
||||
key1, _ := GenerateKey()
|
||||
key2, _ := GenerateKey()
|
||||
plaintext := []byte("secret data")
|
||||
|
||||
ciphertext, _ := Encrypt(key1, plaintext)
|
||||
_, err := Decrypt(key2, ciphertext)
|
||||
ciphertext, _ := Encrypt(key1, plaintext, nil)
|
||||
_, err := Decrypt(key2, ciphertext, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
||||
}
|
||||
@@ -68,7 +100,7 @@ func TestDecryptWrongKey(t *testing.T) {
|
||||
|
||||
func TestDecryptInvalidCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
_, err := Decrypt(key, []byte("short"))
|
||||
_, err := Decrypt(key, []byte("short"), nil)
|
||||
if !errors.Is(err, ErrInvalidCiphertext) {
|
||||
t.Fatalf("expected ErrInvalidCiphertext, got: %v", err)
|
||||
}
|
||||
@@ -124,10 +156,141 @@ func TestEncryptProducesDifferentCiphertext(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("same data")
|
||||
|
||||
ct1, _ := Encrypt(key, plaintext)
|
||||
ct2, _ := Encrypt(key, plaintext)
|
||||
ct1, _ := Encrypt(key, plaintext, nil)
|
||||
ct2, _ := Encrypt(key, plaintext, nil)
|
||||
|
||||
if bytes.Equal(ct1, ct2) {
|
||||
t.Fatal("two encryptions of same plaintext produced identical ciphertext (nonce reuse)")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2EncryptDecryptRoundtrip(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("v2 test data")
|
||||
keyID := "engine/ca/prod"
|
||||
aad := []byte("engine/ca/prod/config.json")
|
||||
|
||||
ciphertext, err := EncryptV2(key, keyID, plaintext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptV2: %v", err)
|
||||
}
|
||||
|
||||
if ciphertext[0] != BarrierVersionV2 {
|
||||
t.Fatalf("version byte: got %d, want %d", ciphertext[0], BarrierVersionV2)
|
||||
}
|
||||
|
||||
pt, gotKeyID, err := DecryptV2(key, ciphertext, aad)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2: %v", err)
|
||||
}
|
||||
if gotKeyID != keyID {
|
||||
t.Fatalf("key ID: got %q, want %q", gotKeyID, keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2DecryptV1Compat(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("v1 legacy data")
|
||||
|
||||
// Encrypt with v1.
|
||||
v1ct, err := Encrypt(key, plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("Encrypt v1: %v", err)
|
||||
}
|
||||
|
||||
// DecryptV2 should handle v1 ciphertext.
|
||||
pt, keyID, err := DecryptV2(key, v1ct, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2 with v1 ciphertext: %v", err)
|
||||
}
|
||||
if keyID != "" {
|
||||
t.Fatalf("expected empty key ID for v1, got %q", keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed: got %q, want %q", pt, plaintext)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2WrongAAD(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("data")
|
||||
aad := []byte("correct/path")
|
||||
|
||||
ct, _ := EncryptV2(key, "system", plaintext, aad)
|
||||
|
||||
_, _, err := DecryptV2(key, ct, []byte("wrong/path"))
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed with wrong AAD, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2WrongKey(t *testing.T) {
|
||||
key1, _ := GenerateKey()
|
||||
key2, _ := GenerateKey()
|
||||
plaintext := []byte("data")
|
||||
|
||||
ct, _ := EncryptV2(key1, "system", plaintext, nil)
|
||||
|
||||
_, _, err := DecryptV2(key2, ct, nil)
|
||||
if !errors.Is(err, ErrDecryptionFailed) {
|
||||
t.Fatalf("expected ErrDecryptionFailed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractKeyID(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
|
||||
// v1: empty key ID.
|
||||
v1ct, _ := Encrypt(key, []byte("data"), nil)
|
||||
kid, err := ExtractKeyID(v1ct)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractKeyID v1: %v", err)
|
||||
}
|
||||
if kid != "" {
|
||||
t.Fatalf("expected empty key ID for v1, got %q", kid)
|
||||
}
|
||||
|
||||
// v2: embedded key ID.
|
||||
v2ct, _ := EncryptV2(key, "engine/transit/main", []byte("data"), nil)
|
||||
kid, err = ExtractKeyID(v2ct)
|
||||
if err != nil {
|
||||
t.Fatalf("ExtractKeyID v2: %v", err)
|
||||
}
|
||||
if kid != "engine/transit/main" {
|
||||
t.Fatalf("key ID: got %q, want %q", kid, "engine/transit/main")
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2KeyIDTooLong(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
longID := string(make([]byte, MaxKeyIDLen+1))
|
||||
|
||||
_, err := EncryptV2(key, longID, []byte("data"), nil)
|
||||
if !errors.Is(err, ErrKeyIDTooLong) {
|
||||
t.Fatalf("expected ErrKeyIDTooLong, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestV2EmptyKeyID(t *testing.T) {
|
||||
key, _ := GenerateKey()
|
||||
plaintext := []byte("data with empty key id")
|
||||
|
||||
ct, err := EncryptV2(key, "", plaintext, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("EncryptV2 empty key ID: %v", err)
|
||||
}
|
||||
|
||||
pt, keyID, err := DecryptV2(key, ct, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("DecryptV2 empty key ID: %v", err)
|
||||
}
|
||||
if keyID != "" {
|
||||
t.Fatalf("expected empty key ID, got %q", keyID)
|
||||
}
|
||||
if !bytes.Equal(plaintext, pt) {
|
||||
t.Fatalf("roundtrip failed")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ func TestOpenAndMigrate(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify tables exist.
|
||||
tables := []string{"seal_config", "barrier_entries", "schema_migrations"}
|
||||
tables := []string{"seal_config", "barrier_entries", "schema_migrations", "barrier_keys"}
|
||||
for _, table := range tables {
|
||||
var name string
|
||||
err := database.QueryRow(
|
||||
@@ -38,7 +38,7 @@ func TestOpenAndMigrate(t *testing.T) {
|
||||
// Check migration version.
|
||||
var version int
|
||||
_ = database.QueryRow("SELECT MAX(version) FROM schema_migrations").Scan(&version)
|
||||
if version != 1 {
|
||||
t.Errorf("migration version: got %d, want 1", version)
|
||||
if version != 2 {
|
||||
t.Errorf("migration version: got %d, want 2", version)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,6 +30,15 @@ var migrations = []string{
|
||||
version INTEGER PRIMARY KEY,
|
||||
applied_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);`,
|
||||
|
||||
// Version 2: barrier key registry for per-engine DEKs
|
||||
`CREATE TABLE IF NOT EXISTS barrier_keys (
|
||||
key_id TEXT PRIMARY KEY,
|
||||
version INTEGER NOT NULL DEFAULT 1,
|
||||
encrypted_dek BLOB NOT NULL,
|
||||
created_at DATETIME NOT NULL DEFAULT (datetime('now')),
|
||||
rotated_at DATETIME NOT NULL DEFAULT (datetime('now'))
|
||||
);`,
|
||||
}
|
||||
|
||||
// Migrate applies all pending migrations.
|
||||
|
||||
@@ -147,6 +147,14 @@ func (r *Registry) Mount(ctx context.Context, name string, engineType EngineType
|
||||
eng := factory()
|
||||
mountPath := fmt.Sprintf("engine/%s/%s/", engineType, name)
|
||||
|
||||
// Create a per-engine DEK in the barrier for this mount.
|
||||
if aesBarrier, ok := r.barrier.(*barrier.AESGCMBarrier); ok {
|
||||
dekKeyID := fmt.Sprintf("engine/%s/%s", engineType, name)
|
||||
if err := aesBarrier.CreateKey(ctx, dekKeyID); err != nil {
|
||||
return fmt.Errorf("engine: create DEK %q: %w", dekKeyID, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := eng.Initialize(ctx, r.barrier, mountPath, config); err != nil {
|
||||
return fmt.Errorf("engine: initialize %q: %w", name, err)
|
||||
}
|
||||
|
||||
85
internal/grpcserver/barrier.go
Normal file
85
internal/grpcserver/barrier.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/seal"
|
||||
)
|
||||
|
||||
type barrierServer struct {
|
||||
pb.UnimplementedBarrierServiceServer
|
||||
s *GRPCServer
|
||||
}
|
||||
|
||||
func (bs *barrierServer) ListKeys(ctx context.Context, _ *pb.ListKeysRequest) (*pb.ListKeysResponse, error) {
|
||||
keys, err := bs.s.sealMgr.Barrier().ListKeys(ctx)
|
||||
if err != nil {
|
||||
bs.s.logger.Error("grpc: list barrier keys", "error", err)
|
||||
return nil, status.Error(codes.Internal, "failed to list keys")
|
||||
}
|
||||
|
||||
resp := &pb.ListKeysResponse{}
|
||||
for _, k := range keys {
|
||||
resp.Keys = append(resp.Keys, &pb.BarrierKeyInfo{
|
||||
KeyId: k.KeyID,
|
||||
Version: int32(k.Version),
|
||||
CreatedAt: k.CreatedAt,
|
||||
RotatedAt: k.RotatedAt,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (bs *barrierServer) RotateMEK(ctx context.Context, req *pb.RotateMEKRequest) (*pb.RotateMEKResponse, error) {
|
||||
if req.Password == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "password is required")
|
||||
}
|
||||
|
||||
if err := bs.s.sealMgr.RotateMEK(ctx, []byte(req.Password)); err != nil {
|
||||
if errors.Is(err, seal.ErrInvalidPassword) {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid password")
|
||||
}
|
||||
if errors.Is(err, seal.ErrSealed) {
|
||||
return nil, status.Error(codes.FailedPrecondition, "sealed")
|
||||
}
|
||||
bs.s.logger.Error("grpc: rotate MEK", "error", err)
|
||||
return nil, status.Error(codes.Internal, "rotation failed")
|
||||
}
|
||||
|
||||
bs.s.logger.Info("audit: MEK rotated")
|
||||
return &pb.RotateMEKResponse{Ok: true}, nil
|
||||
}
|
||||
|
||||
func (bs *barrierServer) RotateKey(ctx context.Context, req *pb.RotateKeyRequest) (*pb.RotateKeyResponse, error) {
|
||||
if req.KeyId == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "key_id is required")
|
||||
}
|
||||
|
||||
if err := bs.s.sealMgr.Barrier().RotateKey(ctx, req.KeyId); err != nil {
|
||||
if errors.Is(err, barrier.ErrKeyNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "key not found")
|
||||
}
|
||||
bs.s.logger.Error("grpc: rotate key", "key_id", req.KeyId, "error", err)
|
||||
return nil, status.Error(codes.Internal, "rotation failed")
|
||||
}
|
||||
|
||||
bs.s.logger.Info("audit: DEK rotated", "key_id", req.KeyId)
|
||||
return &pb.RotateKeyResponse{Ok: true}, nil
|
||||
}
|
||||
|
||||
func (bs *barrierServer) Migrate(ctx context.Context, _ *pb.MigrateBarrierRequest) (*pb.MigrateBarrierResponse, error) {
|
||||
migrated, err := bs.s.sealMgr.Barrier().MigrateToV2(ctx)
|
||||
if err != nil {
|
||||
bs.s.logger.Error("grpc: barrier migration", "error", err)
|
||||
return nil, status.Error(codes.Internal, "migration failed")
|
||||
}
|
||||
|
||||
bs.s.logger.Info("audit: barrier migrated to v2", "entries_migrated", migrated)
|
||||
return &pb.MigrateBarrierResponse{Migrated: int32(migrated)}, nil
|
||||
}
|
||||
@@ -60,7 +60,7 @@ func (s *GRPCServer) Start() error {
|
||||
}
|
||||
tlsCfg := &tls.Config{
|
||||
Certificates: []tls.Certificate{tlsCert},
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
creds := credentials.NewTLS(tlsCfg)
|
||||
|
||||
@@ -81,6 +81,7 @@ func (s *GRPCServer) Start() error {
|
||||
pb.RegisterPKIServiceServer(s.srv, &pkiServer{s: s})
|
||||
pb.RegisterCAServiceServer(s.srv, &caServer{s: s})
|
||||
pb.RegisterPolicyServiceServer(s.srv, &policyServer{s: s})
|
||||
pb.RegisterBarrierServiceServer(s.srv, &barrierServer{s: s})
|
||||
pb.RegisterACMEServiceServer(s.srv, &acmeServer{s: s})
|
||||
|
||||
lis, err := net.Listen("tcp", s.cfg.Server.GRPCAddr)
|
||||
@@ -136,7 +137,11 @@ func sealRequiredMethods() map[string]bool {
|
||||
"/metacrypt.v2.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.BarrierService/ListKeys": true,
|
||||
"/metacrypt.v2.BarrierService/RotateMEK": true,
|
||||
"/metacrypt.v2.BarrierService/RotateKey": true,
|
||||
"/metacrypt.v2.BarrierService/Migrate": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -166,7 +171,11 @@ func authRequiredMethods() map[string]bool {
|
||||
"/metacrypt.v2.ACMEService/CreateEAB": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.BarrierService/ListKeys": true,
|
||||
"/metacrypt.v2.BarrierService/RotateMEK": true,
|
||||
"/metacrypt.v2.BarrierService/RotateKey": true,
|
||||
"/metacrypt.v2.BarrierService/Migrate": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,9 +191,15 @@ func adminRequiredMethods() map[string]bool {
|
||||
"/metacrypt.v2.CAService/RevokeCert": true,
|
||||
"/metacrypt.v2.CAService/DeleteCert": true,
|
||||
"/metacrypt.v2.PolicyService/CreatePolicy": true,
|
||||
"/metacrypt.v2.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v2.PolicyService/ListPolicies": true,
|
||||
"/metacrypt.v2.PolicyService/GetPolicy": true,
|
||||
"/metacrypt.v2.PolicyService/DeletePolicy": true,
|
||||
"/metacrypt.v2.ACMEService/SetConfig": true,
|
||||
"/metacrypt.v2.ACMEService/ListAccounts": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.ACMEService/ListOrders": true,
|
||||
"/metacrypt.v2.BarrierService/ListKeys": true,
|
||||
"/metacrypt.v2.BarrierService/RotateMEK": true,
|
||||
"/metacrypt.v2.BarrierService/RotateKey": true,
|
||||
"/metacrypt.v2.BarrierService/Migrate": true,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,6 +22,38 @@ const (
|
||||
EffectDeny Effect = "deny"
|
||||
)
|
||||
|
||||
// Action constants for policy evaluation.
|
||||
const (
|
||||
ActionAny = "any" // matches all non-admin actions
|
||||
ActionRead = "read" // retrieve/list operations
|
||||
ActionWrite = "write" // create/update/delete operations
|
||||
ActionEncrypt = "encrypt" // encrypt data
|
||||
ActionDecrypt = "decrypt" // decrypt data
|
||||
ActionSign = "sign" // sign data
|
||||
ActionVerify = "verify" // verify signatures
|
||||
ActionHMAC = "hmac" // compute HMAC
|
||||
ActionAdmin = "admin" // administrative operations (never matched by "any")
|
||||
)
|
||||
|
||||
// validEffects is the set of recognized effects.
|
||||
var validEffects = map[Effect]bool{
|
||||
EffectAllow: true,
|
||||
EffectDeny: true,
|
||||
}
|
||||
|
||||
// validActions is the set of recognized actions.
|
||||
var validActions = map[string]bool{
|
||||
ActionAny: true,
|
||||
ActionRead: true,
|
||||
ActionWrite: true,
|
||||
ActionEncrypt: true,
|
||||
ActionDecrypt: true,
|
||||
ActionSign: true,
|
||||
ActionVerify: true,
|
||||
ActionHMAC: true,
|
||||
ActionAdmin: true,
|
||||
}
|
||||
|
||||
// Rule is a policy rule stored in the barrier.
|
||||
type Rule struct {
|
||||
ID string `json:"id"`
|
||||
@@ -88,8 +120,34 @@ func (e *Engine) Match(ctx context.Context, req *Request) (Effect, bool, error)
|
||||
return EffectDeny, false, nil // default deny, no matching rule
|
||||
}
|
||||
|
||||
// CreateRule stores a new policy rule.
|
||||
// LintRule validates a rule's effect and actions. It returns a list of problems
|
||||
// (empty if the rule is valid). This does not check resource patterns or other
|
||||
// fields — only the enumerated values that must come from a known set.
|
||||
func LintRule(rule *Rule) []string {
|
||||
var problems []string
|
||||
|
||||
if rule.ID == "" {
|
||||
problems = append(problems, "rule ID is required")
|
||||
}
|
||||
|
||||
if !validEffects[rule.Effect] {
|
||||
problems = append(problems, fmt.Sprintf("invalid effect %q (must be %q or %q)", rule.Effect, EffectAllow, EffectDeny))
|
||||
}
|
||||
|
||||
for _, a := range rule.Actions {
|
||||
if !validActions[strings.ToLower(a)] {
|
||||
problems = append(problems, fmt.Sprintf("invalid action %q", a))
|
||||
}
|
||||
}
|
||||
|
||||
return problems
|
||||
}
|
||||
|
||||
// CreateRule validates and stores a new policy rule.
|
||||
func (e *Engine) CreateRule(ctx context.Context, rule *Rule) error {
|
||||
if problems := LintRule(rule); len(problems) > 0 {
|
||||
return fmt.Errorf("policy: invalid rule: %s", strings.Join(problems, "; "))
|
||||
}
|
||||
data, err := json.Marshal(rule)
|
||||
if err != nil {
|
||||
return fmt.Errorf("policy: marshal rule: %w", err)
|
||||
@@ -157,14 +215,28 @@ func matchesRule(rule *Rule, req *Request) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// Check action match.
|
||||
if len(rule.Actions) > 0 && !containsString(rule.Actions, req.Action) {
|
||||
// Check action match. The "any" action matches all non-admin actions.
|
||||
if len(rule.Actions) > 0 && !matchesAction(rule.Actions, req.Action) {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// matchesAction checks whether any of the rule's actions match the requested action.
|
||||
// The special "any" action matches all actions except "admin".
|
||||
func matchesAction(ruleActions []string, reqAction string) bool {
|
||||
for _, a := range ruleActions {
|
||||
if strings.EqualFold(a, reqAction) {
|
||||
return true
|
||||
}
|
||||
if strings.EqualFold(a, ActionAny) && !strings.EqualFold(reqAction, ActionAdmin) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func containsString(haystack []string, needle string) bool {
|
||||
for _, s := range haystack {
|
||||
if strings.EqualFold(s, needle) {
|
||||
|
||||
@@ -141,6 +141,96 @@ func TestPolicyPriorityOrder(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestActionAnyMatchesNonAdmin(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
_ = e.CreateRule(ctx, &Rule{
|
||||
ID: "any-rule",
|
||||
Priority: 100,
|
||||
Effect: EffectAllow,
|
||||
Roles: []string{"user"},
|
||||
Resources: []string{"transit/*"},
|
||||
Actions: []string{ActionAny},
|
||||
})
|
||||
|
||||
// "any" should match encrypt, decrypt, sign, verify, hmac, read, write.
|
||||
for _, action := range []string{ActionEncrypt, ActionDecrypt, ActionSign, ActionVerify, ActionHMAC, ActionRead, ActionWrite} {
|
||||
effect, _ := e.Evaluate(ctx, &Request{
|
||||
Username: "alice",
|
||||
Roles: []string{"user"},
|
||||
Resource: "transit/default",
|
||||
Action: action,
|
||||
})
|
||||
if effect != EffectAllow {
|
||||
t.Errorf("action %q should be allowed by 'any', got: %s", action, effect)
|
||||
}
|
||||
}
|
||||
|
||||
// "any" must NOT match "admin".
|
||||
effect, _ := e.Evaluate(ctx, &Request{
|
||||
Username: "alice",
|
||||
Roles: []string{"user"},
|
||||
Resource: "transit/default",
|
||||
Action: ActionAdmin,
|
||||
})
|
||||
if effect != EffectDeny {
|
||||
t.Fatalf("action 'admin' should not be matched by 'any', got: %s", effect)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLintRule(t *testing.T) {
|
||||
// Valid rule.
|
||||
problems := LintRule(&Rule{
|
||||
ID: "ok",
|
||||
Effect: EffectAllow,
|
||||
Actions: []string{ActionAny, ActionEncrypt},
|
||||
})
|
||||
if len(problems) > 0 {
|
||||
t.Errorf("expected no problems, got: %v", problems)
|
||||
}
|
||||
|
||||
// Missing ID.
|
||||
problems = LintRule(&Rule{Effect: EffectAllow})
|
||||
if len(problems) != 1 {
|
||||
t.Errorf("expected 1 problem for missing ID, got: %v", problems)
|
||||
}
|
||||
|
||||
// Invalid effect.
|
||||
problems = LintRule(&Rule{ID: "bad-effect", Effect: "maybe"})
|
||||
if len(problems) != 1 {
|
||||
t.Errorf("expected 1 problem for bad effect, got: %v", problems)
|
||||
}
|
||||
|
||||
// Invalid action.
|
||||
problems = LintRule(&Rule{ID: "bad-action", Effect: EffectAllow, Actions: []string{"destroy"}})
|
||||
if len(problems) != 1 {
|
||||
t.Errorf("expected 1 problem for bad action, got: %v", problems)
|
||||
}
|
||||
|
||||
// Multiple problems.
|
||||
problems = LintRule(&Rule{Effect: "bogus", Actions: []string{"nope"}})
|
||||
if len(problems) != 3 { // missing ID + bad effect + bad action
|
||||
t.Errorf("expected 3 problems, got: %v", problems)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRuleRejectsInvalid(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
|
||||
err := e.CreateRule(ctx, &Rule{
|
||||
ID: "bad",
|
||||
Effect: EffectAllow,
|
||||
Actions: []string{"obliterate"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid action, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPolicyUsernameMatch(t *testing.T) {
|
||||
e, cleanup := setupPolicy(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -141,7 +141,7 @@ func (m *Manager) Initialize(ctx context.Context, password []byte, params crypto
|
||||
defer crypto.Zeroize(kwk)
|
||||
|
||||
// Encrypt MEK with KWK.
|
||||
encryptedMEK, err := crypto.Encrypt(kwk, mek)
|
||||
encryptedMEK, err := crypto.Encrypt(kwk, mek, nil)
|
||||
if err != nil {
|
||||
crypto.Zeroize(mek)
|
||||
return fmt.Errorf("seal: encrypt mek: %w", err)
|
||||
@@ -220,7 +220,7 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
kwk := crypto.DeriveKey(password, salt, params)
|
||||
defer crypto.Zeroize(kwk)
|
||||
|
||||
mek, err := crypto.Decrypt(kwk, encryptedMEK)
|
||||
mek, err := crypto.Decrypt(kwk, encryptedMEK, nil)
|
||||
if err != nil {
|
||||
m.logger.Debug("unseal failed: invalid password")
|
||||
return ErrInvalidPassword
|
||||
@@ -239,6 +239,79 @@ func (m *Manager) Unseal(password []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RotateMEK generates a new MEK, re-wraps all DEKs in the barrier, and
|
||||
// updates the encrypted MEK in seal_config. The password is required to
|
||||
// derive the KWK for re-encrypting the new MEK.
|
||||
func (m *Manager) RotateMEK(ctx context.Context, password []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if m.state != StateUnsealed {
|
||||
return ErrSealed
|
||||
}
|
||||
|
||||
// Read seal config for KDF params.
|
||||
var (
|
||||
salt []byte
|
||||
argTime, argMem uint32
|
||||
argThreads uint8
|
||||
)
|
||||
err := m.db.QueryRow(`
|
||||
SELECT kdf_salt, argon2_time, argon2_memory, argon2_threads
|
||||
FROM seal_config WHERE id = 1`).Scan(&salt, &argTime, &argMem, &argThreads)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: read config: %w", err)
|
||||
}
|
||||
|
||||
// Verify password by decrypting existing MEK.
|
||||
params := crypto.Argon2Params{Time: argTime, Memory: argMem, Threads: argThreads}
|
||||
kwk := crypto.DeriveKey(password, salt, params)
|
||||
defer crypto.Zeroize(kwk)
|
||||
|
||||
var encryptedMEK []byte
|
||||
err = m.db.QueryRow("SELECT encrypted_mek FROM seal_config WHERE id = 1").Scan(&encryptedMEK)
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: read encrypted mek: %w", err)
|
||||
}
|
||||
_, err = crypto.Decrypt(kwk, encryptedMEK, nil)
|
||||
if err != nil {
|
||||
return ErrInvalidPassword
|
||||
}
|
||||
|
||||
// Generate new MEK.
|
||||
newMEK, err := crypto.GenerateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("seal: generate new mek: %w", err)
|
||||
}
|
||||
|
||||
// Re-wrap all DEKs with new MEK.
|
||||
if err := m.barrier.ReWrapKeys(ctx, newMEK); err != nil {
|
||||
crypto.Zeroize(newMEK)
|
||||
return fmt.Errorf("seal: re-wrap keys: %w", err)
|
||||
}
|
||||
|
||||
// Encrypt new MEK with KWK.
|
||||
newEncMEK, err := crypto.Encrypt(kwk, newMEK, nil)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newMEK)
|
||||
return fmt.Errorf("seal: encrypt new mek: %w", err)
|
||||
}
|
||||
|
||||
// Update seal_config.
|
||||
_, err = m.db.ExecContext(ctx,
|
||||
"UPDATE seal_config SET encrypted_mek = ? WHERE id = 1", newEncMEK)
|
||||
if err != nil {
|
||||
crypto.Zeroize(newMEK)
|
||||
return fmt.Errorf("seal: update seal config: %w", err)
|
||||
}
|
||||
|
||||
// Swap in-memory MEK.
|
||||
crypto.Zeroize(m.mek)
|
||||
m.mek = newMEK
|
||||
m.logger.Info("MEK rotated successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Seal seals the service: zeroizes MEK, seals the barrier.
|
||||
func (m *Manager) Seal() error {
|
||||
m.mu.Lock()
|
||||
|
||||
@@ -120,6 +120,94 @@ func TestSealCheckInitializedPersists(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealRotateMEK(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
_ = mgr.CheckInitialized()
|
||||
|
||||
password := []byte("test-password")
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
_ = mgr.Initialize(ctx, password, params)
|
||||
|
||||
// Create a DEK and write data through the barrier.
|
||||
b := mgr.Barrier()
|
||||
_ = b.CreateKey(ctx, "system")
|
||||
_ = b.CreateKey(ctx, "engine/ca/prod")
|
||||
_ = b.Put(ctx, "policy/rule1", []byte("policy-data"))
|
||||
_ = b.Put(ctx, "engine/ca/prod/cert", []byte("cert-data"))
|
||||
|
||||
// Rotate MEK.
|
||||
if err := mgr.RotateMEK(ctx, password); err != nil {
|
||||
t.Fatalf("RotateMEK: %v", err)
|
||||
}
|
||||
|
||||
// Data should still be readable.
|
||||
got, err := b.Get(ctx, "policy/rule1")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after MEK rotation: %v", err)
|
||||
}
|
||||
if string(got) != "policy-data" {
|
||||
t.Fatalf("data: got %q", got)
|
||||
}
|
||||
|
||||
got2, err := b.Get(ctx, "engine/ca/prod/cert")
|
||||
if err != nil {
|
||||
t.Fatalf("Get engine data after MEK rotation: %v", err)
|
||||
}
|
||||
if string(got2) != "cert-data" {
|
||||
t.Fatalf("data: got %q", got2)
|
||||
}
|
||||
|
||||
// Seal and unseal with the same password should work
|
||||
// (the new MEK is now encrypted with the KWK).
|
||||
if err := mgr.Seal(); err != nil {
|
||||
t.Fatalf("Seal: %v", err)
|
||||
}
|
||||
if err := mgr.Unseal(password); err != nil {
|
||||
t.Fatalf("Unseal after MEK rotation: %v", err)
|
||||
}
|
||||
|
||||
got3, err := b.Get(ctx, "engine/ca/prod/cert")
|
||||
if err != nil {
|
||||
t.Fatalf("Get after seal/unseal: %v", err)
|
||||
}
|
||||
if string(got3) != "cert-data" {
|
||||
t.Fatalf("data after seal/unseal: got %q", got3)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealRotateMEKWrongPassword(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
_ = mgr.CheckInitialized()
|
||||
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
_ = mgr.Initialize(ctx, []byte("correct"), params)
|
||||
|
||||
err := mgr.RotateMEK(ctx, []byte("wrong"))
|
||||
if !errors.Is(err, ErrInvalidPassword) {
|
||||
t.Fatalf("expected ErrInvalidPassword, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealRotateMEKWhenSealed(t *testing.T) {
|
||||
mgr, cleanup := setupSeal(t)
|
||||
defer cleanup()
|
||||
ctx := context.Background()
|
||||
_ = mgr.CheckInitialized()
|
||||
|
||||
params := crypto.Argon2Params{Time: 1, Memory: 64 * 1024, Threads: 1}
|
||||
_ = mgr.Initialize(ctx, []byte("password"), params)
|
||||
_ = mgr.Seal()
|
||||
|
||||
err := mgr.RotateMEK(ctx, []byte("password"))
|
||||
if !errors.Is(err, ErrSealed) {
|
||||
t.Fatalf("expected ErrSealed, got: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSealStateString(t *testing.T) {
|
||||
tests := []struct {
|
||||
want string
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
|
||||
mcias "git.wntrmute.dev/kyle/mcias/clients/go"
|
||||
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/crypto"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
|
||||
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
|
||||
@@ -45,6 +46,11 @@ func (s *Server) registerRoutes(r chi.Router) {
|
||||
r.Get("/v1/pki/{mount}/issuer/{name}", s.requireUnseal(s.handlePKIIssuer))
|
||||
r.Get("/v1/pki/{mount}/issuer/{name}/crl", s.requireUnseal(s.handlePKICRL))
|
||||
|
||||
r.Get("/v1/barrier/keys", s.requireAdmin(s.handleBarrierKeys))
|
||||
r.Post("/v1/barrier/rotate-mek", s.requireAdmin(s.handleRotateMEK))
|
||||
r.Post("/v1/barrier/rotate-key", s.requireAdmin(s.handleRotateKey))
|
||||
r.Post("/v1/barrier/migrate", s.requireAdmin(s.handleBarrierMigrate))
|
||||
|
||||
r.HandleFunc("/v1/policy/rules", s.requireAuth(s.handlePolicyRules))
|
||||
r.HandleFunc("/v1/policy/rule", s.requireAuth(s.handlePolicyRule))
|
||||
s.registerACMERoutes(r)
|
||||
@@ -253,6 +259,31 @@ func (s *Server) handleEngineUnmount(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
}
|
||||
|
||||
// adminOnlyOperations lists engine operations that require admin role.
|
||||
// This enforces the same gates as the typed REST routes, ensuring the
|
||||
// generic endpoint cannot bypass admin requirements.
|
||||
var adminOnlyOperations = map[string]bool{
|
||||
// CA engine.
|
||||
"import-root": true,
|
||||
"create-issuer": true,
|
||||
"delete-issuer": true,
|
||||
"revoke-cert": true,
|
||||
"delete-cert": true,
|
||||
// Transit engine.
|
||||
"create-key": true,
|
||||
"delete-key": true,
|
||||
"rotate-key": true,
|
||||
"update-key-config": true,
|
||||
"trim-key": true,
|
||||
// SSH CA engine.
|
||||
"create-profile": true,
|
||||
"update-profile": true,
|
||||
"delete-profile": true,
|
||||
// User engine.
|
||||
"provision": true,
|
||||
"delete-user": true,
|
||||
}
|
||||
|
||||
func (s *Server) handleEngineRequest(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Data map[string]interface{} `json:"data"`
|
||||
@@ -271,6 +302,12 @@ func (s *Server) handleEngineRequest(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
info := TokenInfoFromContext(r.Context())
|
||||
|
||||
// Enforce admin requirement for operations that have admin-only typed routes.
|
||||
if adminOnlyOperations[req.Operation] && !info.IsAdmin {
|
||||
http.Error(w, `{"error":"forbidden: admin required"}`, http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Evaluate policy before dispatching to the engine.
|
||||
policyReq := &policy.Request{
|
||||
Username: info.Username,
|
||||
@@ -412,6 +449,90 @@ func (s *Server) handlePolicyRule(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// --- Barrier Key Management Handlers ---
|
||||
|
||||
func (s *Server) handleBarrierKeys(w http.ResponseWriter, r *http.Request) {
|
||||
keys, err := s.seal.Barrier().ListKeys(r.Context())
|
||||
if err != nil {
|
||||
s.logger.Error("list barrier keys", "error", err)
|
||||
http.Error(w, `{"error":"internal error"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if keys == nil {
|
||||
keys = []barrier.KeyInfo{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, keys)
|
||||
}
|
||||
|
||||
func (s *Server) handleRotateMEK(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.Password == "" {
|
||||
http.Error(w, `{"error":"password is required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seal.RotateMEK(r.Context(), []byte(req.Password)); err != nil {
|
||||
if errors.Is(err, seal.ErrInvalidPassword) {
|
||||
http.Error(w, `{"error":"invalid password"}`, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
if errors.Is(err, seal.ErrSealed) {
|
||||
http.Error(w, `{"error":"sealed"}`, http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
s.logger.Error("rotate MEK", "error", err)
|
||||
http.Error(w, `{"error":"rotation failed"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
}
|
||||
|
||||
func (s *Server) handleRotateKey(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
KeyID string `json:"key_id"`
|
||||
}
|
||||
if err := readJSON(r, &req); err != nil {
|
||||
http.Error(w, `{"error":"invalid request"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
if req.KeyID == "" {
|
||||
http.Error(w, `{"error":"key_id is required"}`, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.seal.Barrier().RotateKey(r.Context(), req.KeyID); err != nil {
|
||||
if errors.Is(err, barrier.ErrKeyNotFound) {
|
||||
http.Error(w, `{"error":"key not found"}`, http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
s.logger.Error("rotate key", "key_id", req.KeyID, "error", err)
|
||||
http.Error(w, `{"error":"rotation failed"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{"ok": true})
|
||||
}
|
||||
|
||||
func (s *Server) handleBarrierMigrate(w http.ResponseWriter, r *http.Request) {
|
||||
migrated, err := s.seal.Barrier().MigrateToV2(r.Context())
|
||||
if err != nil {
|
||||
s.logger.Error("barrier migration", "error", err)
|
||||
http.Error(w, `{"error":"migration failed"}`, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"migrated": migrated,
|
||||
})
|
||||
}
|
||||
|
||||
// --- CA Certificate Handlers ---
|
||||
|
||||
func (s *Server) handleGetCert(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -608,13 +729,29 @@ func (s *Server) getCAEngine(mountName string) (*ca.CAEngine, error) {
|
||||
return caEng, nil
|
||||
}
|
||||
|
||||
// operationAction maps an engine operation name to a policy action ("read" or "write").
|
||||
// operationAction maps an engine operation name to a policy action.
|
||||
func operationAction(op string) string {
|
||||
switch op {
|
||||
case "list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer":
|
||||
return "read"
|
||||
// Read operations.
|
||||
case "list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer",
|
||||
"list-keys", "get-key", "get-public-key", "list-users", "get-profile", "list-profiles":
|
||||
return policy.ActionRead
|
||||
|
||||
// Granular cryptographic operations (including batch variants).
|
||||
case "encrypt", "batch-encrypt":
|
||||
return policy.ActionEncrypt
|
||||
case "decrypt", "batch-decrypt":
|
||||
return policy.ActionDecrypt
|
||||
case "sign", "sign-host", "sign-user":
|
||||
return policy.ActionSign
|
||||
case "verify":
|
||||
return policy.ActionVerify
|
||||
case "hmac":
|
||||
return policy.ActionHMAC
|
||||
|
||||
// Everything else is a write.
|
||||
default:
|
||||
return "write"
|
||||
return policy.ActionWrite
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -58,13 +58,7 @@ func (s *Server) Start() error {
|
||||
s.registerRoutes(r)
|
||||
|
||||
tlsCfg := &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
s.httpSrv = &http.Server{
|
||||
|
||||
@@ -241,18 +241,84 @@ func TestEngineRequestPolicyAllowsWithRule(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationAction verifies the read/write classification of operations.
|
||||
func TestOperationAction(t *testing.T) {
|
||||
readOps := []string{"list-issuers", "list-certs", "get-cert", "get-root", "get-chain", "get-issuer"}
|
||||
for _, op := range readOps {
|
||||
if got := operationAction(op); got != "read" {
|
||||
t.Errorf("operationAction(%q) = %q, want %q", op, got, "read")
|
||||
// TestEngineRequestAdminOnlyBlocksNonAdmin verifies that admin-only operations
|
||||
// via the generic endpoint are rejected for non-admin users.
|
||||
func TestEngineRequestAdminOnlyBlocksNonAdmin(t *testing.T) {
|
||||
srv, sealMgr, _ := setupTestServer(t)
|
||||
unsealServer(t, sealMgr, nil)
|
||||
|
||||
for _, op := range []string{"create-issuer", "delete-cert", "create-key", "rotate-key", "create-profile", "provision"} {
|
||||
body := makeEngineRequest("test-mount", op)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/engine/request", strings.NewReader(body))
|
||||
req = withTokenInfo(req, &auth.TokenInfo{Username: "alice", Roles: []string{"user"}, IsAdmin: false})
|
||||
w := httptest.NewRecorder()
|
||||
srv.handleEngineRequest(w, req)
|
||||
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Errorf("operation %q: expected 403 for non-admin, got %d", op, w.Code)
|
||||
}
|
||||
}
|
||||
writeOps := []string{"issue", "renew", "create-issuer", "delete-issuer", "sign-csr", "revoke"}
|
||||
for _, op := range writeOps {
|
||||
if got := operationAction(op); got != "write" {
|
||||
t.Errorf("operationAction(%q) = %q, want %q", op, got, "write")
|
||||
}
|
||||
|
||||
// TestEngineRequestAdminOnlyAllowsAdmin verifies that admin-only operations
|
||||
// via the generic endpoint are allowed for admin users.
|
||||
func TestEngineRequestAdminOnlyAllowsAdmin(t *testing.T) {
|
||||
srv, sealMgr, _ := setupTestServer(t)
|
||||
unsealServer(t, sealMgr, nil)
|
||||
|
||||
for _, op := range []string{"create-issuer", "delete-cert", "create-key", "rotate-key", "create-profile", "provision"} {
|
||||
body := makeEngineRequest("test-mount", op)
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/engine/request", strings.NewReader(body))
|
||||
req = withTokenInfo(req, &auth.TokenInfo{Username: "admin", Roles: []string{"admin"}, IsAdmin: true})
|
||||
w := httptest.NewRecorder()
|
||||
srv.handleEngineRequest(w, req)
|
||||
|
||||
// Admin passes the admin check; will get 404 (mount not found) not 403.
|
||||
if w.Code == http.StatusForbidden {
|
||||
t.Errorf("operation %q: admin should not be forbidden, got 403", op)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestOperationAction verifies the action classification of operations.
|
||||
func TestOperationAction(t *testing.T) {
|
||||
tests := map[string]string{
|
||||
// Read operations.
|
||||
"list-issuers": policy.ActionRead,
|
||||
"list-certs": policy.ActionRead,
|
||||
"get-cert": policy.ActionRead,
|
||||
"get-root": policy.ActionRead,
|
||||
"get-chain": policy.ActionRead,
|
||||
"get-issuer": policy.ActionRead,
|
||||
"list-keys": policy.ActionRead,
|
||||
"get-key": policy.ActionRead,
|
||||
"get-public-key": policy.ActionRead,
|
||||
"list-users": policy.ActionRead,
|
||||
"get-profile": policy.ActionRead,
|
||||
"list-profiles": policy.ActionRead,
|
||||
|
||||
// Granular crypto operations (including batch variants).
|
||||
"encrypt": policy.ActionEncrypt,
|
||||
"batch-encrypt": policy.ActionEncrypt,
|
||||
"decrypt": policy.ActionDecrypt,
|
||||
"batch-decrypt": policy.ActionDecrypt,
|
||||
"sign": policy.ActionSign,
|
||||
"sign-host": policy.ActionSign,
|
||||
"sign-user": policy.ActionSign,
|
||||
"verify": policy.ActionVerify,
|
||||
"hmac": policy.ActionHMAC,
|
||||
|
||||
// Write operations.
|
||||
"issue": policy.ActionWrite,
|
||||
"renew": policy.ActionWrite,
|
||||
"create-issuer": policy.ActionWrite,
|
||||
"delete-issuer": policy.ActionWrite,
|
||||
"sign-csr": policy.ActionWrite,
|
||||
"revoke": policy.ActionWrite,
|
||||
}
|
||||
for op, want := range tests {
|
||||
if got := operationAction(op); got != want {
|
||||
t.Errorf("operationAction(%q) = %q, want %q", op, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -193,6 +193,7 @@ func newTestWebServer(t *testing.T, vault vaultBackend) *WebServer {
|
||||
vault: vault,
|
||||
logger: slog.Default(),
|
||||
staticFS: staticFS,
|
||||
csrf: newCSRFProtect(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ type VaultClient struct {
|
||||
func NewVaultClient(addr, caCertPath string, logger *slog.Logger) (*VaultClient, error) {
|
||||
logger.Debug("connecting to vault", "addr", addr, "ca_cert", caCertPath)
|
||||
|
||||
tlsCfg := &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
tlsCfg := &tls.Config{MinVersion: tls.VersionTLS13}
|
||||
if caCertPath != "" {
|
||||
logger.Debug("loading vault CA certificate", "path", caCertPath)
|
||||
pemData, err := os.ReadFile(caCertPath) //nolint:gosec
|
||||
|
||||
119
internal/webserver/csrf.go
Normal file
119
internal/webserver/csrf.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
csrfCookieName = "metacrypt_csrf"
|
||||
csrfFieldName = "csrf_token"
|
||||
csrfTokenLen = 32
|
||||
)
|
||||
|
||||
// csrfProtect provides CSRF protection using the signed double-submit cookie
|
||||
// pattern. A random secret is generated at startup. CSRF tokens are an HMAC of
|
||||
// a random nonce, sent as both a cookie and a hidden form field. On POST the
|
||||
// middleware verifies that the form field matches the cookie's HMAC.
|
||||
type csrfProtect struct {
|
||||
secret []byte
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newCSRFProtect() *csrfProtect {
|
||||
secret := make([]byte, 32)
|
||||
if _, err := rand.Read(secret); err != nil {
|
||||
panic("csrf: failed to generate secret: " + err.Error())
|
||||
}
|
||||
return &csrfProtect{secret: secret}
|
||||
}
|
||||
|
||||
// generateToken creates a new CSRF token: base64(nonce) + "." + base64(hmac(nonce)).
|
||||
func (c *csrfProtect) generateToken() (string, error) {
|
||||
nonce := make([]byte, csrfTokenLen)
|
||||
if _, err := rand.Read(nonce); err != nil {
|
||||
return "", fmt.Errorf("csrf: generate nonce: %w", err)
|
||||
}
|
||||
nonceB64 := base64.RawURLEncoding.EncodeToString(nonce)
|
||||
mac := hmac.New(sha256.New, c.secret)
|
||||
mac.Write(nonce)
|
||||
sigB64 := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
||||
return nonceB64 + "." + sigB64, nil
|
||||
}
|
||||
|
||||
// validToken checks that a token has a valid HMAC signature.
|
||||
func (c *csrfProtect) validToken(token string) bool {
|
||||
parts := strings.SplitN(token, ".", 2)
|
||||
if len(parts) != 2 {
|
||||
return false
|
||||
}
|
||||
nonce, err := base64.RawURLEncoding.DecodeString(parts[0])
|
||||
if err != nil || len(nonce) != csrfTokenLen {
|
||||
return false
|
||||
}
|
||||
sig, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
mac := hmac.New(sha256.New, c.secret)
|
||||
mac.Write(nonce)
|
||||
return hmac.Equal(mac.Sum(nil), sig)
|
||||
}
|
||||
|
||||
// setToken generates a new CSRF token, sets it as a cookie, and returns it
|
||||
// for embedding in a form.
|
||||
func (c *csrfProtect) setToken(w http.ResponseWriter) string {
|
||||
token, err := c.generateToken()
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: csrfCookieName,
|
||||
Value: token,
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
return token
|
||||
}
|
||||
|
||||
// middleware returns an HTTP middleware that enforces CSRF validation on
|
||||
// mutation requests (POST, PUT, PATCH, DELETE). GET/HEAD/OPTIONS are passed
|
||||
// through. The HTMX hx-post for /v1/seal is excluded since it hits the API
|
||||
// server directly and uses token auth, not cookies.
|
||||
func (c *csrfProtect) middleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.Method {
|
||||
case http.MethodGet, http.MethodHead, http.MethodOptions:
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// Read token from form field (works for both regular forms and
|
||||
// multipart forms since ParseForm/ParseMultipartForm will have
|
||||
// been called or the field is available via FormValue).
|
||||
formToken := r.FormValue(csrfFieldName)
|
||||
|
||||
// Read token from cookie.
|
||||
cookie, err := r.Cookie(csrfCookieName)
|
||||
if err != nil || cookie.Value == "" {
|
||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// Both tokens must be present, match each other, and be validly signed.
|
||||
if formToken == "" || formToken != cookie.Value || !c.validToken(formToken) {
|
||||
http.Error(w, "CSRF validation failed", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
105
internal/webserver/csrf_test.go
Normal file
105
internal/webserver/csrf_test.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package webserver
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCSRFTokenGenerateAndValidate(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
token, err := c.generateToken()
|
||||
if err != nil {
|
||||
t.Fatalf("generateToken: %v", err)
|
||||
}
|
||||
if !c.validToken(token) {
|
||||
t.Fatal("valid token rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFTokenInvalidFormats(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
for _, bad := range []string{"", "nodot", "a.b.c", "abc.def"} {
|
||||
if c.validToken(bad) {
|
||||
t.Errorf("should reject %q", bad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFTokenCrossSecret(t *testing.T) {
|
||||
c1 := newCSRFProtect()
|
||||
c2 := newCSRFProtect()
|
||||
token, _ := c1.generateToken()
|
||||
if c2.validToken(token) {
|
||||
t.Fatal("token from different secret should be rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFMiddlewareAllowsGET(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("GET should pass through, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFMiddlewareBlocksPOSTWithoutToken(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader("foo=bar"))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("POST without CSRF token should be forbidden, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFMiddlewareAllowsPOSTWithValidToken(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
token, _ := c.generateToken()
|
||||
|
||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
body := csrfFieldName + "=" + token
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("POST with valid CSRF token should pass, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCSRFMiddlewareRejectsMismatch(t *testing.T) {
|
||||
c := newCSRFProtect()
|
||||
token1, _ := c.generateToken()
|
||||
token2, _ := c.generateToken()
|
||||
|
||||
handler := c.middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
body := csrfFieldName + "=" + token1
|
||||
req := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.AddCookie(&http.Cookie{Name: csrfCookieName, Value: token2})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
if w.Code != http.StatusForbidden {
|
||||
t.Fatalf("POST with mismatched tokens should be forbidden, got %d", w.Code)
|
||||
}
|
||||
}
|
||||
@@ -188,7 +188,7 @@ func (ws *WebServer) handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
Path: "/",
|
||||
HttpOnly: true,
|
||||
Secure: true,
|
||||
SameSite: http.SameSiteLaxMode,
|
||||
SameSite: http.SameSiteStrictMode,
|
||||
})
|
||||
http.Redirect(w, r, "/dashboard", http.StatusFound)
|
||||
default:
|
||||
|
||||
@@ -70,6 +70,7 @@ type WebServer struct {
|
||||
logger *slog.Logger
|
||||
httpSrv *http.Server
|
||||
staticFS fs.FS
|
||||
csrf *csrfProtect
|
||||
tgzCache sync.Map // key: UUID string → *tgzEntry
|
||||
userCache sync.Map // key: UUID string → *cachedUsername
|
||||
}
|
||||
@@ -125,6 +126,7 @@ func New(cfg *config.Config, logger *slog.Logger) (*WebServer, error) {
|
||||
vault: vault,
|
||||
logger: logger,
|
||||
staticFS: staticFS,
|
||||
csrf: newCSRFProtect(),
|
||||
}
|
||||
|
||||
if tok := cfg.MCIAS.ServiceToken; tok != "" {
|
||||
@@ -188,6 +190,7 @@ func (lw *loggingResponseWriter) Unwrap() http.ResponseWriter {
|
||||
func (ws *WebServer) Start() error {
|
||||
r := chi.NewRouter()
|
||||
r.Use(ws.loggingMiddleware)
|
||||
r.Use(ws.csrf.middleware)
|
||||
ws.registerRoutes(r)
|
||||
|
||||
ws.httpSrv = &http.Server{
|
||||
@@ -201,7 +204,7 @@ func (ws *WebServer) Start() error {
|
||||
ws.logger.Info("starting web server", "addr", ws.cfg.Web.ListenAddr)
|
||||
|
||||
if ws.cfg.Web.TLSCert != "" && ws.cfg.Web.TLSKey != "" {
|
||||
ws.httpSrv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS12}
|
||||
ws.httpSrv.TLSConfig = &tls.Config{MinVersion: tls.VersionTLS13}
|
||||
err := ws.httpSrv.ListenAndServeTLS(ws.cfg.Web.TLSCert, ws.cfg.Web.TLSKey)
|
||||
if err != nil && err != http.ErrServerClosed {
|
||||
return fmt.Errorf("webserver: %w", err)
|
||||
@@ -226,7 +229,18 @@ func (ws *WebServer) Shutdown(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (ws *WebServer) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
|
||||
tmpl, err := template.ParseFS(webui.FS,
|
||||
csrfToken := ws.csrf.setToken(w)
|
||||
|
||||
funcMap := template.FuncMap{
|
||||
"csrfField": func() template.HTML {
|
||||
return template.HTML(fmt.Sprintf(
|
||||
`<input type="hidden" name="%s" value="%s">`,
|
||||
csrfFieldName, template.HTMLEscapeString(csrfToken),
|
||||
))
|
||||
},
|
||||
}
|
||||
|
||||
tmpl, err := template.New("").Funcs(funcMap).ParseFS(webui.FS,
|
||||
"templates/layout.html",
|
||||
"templates/"+name,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user