Files
metacrypt/internal/grpcserver/transit.go
Kyle Isom bbe382dc10 Migrate module path from kyle/ to mc/ org
All import paths updated to git.wntrmute.dev/mc/. Bumps mcdsl to v1.2.0.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 02:05:59 -07:00

488 lines
16 KiB
Go

package grpcserver
import (
"context"
"errors"
"strings"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/mc/metacrypt/gen/metacrypt/v2"
"git.wntrmute.dev/mc/metacrypt/internal/auth"
"git.wntrmute.dev/mc/metacrypt/internal/engine"
"git.wntrmute.dev/mc/metacrypt/internal/engine/transit"
"git.wntrmute.dev/mc/metacrypt/internal/policy"
)
type transitServer struct {
pb.UnimplementedTransitServiceServer
s *GRPCServer
}
func (ts *transitServer) transitHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) {
resp, err := ts.s.engines.HandleRequest(ctx, mount, req)
if err != nil {
st := codes.Internal
switch {
case errors.Is(err, engine.ErrMountNotFound):
st = codes.NotFound
case errors.Is(err, transit.ErrKeyNotFound):
st = codes.NotFound
case errors.Is(err, transit.ErrKeyExists):
st = codes.AlreadyExists
case errors.Is(err, transit.ErrUnauthorized):
st = codes.Unauthenticated
case errors.Is(err, transit.ErrForbidden):
st = codes.PermissionDenied
case errors.Is(err, transit.ErrDeletionDenied):
st = codes.FailedPrecondition
case errors.Is(err, transit.ErrUnsupportedOp):
st = codes.InvalidArgument
case errors.Is(err, transit.ErrDecryptVersion):
st = codes.FailedPrecondition
case errors.Is(err, transit.ErrInvalidFormat):
st = codes.InvalidArgument
case errors.Is(err, transit.ErrBatchTooLarge):
st = codes.InvalidArgument
case errors.Is(err, transit.ErrInvalidMinVer):
st = codes.InvalidArgument
case strings.Contains(err.Error(), "not found"):
st = codes.NotFound
case strings.Contains(err.Error(), "forbidden"):
st = codes.PermissionDenied
}
ts.s.logger.Error("grpc: transit "+operation, "mount", mount, "error", err)
return nil, status.Error(st, err.Error())
}
return resp, nil
}
func (ts *transitServer) callerInfo(ctx context.Context) *engine.CallerInfo {
ti := auth.TokenInfoFromContext(ctx)
if ti == nil {
return nil
}
return &engine.CallerInfo{
Username: ti.Username,
Roles: ti.Roles,
IsAdmin: ti.IsAdmin,
}
}
func (ts *transitServer) policyChecker(ctx context.Context) engine.PolicyChecker {
caller := ts.callerInfo(ctx)
if caller == nil {
return nil
}
return func(resource, action string) (string, bool) {
pReq := &policy.Request{
Username: caller.Username,
Roles: caller.Roles,
Resource: resource,
Action: action,
}
effect, matched, err := ts.s.policy.Match(ctx, pReq)
if err != nil {
return string(policy.EffectDeny), false
}
return string(effect), matched
}
}
func (ts *transitServer) CreateKey(ctx context.Context, req *pb.CreateTransitKeyRequest) (*pb.CreateTransitKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "create-key", &engine.Request{
Operation: "create-key",
CallerInfo: ts.callerInfo(ctx),
Data: map[string]interface{}{
"name": req.Name,
"type": req.Type,
},
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
keyType, _ := resp.Data["type"].(string)
version, _ := resp.Data["version"].(int)
ts.s.logger.Info("audit: transit key created", "mount", req.Mount, "key", name, "type", keyType, "username", callerUsername(ctx))
return &pb.CreateTransitKeyResponse{Name: name, Type: keyType, Version: int32(version)}, nil
}
func (ts *transitServer) DeleteKey(ctx context.Context, req *pb.DeleteTransitKeyRequest) (*pb.DeleteTransitKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
_, err := ts.transitHandleRequest(ctx, req.Mount, "delete-key", &engine.Request{
Operation: "delete-key",
CallerInfo: ts.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
ts.s.logger.Info("audit: transit key deleted", "mount", req.Mount, "key", req.Name, "username", callerUsername(ctx))
return &pb.DeleteTransitKeyResponse{}, nil
}
func (ts *transitServer) GetKey(ctx context.Context, req *pb.GetTransitKeyRequest) (*pb.GetTransitKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "get-key", &engine.Request{
Operation: "get-key",
CallerInfo: ts.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
keyType, _ := resp.Data["type"].(string)
currentVersion, _ := resp.Data["current_version"].(int)
minDecryptionVersion, _ := resp.Data["min_decryption_version"].(int)
allowDeletion, _ := resp.Data["allow_deletion"].(bool)
rawVersions, _ := resp.Data["versions"].([]int)
versions := make([]int32, len(rawVersions))
for i, v := range rawVersions {
versions[i] = int32(v)
}
return &pb.GetTransitKeyResponse{
Name: name,
Type: keyType,
CurrentVersion: int32(currentVersion),
MinDecryptionVersion: int32(minDecryptionVersion),
AllowDeletion: allowDeletion,
Versions: versions,
}, nil
}
func (ts *transitServer) ListKeys(ctx context.Context, req *pb.ListTransitKeysRequest) (*pb.ListTransitKeysResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "list-keys", &engine.Request{
Operation: "list-keys",
CallerInfo: ts.callerInfo(ctx),
})
if err != nil {
return nil, err
}
keys := toStringSliceFromInterface(resp.Data["keys"])
return &pb.ListTransitKeysResponse{Keys: keys}, nil
}
func (ts *transitServer) RotateKey(ctx context.Context, req *pb.RotateTransitKeyRequest) (*pb.RotateTransitKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "rotate-key", &engine.Request{
Operation: "rotate-key",
CallerInfo: ts.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
version, _ := resp.Data["version"].(int)
ts.s.logger.Info("audit: transit key rotated", "mount", req.Mount, "key", name, "version", version, "username", callerUsername(ctx))
return &pb.RotateTransitKeyResponse{Name: name, Version: int32(version)}, nil
}
func (ts *transitServer) UpdateKeyConfig(ctx context.Context, req *pb.UpdateTransitKeyConfigRequest) (*pb.UpdateTransitKeyConfigResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
data := map[string]interface{}{"name": req.Name}
if req.MinDecryptionVersion != 0 {
data["min_decryption_version"] = float64(req.MinDecryptionVersion)
}
data["allow_deletion"] = req.AllowDeletion
_, err := ts.transitHandleRequest(ctx, req.Mount, "update-key-config", &engine.Request{
Operation: "update-key-config",
CallerInfo: ts.callerInfo(ctx),
Data: data,
})
if err != nil {
return nil, err
}
return &pb.UpdateTransitKeyConfigResponse{}, nil
}
func (ts *transitServer) TrimKey(ctx context.Context, req *pb.TrimTransitKeyRequest) (*pb.TrimTransitKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "trim-key", &engine.Request{
Operation: "trim-key",
CallerInfo: ts.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
trimmed, _ := resp.Data["trimmed"].(int)
return &pb.TrimTransitKeyResponse{Trimmed: int32(trimmed)}, nil
}
func (ts *transitServer) Encrypt(ctx context.Context, req *pb.TransitEncryptRequest) (*pb.TransitEncryptResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
data := map[string]interface{}{
"key": req.Key,
"plaintext": req.Plaintext,
}
if req.Context != "" {
data["context"] = req.Context
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "encrypt", &engine.Request{
Operation: "encrypt",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
ct, _ := resp.Data["ciphertext"].(string)
return &pb.TransitEncryptResponse{Ciphertext: ct}, nil
}
func (ts *transitServer) Decrypt(ctx context.Context, req *pb.TransitDecryptRequest) (*pb.TransitDecryptResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
data := map[string]interface{}{
"key": req.Key,
"ciphertext": req.Ciphertext,
}
if req.Context != "" {
data["context"] = req.Context
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "decrypt", &engine.Request{
Operation: "decrypt",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
pt, _ := resp.Data["plaintext"].(string)
return &pb.TransitDecryptResponse{Plaintext: pt}, nil
}
func (ts *transitServer) Rewrap(ctx context.Context, req *pb.TransitRewrapRequest) (*pb.TransitRewrapResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
data := map[string]interface{}{
"key": req.Key,
"ciphertext": req.Ciphertext,
}
if req.Context != "" {
data["context"] = req.Context
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "rewrap", &engine.Request{
Operation: "rewrap",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
ct, _ := resp.Data["ciphertext"].(string)
return &pb.TransitRewrapResponse{Ciphertext: ct}, nil
}
func (ts *transitServer) BatchEncrypt(ctx context.Context, req *pb.TransitBatchEncryptRequest) (*pb.TransitBatchResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
items := protoItemsToInterface(req.Items)
resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-encrypt", &engine.Request{
Operation: "batch-encrypt",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: map[string]interface{}{"key": req.Key, "items": items},
})
if err != nil {
return nil, err
}
return toBatchResponse(resp), nil
}
func (ts *transitServer) BatchDecrypt(ctx context.Context, req *pb.TransitBatchDecryptRequest) (*pb.TransitBatchResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
items := protoItemsToInterface(req.Items)
resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-decrypt", &engine.Request{
Operation: "batch-decrypt",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: map[string]interface{}{"key": req.Key, "items": items},
})
if err != nil {
return nil, err
}
return toBatchResponse(resp), nil
}
func (ts *transitServer) BatchRewrap(ctx context.Context, req *pb.TransitBatchRewrapRequest) (*pb.TransitBatchResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
items := protoItemsToInterface(req.Items)
resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-rewrap", &engine.Request{
Operation: "batch-rewrap",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: map[string]interface{}{"key": req.Key, "items": items},
})
if err != nil {
return nil, err
}
return toBatchResponse(resp), nil
}
func (ts *transitServer) Sign(ctx context.Context, req *pb.TransitSignRequest) (*pb.TransitSignResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "sign", &engine.Request{
Operation: "sign",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: map[string]interface{}{"key": req.Key, "input": req.Input},
})
if err != nil {
return nil, err
}
sig, _ := resp.Data["signature"].(string)
return &pb.TransitSignResponse{Signature: sig}, nil
}
func (ts *transitServer) Verify(ctx context.Context, req *pb.TransitVerifyRequest) (*pb.TransitVerifyResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "verify", &engine.Request{
Operation: "verify",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: map[string]interface{}{
"key": req.Key,
"input": req.Input,
"signature": req.Signature,
},
})
if err != nil {
return nil, err
}
valid, _ := resp.Data["valid"].(bool)
return &pb.TransitVerifyResponse{Valid: valid}, nil
}
func (ts *transitServer) Hmac(ctx context.Context, req *pb.TransitHmacRequest) (*pb.TransitHmacResponse, error) {
if req.Mount == "" || req.Key == "" {
return nil, status.Error(codes.InvalidArgument, "mount and key are required")
}
data := map[string]interface{}{
"key": req.Key,
"input": req.Input,
}
if req.Hmac != "" {
data["hmac"] = req.Hmac
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "hmac", &engine.Request{
Operation: "hmac",
CallerInfo: ts.callerInfo(ctx),
CheckPolicy: ts.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
hmacStr, _ := resp.Data["hmac"].(string)
valid, _ := resp.Data["valid"].(bool)
return &pb.TransitHmacResponse{Hmac: hmacStr, Valid: valid}, nil
}
func (ts *transitServer) GetPublicKey(ctx context.Context, req *pb.GetTransitPublicKeyRequest) (*pb.GetTransitPublicKeyResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
data := map[string]interface{}{"name": req.Name}
if req.Version != 0 {
data["version"] = float64(req.Version)
}
resp, err := ts.transitHandleRequest(ctx, req.Mount, "get-public-key", &engine.Request{
Operation: "get-public-key",
CallerInfo: ts.callerInfo(ctx),
Data: data,
})
if err != nil {
return nil, err
}
pk, _ := resp.Data["public_key"].(string)
version, _ := resp.Data["version"].(int)
keyType, _ := resp.Data["type"].(string)
return &pb.GetTransitPublicKeyResponse{
PublicKey: pk,
Version: int32(version),
Type: keyType,
}, nil
}
// --- helpers ---
func protoItemsToInterface(items []*pb.TransitBatchItem) []interface{} {
out := make([]interface{}, len(items))
for i, item := range items {
m := map[string]interface{}{}
if item.Plaintext != "" {
m["plaintext"] = item.Plaintext
}
if item.Ciphertext != "" {
m["ciphertext"] = item.Ciphertext
}
if item.Context != "" {
m["context"] = item.Context
}
if item.Reference != "" {
m["reference"] = item.Reference
}
out[i] = m
}
return out
}
func toBatchResponse(resp *engine.Response) *pb.TransitBatchResponse {
raw, _ := resp.Data["results"].([]interface{})
results := make([]*pb.TransitBatchResultItem, 0, len(raw))
for _, item := range raw {
switch r := item.(type) {
case map[string]interface{}:
pt, _ := r["plaintext"].(string)
ct, _ := r["ciphertext"].(string)
ref, _ := r["reference"].(string)
errStr, _ := r["error"].(string)
results = append(results, &pb.TransitBatchResultItem{
Plaintext: pt,
Ciphertext: ct,
Reference: ref,
Error: errStr,
})
}
}
return &pb.TransitBatchResponse{Results: results}
}