Files
metacrypt/internal/engine/transit/transit.go
Kyle Isom bbe382dc10 Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00

1624 lines
38 KiB
Go

// Package transit implements the transit encryption engine for symmetric
// encryption, signing, and HMAC operations with versioned key management.
package transit
import (
"context"
"crypto"
"crypto/aes"
"crypto/cipher"
"crypto/ecdsa"
"crypto/ed25519"
"crypto/elliptic"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash"
"sort"
"strconv"
"strings"
"sync"
"golang.org/x/crypto/chacha20poly1305"
"git.wntrmute.dev/mc/metacrypt/internal/barrier"
mcrypto "git.wntrmute.dev/mc/metacrypt/internal/crypto"
"git.wntrmute.dev/mc/metacrypt/internal/engine"
)
const maxBatchSize = 500
var (
ErrSealed = errors.New("transit: engine is sealed")
ErrKeyNotFound = errors.New("transit: key not found")
ErrKeyExists = errors.New("transit: key already exists")
ErrForbidden = errors.New("transit: forbidden")
ErrUnauthorized = errors.New("transit: authentication required")
ErrDeletionDenied = errors.New("transit: deletion not allowed")
ErrInvalidKeyType = errors.New("transit: invalid key type")
ErrUnsupportedOp = errors.New("transit: unsupported operation for key type")
ErrDecryptVersion = errors.New("transit: ciphertext version below minimum decryption version")
ErrInvalidFormat = errors.New("transit: invalid ciphertext format")
ErrBatchTooLarge = errors.New("transit: batch size exceeds maximum")
ErrInvalidMinVer = errors.New("transit: min_decryption_version can only increase and cannot exceed current version")
)
// keyVersion holds a single version of key material.
type keyVersion struct {
version int
key []byte // symmetric key material
privKey crypto.PrivateKey // asymmetric (nil for symmetric)
pubKey crypto.PublicKey // asymmetric (nil for symmetric)
}
// keyState holds in-memory state for a loaded key.
type keyState struct {
config *KeyConfig
versions map[int]*keyVersion
}
// TransitEngine implements the transit encryption engine.
type TransitEngine struct {
barrier barrier.Barrier
config *TransitConfig
keys map[string]*keyState
mountPath string
mu sync.RWMutex
}
// NewTransitEngine creates a new transit engine instance.
func NewTransitEngine() engine.Engine {
return &TransitEngine{
keys: make(map[string]*keyState),
}
}
func (e *TransitEngine) Type() engine.EngineType {
return engine.EngineTypeTransit
}
// Initialize sets up the transit engine for first use.
func (e *TransitEngine) Initialize(ctx context.Context, b barrier.Barrier, mountPath string, config map[string]interface{}) error {
e.mu.Lock()
defer e.mu.Unlock()
e.barrier = b
e.mountPath = mountPath
cfg := &TransitConfig{}
if config != nil {
if v, ok := config["max_key_versions"]; ok {
switch val := v.(type) {
case float64:
cfg.MaxKeyVersions = int(val)
case int:
cfg.MaxKeyVersions = val
}
}
}
e.config = cfg
configData, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("transit: marshal config: %w", err)
}
if err := b.Put(ctx, mountPath+"config.json", configData); err != nil {
return fmt.Errorf("transit: store config: %w", err)
}
e.keys = make(map[string]*keyState)
return nil
}
// Unseal loads the transit state from the barrier into memory.
func (e *TransitEngine) Unseal(ctx context.Context, b barrier.Barrier, mountPath string) error {
e.mu.Lock()
defer e.mu.Unlock()
e.barrier = b
e.mountPath = mountPath
// Load config.
configData, err := b.Get(ctx, mountPath+"config.json")
if err != nil {
return fmt.Errorf("transit: load config: %w", err)
}
var cfg TransitConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return fmt.Errorf("transit: parse config: %w", err)
}
e.config = &cfg
e.keys = make(map[string]*keyState)
// Load all keys.
keyPaths, err := b.List(ctx, mountPath+"keys/")
if err != nil {
return nil // no keys yet
}
// Collect unique key names from paths like "mykey/config.json", "mykey/v1.key".
keyNames := make(map[string]bool)
for _, p := range keyPaths {
parts := strings.SplitN(p, "/", 2)
if len(parts) > 0 && parts[0] != "" {
keyNames[parts[0]] = true
}
}
for name := range keyNames {
ks, err := e.loadKey(ctx, b, mountPath, name)
if err != nil {
return fmt.Errorf("transit: load key %q: %w", name, err)
}
e.keys[name] = ks
}
return nil
}
func (e *TransitEngine) loadKey(ctx context.Context, b barrier.Barrier, mountPath, name string) (*keyState, error) {
prefix := mountPath + "keys/" + name + "/"
configData, err := b.Get(ctx, prefix+"config.json")
if err != nil {
return nil, fmt.Errorf("load config: %w", err)
}
var cfg KeyConfig
if err := json.Unmarshal(configData, &cfg); err != nil {
return nil, fmt.Errorf("parse config: %w", err)
}
ks := &keyState{
config: &cfg,
versions: make(map[int]*keyVersion),
}
// Load all versions.
for v := 1; v <= cfg.CurrentVersion; v++ {
kv, err := e.loadKeyVersion(ctx, b, prefix, &cfg, v)
if err != nil {
// Version may have been trimmed; skip.
continue
}
ks.versions[v] = kv
}
return ks, nil
}
func (e *TransitEngine) loadKeyVersion(ctx context.Context, b barrier.Barrier, prefix string, cfg *KeyConfig, version int) (*keyVersion, error) {
path := fmt.Sprintf("%sv%d.key", prefix, version)
data, err := b.Get(ctx, path)
if err != nil {
return nil, err
}
kv := &keyVersion{version: version}
switch cfg.Type {
case "aes256-gcm", "chacha20-poly", "hmac-sha256", "hmac-sha512":
kv.key = data
case "ed25519":
privKey := ed25519.PrivateKey(data)
kv.key = data
kv.privKey = privKey
kv.pubKey = privKey.Public()
case "ecdsa-p256", "ecdsa-p384":
privKey, err := x509.ParsePKCS8PrivateKey(data)
if err != nil {
return nil, fmt.Errorf("parse PKCS8 key: %w", err)
}
ecKey, ok := privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("expected ECDSA key, got %T", privKey)
}
kv.privKey = ecKey
kv.pubKey = &ecKey.PublicKey
default:
return nil, fmt.Errorf("unknown key type: %s", cfg.Type)
}
return kv, nil
}
// Seal zeroizes all in-memory key material.
func (e *TransitEngine) Seal() error {
e.mu.Lock()
defer e.mu.Unlock()
for name, ks := range e.keys {
for _, kv := range ks.versions {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(e.keys, name)
}
e.keys = nil
e.config = nil
return nil
}
// HandleRequest dispatches transit operations.
func (e *TransitEngine) HandleRequest(ctx context.Context, req *engine.Request) (*engine.Response, error) {
switch req.Operation {
case "create-key":
return e.handleCreateKey(ctx, req)
case "delete-key":
return e.handleDeleteKey(ctx, req)
case "get-key":
return e.handleGetKey(ctx, req)
case "list-keys":
return e.handleListKeys(ctx, req)
case "rotate-key":
return e.handleRotateKey(ctx, req)
case "update-key-config":
return e.handleUpdateKeyConfig(ctx, req)
case "trim-key":
return e.handleTrimKey(ctx, req)
case "encrypt":
return e.handleEncrypt(ctx, req)
case "decrypt":
return e.handleDecrypt(ctx, req)
case "rewrap":
return e.handleRewrap(ctx, req)
case "batch-encrypt":
return e.handleBatchEncrypt(ctx, req)
case "batch-decrypt":
return e.handleBatchDecrypt(ctx, req)
case "batch-rewrap":
return e.handleBatchRewrap(ctx, req)
case "sign":
return e.handleSign(ctx, req)
case "verify":
return e.handleVerify(ctx, req)
case "hmac":
return e.handleHMAC(ctx, req)
case "get-public-key":
return e.handleGetPublicKey(ctx, req)
default:
return nil, fmt.Errorf("transit: unknown operation: %s", req.Operation)
}
}
// --- Authorization helpers ---
func (e *TransitEngine) requireAdmin(req *engine.Request) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if !req.CallerInfo.IsAdmin {
return ErrForbidden
}
return nil
}
func (e *TransitEngine) requireUser(req *engine.Request) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if !req.CallerInfo.IsUser() {
return ErrForbidden
}
return nil
}
func (e *TransitEngine) requireUserWithPolicy(req *engine.Request, keyName string) error {
if req.CallerInfo == nil {
return ErrUnauthorized
}
if req.CallerInfo.IsAdmin {
return nil
}
if !req.CallerInfo.IsUser() {
return ErrForbidden
}
// Check policy for the specific key.
if req.CheckPolicy != nil {
resource := fmt.Sprintf("transit/%s/key/%s", e.mountName(), keyName)
action := operationToAction(req.Operation)
effect, matched := req.CheckPolicy(resource, action)
if matched {
if effect == "allow" {
return nil
}
return ErrForbidden
}
}
// Default: users can access transit operations without explicit policy.
return nil
}
func operationToAction(op string) string {
switch op {
case "get-key", "list-keys", "get-public-key":
return "read"
case "decrypt", "rewrap", "batch-decrypt", "batch-rewrap":
return "decrypt"
default:
return "write"
}
}
// mountName extracts the user-facing mount name from the mount path.
func (e *TransitEngine) mountName() string {
parts := strings.Split(strings.TrimSuffix(e.mountPath, "/"), "/")
if len(parts) >= 3 {
return parts[2]
}
return ""
}
func (e *TransitEngine) sealed() bool {
return e.config == nil
}
// --- Key Management Operations ---
func (e *TransitEngine) handleCreateKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
keyType, _ := req.Data["type"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, fmt.Errorf("transit: %w", err)
}
if keyType == "" {
keyType = "aes256-gcm"
}
if !isValidKeyType(keyType) {
return nil, ErrInvalidKeyType
}
if _, exists := e.keys[name]; exists {
return nil, ErrKeyExists
}
// Generate key version 1.
kv, err := generateKeyVersion(keyType, 1)
if err != nil {
return nil, fmt.Errorf("transit: generate key: %w", err)
}
cfg := &KeyConfig{
Name: name,
Type: keyType,
CurrentVersion: 1,
MinDecryptionVersion: 1,
AllowDeletion: false,
}
// Store config and key.
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, cfg); err != nil {
return nil, err
}
if err := e.storeKeyVersion(ctx, prefix, cfg, kv); err != nil {
return nil, err
}
e.keys[name] = &keyState{
config: cfg,
versions: map[int]*keyVersion{1: kv},
}
return &engine.Response{
Data: map[string]interface{}{
"name": name,
"type": keyType,
"version": 1,
},
}, nil
}
func (e *TransitEngine) handleDeleteKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, err
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
if !ks.config.AllowDeletion {
return nil, ErrDeletionDenied
}
// Delete all versions and config from barrier.
prefix := e.mountPath + "keys/" + name + "/"
for v := range ks.versions {
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
}
_ = e.barrier.Delete(ctx, prefix+"config.json")
// Zeroize in-memory material.
for _, kv := range ks.versions {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(e.keys, name)
return &engine.Response{
Data: map[string]interface{}{"ok": true},
}, nil
}
func (e *TransitEngine) handleGetKey(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, err
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
versions := make([]int, 0, len(ks.versions))
for v := range ks.versions {
versions = append(versions, v)
}
sort.Ints(versions)
return &engine.Response{
Data: map[string]interface{}{
"name": ks.config.Name,
"type": ks.config.Type,
"current_version": ks.config.CurrentVersion,
"min_decryption_version": ks.config.MinDecryptionVersion,
"allow_deletion": ks.config.AllowDeletion,
"versions": versions,
},
}, nil
}
func (e *TransitEngine) handleListKeys(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
keys := make([]string, 0, len(e.keys))
for name := range e.keys {
keys = append(keys, name)
}
sort.Strings(keys)
return &engine.Response{
Data: map[string]interface{}{"keys": keys},
}, nil
}
func (e *TransitEngine) handleRotateKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, err
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
newVersion := ks.config.CurrentVersion + 1
kv, err := generateKeyVersion(ks.config.Type, newVersion)
if err != nil {
return nil, fmt.Errorf("transit: generate key version: %w", err)
}
ks.config.CurrentVersion = newVersion
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, ks.config); err != nil {
return nil, err
}
if err := e.storeKeyVersion(ctx, prefix, ks.config, kv); err != nil {
return nil, err
}
ks.versions[newVersion] = kv
// Prune old versions if max_key_versions is set.
if e.config.MaxKeyVersions > 0 && len(ks.versions) > e.config.MaxKeyVersions {
e.pruneVersions(ctx, ks, prefix)
}
return &engine.Response{
Data: map[string]interface{}{
"name": name,
"version": newVersion,
},
}, nil
}
func (e *TransitEngine) pruneVersions(ctx context.Context, ks *keyState, prefix string) {
versions := make([]int, 0, len(ks.versions))
for v := range ks.versions {
versions = append(versions, v)
}
sort.Ints(versions)
for len(versions) > e.config.MaxKeyVersions {
v := versions[0]
if v >= ks.config.MinDecryptionVersion {
break
}
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
if kv, ok := ks.versions[v]; ok {
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
}
delete(ks.versions, v)
versions = versions[1:]
}
}
func (e *TransitEngine) handleUpdateKeyConfig(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, err
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
if v, ok := req.Data["min_decryption_version"]; ok {
newMin := toInt(v)
if newMin < ks.config.MinDecryptionVersion || newMin > ks.config.CurrentVersion {
return nil, ErrInvalidMinVer
}
ks.config.MinDecryptionVersion = newMin
}
if v, ok := req.Data["allow_deletion"]; ok {
if b, ok := v.(bool); ok {
ks.config.AllowDeletion = b
}
}
prefix := e.mountPath + "keys/" + name + "/"
if err := e.storeKeyConfig(ctx, prefix, ks.config); err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ok": true},
}, nil
}
func (e *TransitEngine) handleTrimKey(ctx context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireAdmin(req); err != nil {
return nil, err
}
e.mu.Lock()
defer e.mu.Unlock()
if e.sealed() {
return nil, ErrSealed
}
name, _ := req.Data["name"].(string)
if name == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(name); err != nil {
return nil, err
}
ks, ok := e.keys[name]
if !ok {
return nil, ErrKeyNotFound
}
prefix := e.mountPath + "keys/" + name + "/"
trimmed := 0
for v, kv := range ks.versions {
if v < ks.config.MinDecryptionVersion {
path := fmt.Sprintf("%sv%d.key", prefix, v)
_ = e.barrier.Delete(ctx, path)
if kv.key != nil {
mcrypto.Zeroize(kv.key)
}
zeroizeKey(kv.privKey)
delete(ks.versions, v)
trimmed++
}
}
return &engine.Response{
Data: map[string]interface{}{
"trimmed": trimmed,
},
}, nil
}
// --- Crypto Operations ---
func (e *TransitEngine) handleEncrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
plaintextB64, _ := req.Data["plaintext"].(string)
contextB64, _ := req.Data["context"].(string)
ciphertext, err := e.encryptWithKey(keyName, plaintextB64, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ciphertext": ciphertext},
}, nil
}
func (e *TransitEngine) encryptWithKey(keyName, plaintextB64, contextB64 string) (string, error) {
ks, ok := e.keys[keyName]
if !ok {
return "", ErrKeyNotFound
}
if !isSymmetric(ks.config.Type) {
return "", ErrUnsupportedOp
}
plaintext, err := base64.StdEncoding.DecodeString(plaintextB64)
if err != nil {
return "", fmt.Errorf("transit: invalid base64 plaintext: %w", err)
}
var aad []byte
if contextB64 != "" {
aad, err = base64.StdEncoding.DecodeString(contextB64)
if err != nil {
return "", fmt.Errorf("transit: invalid base64 context: %w", err)
}
}
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return "", fmt.Errorf("transit: current key version %d not found", currentVersion)
}
encrypted, err := encryptData(ks.config.Type, kv.key, plaintext, aad)
if err != nil {
return "", err
}
return formatCiphertext(currentVersion, encrypted), nil
}
func (e *TransitEngine) handleDecrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ciphertextStr, _ := req.Data["ciphertext"].(string)
contextB64, _ := req.Data["context"].(string)
plaintext, err := e.decryptWithKey(keyName, ciphertextStr, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"plaintext": base64.StdEncoding.EncodeToString(plaintext)},
}, nil
}
func (e *TransitEngine) decryptWithKey(keyName, ciphertextStr, contextB64 string) ([]byte, error) {
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isSymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
version, data, err := parseCiphertext(ciphertextStr)
if err != nil {
return nil, err
}
if version < ks.config.MinDecryptionVersion {
return nil, ErrDecryptVersion
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
var aad []byte
if contextB64 != "" {
aad, err = base64.StdEncoding.DecodeString(contextB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 context: %w", err)
}
}
return decryptData(ks.config.Type, kv.key, data, aad)
}
func (e *TransitEngine) handleRewrap(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ciphertextStr, _ := req.Data["ciphertext"].(string)
contextB64, _ := req.Data["context"].(string)
// Decrypt with old version.
plaintext, err := e.decryptWithKey(keyName, ciphertextStr, contextB64)
if err != nil {
return nil, err
}
// Re-encrypt with latest version (reuse the decoded plaintext as raw bytes).
plaintextB64 := base64.StdEncoding.EncodeToString(plaintext)
newCiphertext, err := e.encryptWithKey(keyName, plaintextB64, contextB64)
if err != nil {
return nil, err
}
return &engine.Response{
Data: map[string]interface{}{"ciphertext": newCiphertext},
}, nil
}
// --- Batch Operations ---
type batchItem struct {
Plaintext string `json:"plaintext"`
Ciphertext string `json:"ciphertext"`
Context string `json:"context"`
Reference string `json:"reference"`
}
type batchResult struct {
Plaintext string `json:"plaintext,omitempty"`
Ciphertext string `json:"ciphertext,omitempty"`
Reference string `json:"reference,omitempty"`
Error string `json:"error,omitempty"`
}
func (e *TransitEngine) handleBatchEncrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
ct, err := e.encryptWithKey(keyName, item.Plaintext, item.Context)
r := batchResult{Reference: item.Reference}
if err != nil {
r.Error = err.Error()
} else {
r.Ciphertext = ct
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
func (e *TransitEngine) handleBatchDecrypt(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
pt, err := e.decryptWithKey(keyName, item.Ciphertext, item.Context)
r := batchResult{Reference: item.Reference}
if err != nil {
r.Error = err.Error()
} else {
r.Plaintext = base64.StdEncoding.EncodeToString(pt)
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
func (e *TransitEngine) handleBatchRewrap(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
items, err := extractBatchItems(req.Data["items"])
if err != nil {
return nil, err
}
if len(items) > maxBatchSize {
return nil, ErrBatchTooLarge
}
results := make([]interface{}, len(items))
for i, item := range items {
r := batchResult{Reference: item.Reference}
// Decrypt with old version.
pt, err := e.decryptWithKey(keyName, item.Ciphertext, item.Context)
if err != nil {
r.Error = err.Error()
results[i] = r
continue
}
// Re-encrypt with latest version.
ptB64 := base64.StdEncoding.EncodeToString(pt)
ct, err := e.encryptWithKey(keyName, ptB64, item.Context)
if err != nil {
r.Error = err.Error()
} else {
r.Ciphertext = ct
}
results[i] = r
}
return &engine.Response{
Data: map[string]interface{}{"results": results},
}, nil
}
// --- Sign/Verify Operations ---
func (e *TransitEngine) handleSign(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return nil, fmt.Errorf("transit: current key version %d not found", currentVersion)
}
var sig []byte
switch ks.config.Type {
case "ed25519":
edKey, ok := kv.privKey.(ed25519.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ed25519 key")
}
sig = ed25519.Sign(edKey, input)
case "ecdsa-p256":
ecKey, ok := kv.privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA key")
}
h := sha256.Sum256(input)
sig, err = ecdsa.SignASN1(rand.Reader, ecKey, h[:])
if err != nil {
return nil, fmt.Errorf("transit: sign: %w", err)
}
case "ecdsa-p384":
ecKey, ok := kv.privKey.(*ecdsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA key")
}
h := sha512.Sum384(input)
sig, err = ecdsa.SignASN1(rand.Reader, ecKey, h[:])
if err != nil {
return nil, fmt.Errorf("transit: sign: %w", err)
}
default:
return nil, ErrUnsupportedOp
}
return &engine.Response{
Data: map[string]interface{}{
"signature": formatSignature(currentVersion, sig),
},
}, nil
}
func (e *TransitEngine) handleVerify(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
signatureStr, _ := req.Data["signature"].(string)
version, sigBytes, err := parseVersionedData(signatureStr)
if err != nil {
return nil, fmt.Errorf("transit: invalid signature format: %w", err)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
valid := false
switch ks.config.Type {
case "ed25519":
edPub, ok := kv.pubKey.(ed25519.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ed25519 public key")
}
valid = ed25519.Verify(edPub, input, sigBytes)
case "ecdsa-p256":
ecPub, ok := kv.pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA public key")
}
h := sha256.Sum256(input)
valid = ecdsa.VerifyASN1(ecPub, h[:], sigBytes)
case "ecdsa-p384":
ecPub, ok := kv.pubKey.(*ecdsa.PublicKey)
if !ok {
return nil, fmt.Errorf("transit: expected ECDSA public key")
}
h := sha512.Sum384(input)
valid = ecdsa.VerifyASN1(ecPub, h[:], sigBytes)
default:
return nil, ErrUnsupportedOp
}
return &engine.Response{
Data: map[string]interface{}{"valid": valid},
}, nil
}
// --- HMAC Operation ---
func (e *TransitEngine) handleHMAC(_ context.Context, req *engine.Request) (*engine.Response, error) {
keyName, _ := req.Data["key"].(string)
if keyName == "" {
keyName, _ = req.Data["name"].(string)
}
if err := e.requireUserWithPolicy(req, keyName); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
if keyName == "" {
return nil, fmt.Errorf("transit: key name is required")
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isHMAC(ks.config.Type) {
return nil, ErrUnsupportedOp
}
inputB64, _ := req.Data["input"].(string)
input, err := base64.StdEncoding.DecodeString(inputB64)
if err != nil {
return nil, fmt.Errorf("transit: invalid base64 input: %w", err)
}
// Verify mode: if hmac is provided, verify it.
if hmacStr, ok := req.Data["hmac"].(string); ok && hmacStr != "" {
version, macBytes, err := parseVersionedData(hmacStr)
if err != nil {
return nil, fmt.Errorf("transit: invalid hmac format: %w", err)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
expected := computeHMAC(ks.config.Type, kv.key, input)
valid := hmac.Equal(macBytes, expected)
return &engine.Response{
Data: map[string]interface{}{"valid": valid},
}, nil
}
// Compute mode.
currentVersion := ks.config.CurrentVersion
kv, ok := ks.versions[currentVersion]
if !ok {
return nil, fmt.Errorf("transit: current key version %d not found", currentVersion)
}
mac := computeHMAC(ks.config.Type, kv.key, input)
return &engine.Response{
Data: map[string]interface{}{
"hmac": formatHMAC(currentVersion, mac),
},
}, nil
}
// --- Get Public Key ---
func (e *TransitEngine) handleGetPublicKey(_ context.Context, req *engine.Request) (*engine.Response, error) {
if err := e.requireUser(req); err != nil {
return nil, err
}
e.mu.RLock()
defer e.mu.RUnlock()
if e.sealed() {
return nil, ErrSealed
}
keyName, _ := req.Data["name"].(string)
if keyName == "" {
return nil, fmt.Errorf("transit: name is required")
}
if err := engine.ValidateName(keyName); err != nil {
return nil, err
}
ks, ok := e.keys[keyName]
if !ok {
return nil, ErrKeyNotFound
}
if !isAsymmetric(ks.config.Type) {
return nil, ErrUnsupportedOp
}
version := ks.config.CurrentVersion
if v, ok := req.Data["version"]; ok {
version = toInt(v)
}
kv, ok := ks.versions[version]
if !ok {
return nil, fmt.Errorf("transit: key version %d not found", version)
}
pubKeyBytes, err := x509.MarshalPKIXPublicKey(kv.pubKey)
if err != nil {
return nil, fmt.Errorf("transit: marshal public key: %w", err)
}
return &engine.Response{
Data: map[string]interface{}{
"public_key": base64.StdEncoding.EncodeToString(pubKeyBytes),
"version": version,
"type": ks.config.Type,
},
}, nil
}
// --- Storage helpers ---
func (e *TransitEngine) storeKeyConfig(ctx context.Context, prefix string, cfg *KeyConfig) error {
data, err := json.Marshal(cfg)
if err != nil {
return fmt.Errorf("transit: marshal key config: %w", err)
}
return e.barrier.Put(ctx, prefix+"config.json", data)
}
func (e *TransitEngine) storeKeyVersion(ctx context.Context, prefix string, cfg *KeyConfig, kv *keyVersion) error {
path := fmt.Sprintf("%sv%d.key", prefix, kv.version)
var data []byte
switch cfg.Type {
case "aes256-gcm", "chacha20-poly", "hmac-sha256", "hmac-sha512":
data = kv.key
case "ed25519":
data = kv.key // raw 64-byte private key
case "ecdsa-p256", "ecdsa-p384":
var err error
data, err = x509.MarshalPKCS8PrivateKey(kv.privKey)
if err != nil {
return fmt.Errorf("transit: marshal PKCS8 key: %w", err)
}
default:
return fmt.Errorf("transit: unknown key type: %s", cfg.Type)
}
return e.barrier.Put(ctx, path, data)
}
// --- Key generation ---
func generateKeyVersion(keyType string, version int) (*keyVersion, error) {
kv := &keyVersion{version: version}
switch keyType {
case "aes256-gcm":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "chacha20-poly":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "ed25519":
_, privKey, err := ed25519.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
kv.key = []byte(privKey)
kv.privKey = privKey
kv.pubKey = privKey.Public()
case "ecdsa-p256":
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
kv.privKey = privKey
kv.pubKey = &privKey.PublicKey
case "ecdsa-p384":
privKey, err := ecdsa.GenerateKey(elliptic.P384(), rand.Reader)
if err != nil {
return nil, err
}
kv.privKey = privKey
kv.pubKey = &privKey.PublicKey
case "hmac-sha256":
key := make([]byte, 32)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
case "hmac-sha512":
key := make([]byte, 64)
if _, err := rand.Read(key); err != nil {
return nil, err
}
kv.key = key
default:
return nil, fmt.Errorf("unknown key type: %s", keyType)
}
return kv, nil
}
// --- Encryption/Decryption helpers ---
func encryptData(keyType string, key, plaintext, aad []byte) ([]byte, error) {
aead, err := newAEAD(keyType, key)
if err != nil {
return nil, err
}
nonce := make([]byte, aead.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, fmt.Errorf("transit: generate nonce: %w", err)
}
ciphertext := aead.Seal(nil, nonce, plaintext, aad)
// Format: nonce + ciphertext (includes tag)
result := make([]byte, len(nonce)+len(ciphertext))
copy(result, nonce)
copy(result[len(nonce):], ciphertext)
return result, nil
}
func decryptData(keyType string, key, data, aad []byte) ([]byte, error) {
aead, err := newAEAD(keyType, key)
if err != nil {
return nil, err
}
nonceSize := aead.NonceSize()
if len(data) < nonceSize {
return nil, ErrInvalidFormat
}
nonce := data[:nonceSize]
ciphertext := data[nonceSize:]
plaintext, err := aead.Open(nil, nonce, ciphertext, aad)
if err != nil {
return nil, fmt.Errorf("transit: decryption failed: %w", err)
}
return plaintext, nil
}
func newAEAD(keyType string, key []byte) (cipher.AEAD, error) {
switch keyType {
case "aes256-gcm":
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("transit: new AES cipher: %w", err)
}
return cipher.NewGCM(block)
case "chacha20-poly":
return chacha20poly1305.NewX(key)
default:
return nil, fmt.Errorf("transit: unsupported encryption type: %s", keyType)
}
}
// --- HMAC helpers ---
func computeHMAC(keyType string, key, input []byte) []byte {
var h func() hash.Hash
switch keyType {
case "hmac-sha256":
h = sha256.New
case "hmac-sha512":
h = sha512.New
default:
return nil
}
mac := hmac.New(h, key)
mac.Write(input)
return mac.Sum(nil)
}
// --- Format helpers ---
func formatCiphertext(version int, data []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(data))
}
func formatSignature(version int, sig []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(sig))
}
func formatHMAC(version int, mac []byte) string {
return fmt.Sprintf("metacrypt:v%d:%s", version, base64.StdEncoding.EncodeToString(mac))
}
func parseCiphertext(s string) (int, []byte, error) {
return parseVersionedData(s)
}
func parseVersionedData(s string) (int, []byte, error) {
parts := strings.SplitN(s, ":", 3)
if len(parts) != 3 || parts[0] != "metacrypt" {
return 0, nil, ErrInvalidFormat
}
if !strings.HasPrefix(parts[1], "v") {
return 0, nil, ErrInvalidFormat
}
version, err := strconv.Atoi(parts[1][1:])
if err != nil {
return 0, nil, ErrInvalidFormat
}
data, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return 0, nil, fmt.Errorf("transit: invalid base64: %w", err)
}
return version, data, nil
}
// --- Type helpers ---
func isValidKeyType(t string) bool {
switch t {
case "aes256-gcm", "chacha20-poly", "ed25519", "ecdsa-p256", "ecdsa-p384", "hmac-sha256", "hmac-sha512":
return true
}
return false
}
func isSymmetric(t string) bool {
return t == "aes256-gcm" || t == "chacha20-poly"
}
func isAsymmetric(t string) bool {
return t == "ed25519" || t == "ecdsa-p256" || t == "ecdsa-p384"
}
func isHMAC(t string) bool {
return t == "hmac-sha256" || t == "hmac-sha512"
}
// --- Utility ---
func toInt(v interface{}) int {
switch val := v.(type) {
case float64:
return int(val)
case int:
return val
case int64:
return int(val)
case json.Number:
n, _ := val.Int64()
return int(n)
}
return 0
}
func extractBatchItems(v interface{}) ([]batchItem, error) {
raw, ok := v.([]interface{})
if !ok {
return nil, fmt.Errorf("transit: items must be an array")
}
items := make([]batchItem, len(raw))
for i, r := range raw {
m, ok := r.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("transit: item %d is not an object", i)
}
items[i].Plaintext, _ = m["plaintext"].(string)
items[i].Ciphertext, _ = m["ciphertext"].(string)
items[i].Context, _ = m["context"].(string)
items[i].Reference, _ = m["reference"].(string)
}
return items, nil
}
func zeroizeKey(key crypto.PrivateKey) {
if key == nil {
return
}
switch k := key.(type) {
case *ecdsa.PrivateKey:
k.D.SetInt64(0)
case ed25519.PrivateKey:
for i := range k {
k[i] = 0
}
}
}