Files
metacrypt/internal/engine/transit/transit.go
Kyle Isom cbd77c58e8 Implement transit encryption engine with versioned key management
Add complete transit engine supporting symmetric encryption (AES-256-GCM,
XChaCha20-Poly1305), asymmetric signing (Ed25519, ECDSA P-256/P-384),
and HMAC (SHA-256/SHA-512) with versioned key rotation, min decryption
version enforcement, key trimming, batch operations, and rewrap.

Includes proto definitions, gRPC handlers, REST routes, and comprehensive
tests covering all 18 operations, auth enforcement, and edge cases.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-16 19:45:56 -07:00

1603 lines
37 KiB
Go

// Package transit implements the transit encryption engine for symmetric
// encryption, signing, and HMAC operations with versioned key management.
package transit
import (
"context"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"sort"
"strconv"
"strings"
"sync"
"golang.org/x/crypto/chacha20poly1305"
"git.wntrmute.dev/kyle/metacrypt/internal/barrier"
mcrypto "git.wntrmute.dev/kyle/metacrypt/internal/crypto"
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
)
const maxBatchSize = 500
var (
ErrSealed = errors.New("transit: engine is sealed")
ErrKeyNotFound = errors.New("transit: key not found")
ErrKeyExists = errors.New("transit: key already exists")
ErrForbidden = errors.New("transit: forbidden")
ErrUnauthorized = errors.New("transit: authentication required")
ErrDeletionDenied = errors.New("transit: deletion not allowed")
ErrInvalidKeyType = errors.New("transit: invalid key type")
ErrUnsupportedOp = errors.New("transit: unsupported operation for key type")
ErrDecryptVersion = errors.New("transit: ciphertext version below minimum decryption version")
ErrInvalidFormat = errors.New("transit: invalid ciphertext format")
ErrBatchTooLarge = errors.New("transit: batch size exceeds maximum")
ErrInvalidMinVer = errors.New("transit: min_decryption_version can only increase and cannot exceed current version")
)
// keyVersion holds a single version of key material.
type keyVersion struct {
version int
key []byte // symmetric key material
privKey crypto.PrivateKey // asymmetric (nil for symmetric)
pubKey crypto.PublicKey // asymmetric (nil for symmetric)
}
// keyState holds in-memory state for a loaded key.
type keyState struct {
config *KeyConfig
versions map[int]*keyVersion
}
// TransitEngine implements the transit encryption engine.
type TransitEngine struct {
barrier barrier.Barrier
config *TransitConfig
keys map[string]*keyState
mountPath string
mu sync.RWMutex
}
// NewTransitEngine creates a new transit engine instance.
func NewTransitEngine() engine.Engine {
return &TransitEngine{
keys: make(map[string]*keyState),
}
}
func (e *TransitEngine) Type() engine.EngineType {
return engine.EngineTypeTransit
}
// Initialize sets up the transit engine for first use.
func (e *TransitEngine) Initialize(ctx context.Context, b barrier.Barrier, mountPath string, config map[string]interface{}) error {
e.mu.Lock()
defer e.mu.Unlock()
e.barrier = b
e.mountPath = mountPath
cfg := &TransitConfig{}
if config != nil {
if v, ok := config["max_key_versions"]; ok {
switch val := v.(type) {
case float64:
cfg.MaxKeyVersions = int(val)
case int:
cfg.MaxKeyVersions = val
}
}
}
e.config = cfg
configData, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("transit: marshal config: %w", err)
}
if err := b.Put(ctx, mountPath+"config.json", configData); err != nil {
return fmt.Errorf("transit: store config: %w", err)
}
e.keys = make(map[string]*keyState)
return nil
}
// Unseal loads the transit state from the barrier into memory.
func (e *TransitEngine) Unseal(ctx context.Context, b barrier.Barrier, mountPath string) error {
e.mu.Lock()
defer e.mu.Unlock()
e.barrier = b
e.mountPath = mountPath
// Load config.
configData, err := b.Get(ctx, mountPath+"config.json")
if err != nil {
return fmt.Errorf("transit: load config: %w", err)
}
var cfg TransitConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return fmt.Errorf("transit: parse config: %w", err)
}
e.config = &cfg
e.keys = make(map[string]*keyState)
// Load all keys.
keyPaths, err := b.List(ctx, mountPath+"keys/")
if err != nil {
return nil // no keys yet
}
// Collect unique key names from paths like "mykey/config.json", "mykey/v1.key".
keyNames := make(map[string]bool)
for _, p := range keyPaths {
parts := strings.SplitN(p, "/", 2)
if len(parts) > 0 && parts[0] != "" {
keyNames[parts[0]] = true
}
}
for name := range keyNames {
ks, err := e.loadKey(ctx, b, mountPath, name)
if err != nil {
return fmt.Errorf("transit: load key %q: %w", name, err)
}
e.keys[name] = ks
}
return nil
}
func (e *TransitEngine) loadKey(ctx context.Context, b barrier.Barrier, mountPath, name string) (*keyState, error) {
prefix := mountPath + "keys/" + name + "/"
configData, err := b.Get(ctx, prefix+"config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg KeyConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
ks := &keyState{
config: &cfg,
versions: make(map[int]*keyVersion),
}
// Load all versions.
for v := 1; v <= cfg.CurrentVersion; v++ {
kv, err := e.loadKeyVersion(ctx, b, prefix, &cfg, v)
if err != nil {
// Version may have been trimmed; skip.
continue
}
ks.versions[v] = kv
}
return ks, nil
}
func (e *TransitEngine) loadKeyVersion(ctx context.Context, b barrier.Barrier, prefix string, cfg *KeyConfig, version int) (*keyVersion, error) {
path := fmt.Sprintf("%sv%d.key", prefix, version)
data, err := b.Get(ctx, path)
if err != nil {
return nil, err
}
kv := &keyVersion{version: version}
switch cfg.Type {
case "aes256-gcm", "chacha20-poly", "hmac-sha256", "hmac-sha512":
kv.key = data
case "ed25519":
privKey := ed25519.PrivateKey(data)
kv.key = data
kv.privKey = privKey
kv.pubKey = privKey.Public()
case "ecdsa-p256", "ecdsa-p384":
privKey, err := x509.ParsePKCS8PrivateKey(data)
if err != nil {
return nil, fmt.Errorf("parse PKCS8 key: %w", err)
}
ecKey, ok := privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("expected ECDSA key, got %T", privKey)
}
kv.privKey = ecKey
kv.pubKey = &ecKey.PublicKey
default:
return nil, fmt.Errorf("unknown key type: %s", cfg.Type)
}
return kv, nil
}
// Seal zeroizes all in-memory key material.
func (e *TransitEngine) Seal() error {
e.mu.Lock()
defer e.mu.Unlock()
for name, ks := range e.keys {
for _, kv := range ks.versions {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(e.keys, name)
}
e.keys = nil
e.config = nil
return nil
}
// HandleRequest dispatches transit operations.
func (e *TransitEngine) HandleRequest(ctx context.Context, req *engine.Request) (*engine.Response, error) {
switch req.Operation {
case "create-key":
return e.handleCreateKey(ctx, req)
case "delete-key":
return e.handleDeleteKey(ctx, req)
case "get-key":
return e.handleGetKey(ctx, req)
case "list-keys":
return e.handleListKeys(ctx, req)
case "rotate-key":
return e.handleRotateKey(ctx, req)
case "update-key-config":
return e.handleUpdateKeyConfig(ctx, req)
case "trim-key":
return e.handleTrimKey(ctx, req)
case "encrypt":
return e.handleEncrypt(ctx, req)
case "decrypt":
return e.handleDecrypt(ctx, req)
case "rewrap":
return e.handleRewrap(ctx, req)
case "batch-encrypt":
return e.handleBatchEncrypt(ctx, req)
case "batch-decrypt":
return e.handleBatchDecrypt(ctx, req)
case "batch-rewrap":
return e.handleBatchRewrap(ctx, req)
case "sign":
return e.handleSign(ctx, req)
case "verify":
return e.handleVerify(ctx, req)
case "hmac":
return e.handleHMAC(ctx, req)
case "get-public-key":
return e.handleGetPublicKey(ctx, req)
default:
return nil, fmt.Errorf("transit: unknown operation: %s", req.Operation)
}
}
// --- Authorization helpers ---
func (e *TransitEngine) requireAdmin(req *engine.Request) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if !req.CallerInfo.IsAdmin {
return ErrForbidden
}
return nil
}
func (e *TransitEngine) requireUser(req *engine.Request) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if !req.CallerInfo.IsUser() {
return ErrForbidden
}
return nil
}
func (e *TransitEngine) requireUserWithPolicy(req *engine.Request, keyName string) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if req.CallerInfo.IsAdmin {
return nil
}
if !req.CallerInfo.IsUser() {
return ErrForbidden
}
// Check policy for the specific key.
if req.CheckPolicy != nil {
resource := fmt.Sprintf("transit/%s/key/%s", e.mountName(), keyName)
action := operationToAction(req.Operation)
effect, matched := req.CheckPolicy(resource, action)
if matched {
if effect == "allow" {
return nil
}
return ErrForbidden
}
}
// Default: users can access transit operations without explicit policy.
return nil
}
func operationToAction(op string) string {
switch op {
case "get-key", "list-keys", "get-public-key":
return "read"
case "decrypt", "rewrap", "batch-decrypt", "batch-rewrap":
return "decrypt"
default:
return "write"
}
}
// mountName extracts the user-facing mount name from the mount path.
func (e *TransitEngine) mountName() string {
parts := strings.Split(strings.TrimSuffix(e.mountPath, "/"), "/")
if len(parts) >= 3 {
return parts[2]
}
return ""
}
func (e *TransitEngine) sealed() bool {
return e.config == nil
}
// --- Key Management Operations ---
func (e *TransitEngine) handleCreateKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
keyType, _ := req.Data["type"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if keyType == "" {
keyType = "aes256-gcm"
}
if !isValidKeyType(keyType) {
return nil, ErrInvalidKeyType
}
if _, exists := e.keys[name]; exists {
return nil, ErrKeyExists
}
// Generate key version 1.
kv, err := generateKeyVersion(keyType, 1)
if err != nil {
return nil, fmt.Errorf("transit: generate key: %w", err)
}
cfg := &KeyConfig{
Name: name,
Type: keyType,
CurrentVersion: 1,
MinDecryptionVersion: 1,
AllowDeletion: false,
}
// Store config and key.
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, cfg); err != nil {
return nil, err
}
if err := e.storeKeyVersion(ctx, prefix, cfg, kv); err != nil {
return nil, err
}
e.keys[name] = &keyState{
config: cfg,
versions: map[int]*keyVersion{1: kv},
}
return &engine.Response{
Data: map[string]interface{}{
"name": name,
"type": keyType,
"version": 1,
},
}, nil
}
func (e *TransitEngine) handleDeleteKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
if !ks.config.AllowDeletion {
return nil, ErrDeletionDenied
}
// Delete all versions and config from barrier.
prefix := e.mountPath + "keys/" + name + "/"
for v := range ks.versions {
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
}
_ = e.barrier.Delete(ctx, prefix+"config.json")
// Zeroize in-memory material.
for _, kv := range ks.versions {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(e.keys, name)
return &engine.Response{
Data: map[string]interface{}{"ok": true},
}, nil
}
func (e *TransitEngine) handleGetKey(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
versions := make([]int, 0, len(ks.versions))
for v := range ks.versions {
versions = append(versions, v)
}
sort.Ints(versions)
return &engine.Response{
Data: map[string]interface{}{
"name": ks.config.Name,
"type": ks.config.Type,
"current_version": ks.config.CurrentVersion,
"min_decryption_version": ks.config.MinDecryptionVersion,
"allow_deletion": ks.config.AllowDeletion,
"versions": versions,
},
}, nil
}
func (e *TransitEngine) handleListKeys(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
keys := make([]string, 0, len(e.keys))
for name := range e.keys {
keys = append(keys, name)
}
sort.Strings(keys)
return &engine.Response{
Data: map[string]interface{}{"keys": keys},
}, nil
}
func (e *TransitEngine) handleRotateKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
newVersion := ks.config.CurrentVersion + 1
kv, err := generateKeyVersion(ks.config.Type, newVersion)
if err != nil {
return nil, fmt.Errorf("transit: generate key version: %w", err)
}
ks.config.CurrentVersion = newVersion
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, ks.config); err != nil {
return nil, err
}
if err := e.storeKeyVersion(ctx, prefix, ks.config, kv); err != nil {
return nil, err
}
ks.versions[newVersion] = kv
// Prune old versions if max_key_versions is set.
if e.config.MaxKeyVersions > 0 && len(ks.versions) > e.config.MaxKeyVersions {
e.pruneVersions(ctx, ks, prefix)
}
return &engine.Response{
Data: map[string]interface{}{
"name": name,
"version": newVersion,
},
}, nil
}
func (e *TransitEngine) pruneVersions(ctx context.Context, ks *keyState, prefix string) {
versions := make([]int, 0, len(ks.versions))
for v := range ks.versions {
versions = append(versions, v)
}
sort.Ints(versions)
for len(versions) > e.config.MaxKeyVersions {
v := versions[0]
if v >= ks.config.MinDecryptionVersion {
break
}
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
if kv, ok := ks.versions[v]; ok {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(ks.versions, v)
versions = versions[1:]
}
}
func (e *TransitEngine) handleUpdateKeyConfig(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
if v, ok := req.Data["min_decryption_version"]; ok {
newMin := toInt(v)
if newMin < ks.config.MinDecryptionVersion || newMin > ks.config.CurrentVersion {
return nil, ErrInvalidMinVer
}
ks.config.MinDecryptionVersion = newMin
}
if v, ok := req.Data["allow_deletion"]; ok {
if b, ok := v.(bool); ok {
ks.config.AllowDeletion = b
}
}
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, ks.config); err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ok": true},
}, nil
}
func (e *TransitEngine) handleTrimKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
prefix := e.mountPath + "keys/" + name + "/"
trimmed := 0
for v, kv := range ks.versions {
if v < ks.config.MinDecryptionVersion {
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
delete(ks.versions, v)
trimmed++
}
}
return &engine.Response{
Data: map[string]interface{}{
"trimmed": trimmed,
},
}, nil
}
// --- Crypto Operations ---
func (e *TransitEngine) handleEncrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
plaintextB64, _ := req.Data["plaintext"].(string)
contextB64, _ := req.Data["context"].(string)
ciphertext, err := e.encryptWithKey(keyName, plaintextB64, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ciphertext": ciphertext},
}, nil
}
func (e *TransitEngine) encryptWithKey(keyName, plaintextB64, contextB64 string) (string, error) {
ks, ok := e.keys[keyName]
if !ok {
return "", ErrKeyNotFound
}
if !isSymmetric(ks.config.Type) {
return "", ErrUnsupportedOp
}
plaintext, err := base64.StdEncoding.DecodeString(plaintextB64)
if err != nil {
return "", fmt.Errorf("transit: invalid base64 plaintext: %w", err)
}
var aad []byte
if contextB64 != "" {
aad, err = base64.StdEncoding.DecodeString(contextB64)
if err != nil {
return "", fmt.Errorf("transit: invalid base64 context: %w", err)
}
}
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return "", fmt.Errorf("transit: current key version %d not found", currentVersion)
}
encrypted, err := encryptData(ks.config.Type, kv.key, plaintext, aad)
if err != nil {
return "", err
}
return formatCiphertext(currentVersion, encrypted), nil
}
func (e *TransitEngine) handleDecrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ciphertextStr, _ := req.Data["ciphertext"].(string)
contextB64, _ := req.Data["context"].(string)
plaintext, err := e.decryptWithKey(keyName, ciphertextStr, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"plaintext": base64.StdEncoding.EncodeToString(plaintext)},
}, nil
}
func (e *TransitEngine) decryptWithKey(keyName, ciphertextStr, contextB64 string) ([]byte, error) {
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isSymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
version, data, err := parseCiphertext(ciphertextStr)
if err != nil {
return nil, err
}
if version < ks.config.MinDecryptionVersion {
return nil, ErrDecryptVersion
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
var aad []byte
if contextB64 != "" {
aad, err = base64.StdEncoding.DecodeString(contextB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 context: %w", err)
}
}
return decryptData(ks.config.Type, kv.key, data, aad)
}
func (e *TransitEngine) handleRewrap(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ciphertextStr, _ := req.Data["ciphertext"].(string)
contextB64, _ := req.Data["context"].(string)
// Decrypt with old version.
plaintext, err := e.decryptWithKey(keyName, ciphertextStr, contextB64)
if err != nil {
return nil, err
}
// Re-encrypt with latest version (reuse the decoded plaintext as raw bytes).
plaintextB64 := base64.StdEncoding.EncodeToString(plaintext)
newCiphertext, err := e.encryptWithKey(keyName, plaintextB64, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ciphertext": newCiphertext},
}, nil
}
// --- Batch Operations ---
type batchItem struct {
Plaintext string `json:"plaintext"`
Ciphertext string `json:"ciphertext"`
Context string `json:"context"`
Reference string `json:"reference"`
}
type batchResult struct {
Plaintext string `json:"plaintext,omitempty"`
Ciphertext string `json:"ciphertext,omitempty"`
Reference string `json:"reference,omitempty"`
Error string `json:"error,omitempty"`
}
func (e *TransitEngine) handleBatchEncrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
ct, err := e.encryptWithKey(keyName, item.Plaintext, item.Context)
r := batchResult{Reference: item.Reference}
if err != nil {
r.Error = err.Error()
} else {
r.Ciphertext = ct
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
func (e *TransitEngine) handleBatchDecrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
pt, err := e.decryptWithKey(keyName, item.Ciphertext, item.Context)
r := batchResult{Reference: item.Reference}
if err != nil {
r.Error = err.Error()
} else {
r.Plaintext = base64.StdEncoding.EncodeToString(pt)
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
func (e *TransitEngine) handleBatchRewrap(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
r := batchResult{Reference: item.Reference}
// Decrypt with old version.
pt, err := e.decryptWithKey(keyName, item.Ciphertext, item.Context)
if err != nil {
r.Error = err.Error()
results[i] = r
continue
}
// Re-encrypt with latest version.
ptB64 := base64.StdEncoding.EncodeToString(pt)
ct, err := e.encryptWithKey(keyName, ptB64, item.Context)
if err != nil {
r.Error = err.Error()
} else {
r.Ciphertext = ct
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
// --- Sign/Verify Operations ---
func (e *TransitEngine) handleSign(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return nil, fmt.Errorf("transit: current key version %d not found", currentVersion)
}
var sig []byte
switch ks.config.Type {
case "ed25519":
edKey, ok := kv.privKey.(ed25519.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ed25519 key")
}
sig = ed25519.Sign(edKey, input)
case "ecdsa-p256":
ecKey, ok := kv.privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA key")
}
h := sha256.Sum256(input)
sig, err = ecdsa.SignASN1(rand.Reader, ecKey, h[:])
if err != nil {
return nil, fmt.Errorf("transit: sign: %w", err)
}
case "ecdsa-p384":
ecKey, ok := kv.privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA key")
}
h := sha512.Sum384(input)
sig, err = ecdsa.SignASN1(rand.Reader, ecKey, h[:])
if err != nil {
return nil, fmt.Errorf("transit: sign: %w", err)
}
default:
return nil, ErrUnsupportedOp
}
return &engine.Response{
Data: map[string]interface{}{
"signature": formatSignature(currentVersion, sig),
},
}, nil
}
func (e *TransitEngine) handleVerify(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
signatureStr, _ := req.Data["signature"].(string)
version, sigBytes, err := parseVersionedData(signatureStr)
if err != nil {
return nil, fmt.Errorf("transit: invalid signature format: %w", err)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
valid := false
switch ks.config.Type {
case "ed25519":
edPub, ok := kv.pubKey.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ed25519 public key")
}
valid = ed25519.Verify(edPub, input, sigBytes)
case "ecdsa-p256":
ecPub, ok := kv.pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA public key")
}
h := sha256.Sum256(input)
valid = ecdsa.VerifyASN1(ecPub, h[:], sigBytes)
case "ecdsa-p384":
ecPub, ok := kv.pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA public key")
}
h := sha512.Sum384(input)
valid = ecdsa.VerifyASN1(ecPub, h[:], sigBytes)
default:
return nil, ErrUnsupportedOp
}
return &engine.Response{
Data: map[string]interface{}{"valid": valid},
}, nil
}
// --- HMAC Operation ---
func (e *TransitEngine) handleHMAC(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isHMAC(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
// Verify mode: if hmac is provided, verify it.
if hmacStr, ok := req.Data["hmac"].(string); ok && hmacStr != "" {
version, macBytes, err := parseVersionedData(hmacStr)
if err != nil {
return nil, fmt.Errorf("transit: invalid hmac format: %w", err)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
expected := computeHMAC(ks.config.Type, kv.key, input)
valid := hmac.Equal(macBytes, expected)
return &engine.Response{
Data: map[string]interface{}{"valid": valid},
}, nil
}
// Compute mode.
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return nil, fmt.Errorf("transit: current key version %d not found", currentVersion)
}
mac := computeHMAC(ks.config.Type, kv.key, input)
return &engine.Response{
Data: map[string]interface{}{
"hmac": formatHMAC(currentVersion, mac),
},
}, nil
}
// --- Get Public Key ---
func (e *TransitEngine) handleGetPublicKey(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
keyName, _ := req.Data["name"].(string)
if keyName == "" {
return nil, fmt.Errorf("transit: name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
version := ks.config.CurrentVersion
if v, ok := req.Data["version"]; ok {
version = toInt(v)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(kv.pubKey)
if err != nil {
return nil, fmt.Errorf("transit: marshal public key: %w", err)
}
return &engine.Response{
Data: map[string]interface{}{
"public_key": base64.StdEncoding.EncodeToString(pubKeyBytes),
"version": version,
"type": ks.config.Type,
},
}, nil
}
// --- Storage helpers ---
func (e *TransitEngine) storeKeyConfig(ctx context.Context, prefix string, cfg *KeyConfig) error {
data, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("transit: marshal key config: %w", err)
}
return e.barrier.Put(ctx, prefix+"config.json", data)
}
func (e *TransitEngine) storeKeyVersion(ctx context.Context, prefix string, cfg *KeyConfig, kv *keyVersion) error {
path := fmt.Sprintf("%sv%d.key", prefix, kv.version)
var data []byte
switch cfg.Type {
case "aes256-gcm", "chacha20-poly", "hmac-sha256", "hmac-sha512":
data = kv.key
case "ed25519":
data = kv.key // raw 64-byte private key
case "ecdsa-p256", "ecdsa-p384":
var err error
data, err = x509.MarshalPKCS8PrivateKey(kv.privKey)
if err != nil {
return fmt.Errorf("transit: marshal PKCS8 key: %w", err)
}
default:
return fmt.Errorf("transit: unknown key type: %s", cfg.Type)
}
return e.barrier.Put(ctx, path, data)
}
// --- Key generation ---
func generateKeyVersion(keyType string, version int) (*keyVersion, error) {
kv := &keyVersion{version: version}
switch keyType {
case "aes256-gcm":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "chacha20-poly":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "ed25519":
_, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
kv.key = []byte(privKey)
kv.privKey = privKey
kv.pubKey = privKey.Public()
case "ecdsa-p256":
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
kv.privKey = privKey
kv.pubKey = &privKey.PublicKey
case "ecdsa-p384":
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return nil, err
}
kv.privKey = privKey
kv.pubKey = &privKey.PublicKey
case "hmac-sha256":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "hmac-sha512":
key := make([]byte, 64)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
default:
return nil, fmt.Errorf("unknown key type: %s", keyType)
}
return kv, nil
}
// --- Encryption/Decryption helpers ---
func encryptData(keyType string, key, plaintext, aad []byte) ([]byte, error) {
aead, err := newAEAD(keyType, key)
if err != nil {
return nil, err
}
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("transit: generate nonce: %w", err)
}
ciphertext := aead.Seal(nil, nonce, plaintext, aad)
// Format: nonce + ciphertext (includes tag)
result := make([]byte, len(nonce)+len(ciphertext))
copy(result, nonce)
copy(result[len(nonce):], ciphertext)
return result, nil
}
func decryptData(keyType string, key, data, aad []byte) ([]byte, error) {
aead, err := newAEAD(keyType, key)
if err != nil {
return nil, err
}
nonceSize := aead.NonceSize()
if len(data) < nonceSize {
return nil, ErrInvalidFormat
}
nonce := data[:nonceSize]
ciphertext := data[nonceSize:]
plaintext, err := aead.Open(nil, nonce, ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("transit: decryption failed: %w", err)
}
return plaintext, nil
}
func newAEAD(keyType string, key []byte) (cipher.AEAD, error) {
switch keyType {
case "aes256-gcm":
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("transit: new AES cipher: %w", err)
}
return cipher.NewGCM(block)
case "chacha20-poly":
return chacha20poly1305.NewX(key)
default:
return nil, fmt.Errorf("transit: unsupported encryption type: %s", keyType)
}
}
// --- HMAC helpers ---
func computeHMAC(keyType string, key, input []byte) []byte {
var h func() hash.Hash
switch keyType {
case "hmac-sha256":
h = sha256.New
case "hmac-sha512":
h = sha512.New
default:
return nil
}
mac := hmac.New(h, key)
mac.Write(input)
return mac.Sum(nil)
}
// --- Format helpers ---
func formatCiphertext(version int, data []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(data))
}
func formatSignature(version int, sig []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(sig))
}
func formatHMAC(version int, mac []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(mac))
}
func parseCiphertext(s string) (int, []byte, error) {
return parseVersionedData(s)
}
func parseVersionedData(s string) (int, []byte, error) {
parts := strings.SplitN(s, ":", 3)
if len(parts) != 3 || parts[0] != "metacrypt" {
return 0, nil, ErrInvalidFormat
}
if !strings.HasPrefix(parts[1], "v") {
return 0, nil, ErrInvalidFormat
}
version, err := strconv.Atoi(parts[1][1:])
if err != nil {
return 0, nil, ErrInvalidFormat
}
data, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return 0, nil, fmt.Errorf("transit: invalid base64: %w", err)
}
return version, data, nil
}
// --- Type helpers ---
func isValidKeyType(t string) bool {
switch t {
case "aes256-gcm", "chacha20-poly", "ed25519", "ecdsa-p256", "ecdsa-p384", "hmac-sha256", "hmac-sha512":
return true
}
return false
}
func isSymmetric(t string) bool {
return t == "aes256-gcm" || t == "chacha20-poly"
}
func isAsymmetric(t string) bool {
return t == "ed25519" || t == "ecdsa-p256" || t == "ecdsa-p384"
}
func isHMAC(t string) bool {
return t == "hmac-sha256" || t == "hmac-sha512"
}
// --- Utility ---
func toInt(v interface{}) int {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
case int64:
return int(val)
case json.Number:
n, _ := val.Int64()
return int(n)
}
return 0
}
func extractBatchItems(v interface{}) ([]batchItem, error) {
raw, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("transit: items must be an array")
}
items := make([]batchItem, len(raw))
for i, r := range raw {
m, ok := r.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("transit: item %d is not an object", i)
}
items[i].Plaintext, _ = m["plaintext"].(string)
items[i].Ciphertext, _ = m["ciphertext"].(string)
items[i].Context, _ = m["context"].(string)
items[i].Reference, _ = m["reference"].(string)
}
return items, nil
}
func zeroizeKey(key crypto.PrivateKey) {
if key == nil {
return
}
switch k := key.(type) {
case *ecdsa.PrivateKey:
k.D.SetInt64(0)
case ed25519.PrivateKey:
for i := range k {
k[i] = 0
}
}
}