package grpcserver import ( "context" "errors" "strings" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2" "git.wntrmute.dev/kyle/metacrypt/internal/engine" "git.wntrmute.dev/kyle/metacrypt/internal/engine/transit" "git.wntrmute.dev/kyle/metacrypt/internal/policy" ) type transitServer struct { pb.UnimplementedTransitServiceServer s *GRPCServer } func (ts *transitServer) transitHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) { resp, err := ts.s.engines.HandleRequest(ctx, mount, req) if err != nil { st := codes.Internal switch { case errors.Is(err, engine.ErrMountNotFound): st = codes.NotFound case errors.Is(err, transit.ErrKeyNotFound): st = codes.NotFound case errors.Is(err, transit.ErrKeyExists): st = codes.AlreadyExists case errors.Is(err, transit.ErrUnauthorized): st = codes.Unauthenticated case errors.Is(err, transit.ErrForbidden): st = codes.PermissionDenied case errors.Is(err, transit.ErrDeletionDenied): st = codes.FailedPrecondition case errors.Is(err, transit.ErrUnsupportedOp): st = codes.InvalidArgument case errors.Is(err, transit.ErrDecryptVersion): st = codes.FailedPrecondition case errors.Is(err, transit.ErrInvalidFormat): st = codes.InvalidArgument case errors.Is(err, transit.ErrBatchTooLarge): st = codes.InvalidArgument case errors.Is(err, transit.ErrInvalidMinVer): st = codes.InvalidArgument case strings.Contains(err.Error(), "not found"): st = codes.NotFound case strings.Contains(err.Error(), "forbidden"): st = codes.PermissionDenied } ts.s.logger.Error("grpc: transit "+operation, "mount", mount, "error", err) return nil, status.Error(st, err.Error()) } return resp, nil } func (ts *transitServer) callerInfo(ctx context.Context) *engine.CallerInfo { ti := tokenInfoFromContext(ctx) if ti == nil { return nil } return &engine.CallerInfo{ Username: ti.Username, Roles: ti.Roles, IsAdmin: ti.IsAdmin, } } func (ts *transitServer) policyChecker(ctx context.Context) engine.PolicyChecker { caller := ts.callerInfo(ctx) if caller == nil { return nil } return func(resource, action string) (string, bool) { pReq := &policy.Request{ Username: caller.Username, Roles: caller.Roles, Resource: resource, Action: action, } effect, matched, err := ts.s.policy.Match(ctx, pReq) if err != nil { return string(policy.EffectDeny), false } return string(effect), matched } } func (ts *transitServer) CreateKey(ctx context.Context, req *pb.CreateTransitKeyRequest) (*pb.CreateTransitKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "create-key", &engine.Request{ Operation: "create-key", CallerInfo: ts.callerInfo(ctx), Data: map[string]interface{}{ "name": req.Name, "type": req.Type, }, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) keyType, _ := resp.Data["type"].(string) version, _ := resp.Data["version"].(int) ts.s.logger.Info("audit: transit key created", "mount", req.Mount, "key", name, "type", keyType, "username", callerUsername(ctx)) return &pb.CreateTransitKeyResponse{Name: name, Type: keyType, Version: int32(version)}, nil } func (ts *transitServer) DeleteKey(ctx context.Context, req *pb.DeleteTransitKeyRequest) (*pb.DeleteTransitKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } _, err := ts.transitHandleRequest(ctx, req.Mount, "delete-key", &engine.Request{ Operation: "delete-key", CallerInfo: ts.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } ts.s.logger.Info("audit: transit key deleted", "mount", req.Mount, "key", req.Name, "username", callerUsername(ctx)) return &pb.DeleteTransitKeyResponse{}, nil } func (ts *transitServer) GetKey(ctx context.Context, req *pb.GetTransitKeyRequest) (*pb.GetTransitKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "get-key", &engine.Request{ Operation: "get-key", CallerInfo: ts.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) keyType, _ := resp.Data["type"].(string) currentVersion, _ := resp.Data["current_version"].(int) minDecryptionVersion, _ := resp.Data["min_decryption_version"].(int) allowDeletion, _ := resp.Data["allow_deletion"].(bool) rawVersions, _ := resp.Data["versions"].([]int) versions := make([]int32, len(rawVersions)) for i, v := range rawVersions { versions[i] = int32(v) } return &pb.GetTransitKeyResponse{ Name: name, Type: keyType, CurrentVersion: int32(currentVersion), MinDecryptionVersion: int32(minDecryptionVersion), AllowDeletion: allowDeletion, Versions: versions, }, nil } func (ts *transitServer) ListKeys(ctx context.Context, req *pb.ListTransitKeysRequest) (*pb.ListTransitKeysResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "list-keys", &engine.Request{ Operation: "list-keys", CallerInfo: ts.callerInfo(ctx), }) if err != nil { return nil, err } keys := toStringSliceFromInterface(resp.Data["keys"]) return &pb.ListTransitKeysResponse{Keys: keys}, nil } func (ts *transitServer) RotateKey(ctx context.Context, req *pb.RotateTransitKeyRequest) (*pb.RotateTransitKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "rotate-key", &engine.Request{ Operation: "rotate-key", CallerInfo: ts.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) version, _ := resp.Data["version"].(int) ts.s.logger.Info("audit: transit key rotated", "mount", req.Mount, "key", name, "version", version, "username", callerUsername(ctx)) return &pb.RotateTransitKeyResponse{Name: name, Version: int32(version)}, nil } func (ts *transitServer) UpdateKeyConfig(ctx context.Context, req *pb.UpdateTransitKeyConfigRequest) (*pb.UpdateTransitKeyConfigResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } data := map[string]interface{}{"name": req.Name} if req.MinDecryptionVersion != 0 { data["min_decryption_version"] = float64(req.MinDecryptionVersion) } data["allow_deletion"] = req.AllowDeletion _, err := ts.transitHandleRequest(ctx, req.Mount, "update-key-config", &engine.Request{ Operation: "update-key-config", CallerInfo: ts.callerInfo(ctx), Data: data, }) if err != nil { return nil, err } return &pb.UpdateTransitKeyConfigResponse{}, nil } func (ts *transitServer) TrimKey(ctx context.Context, req *pb.TrimTransitKeyRequest) (*pb.TrimTransitKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "trim-key", &engine.Request{ Operation: "trim-key", CallerInfo: ts.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } trimmed, _ := resp.Data["trimmed"].(int) return &pb.TrimTransitKeyResponse{Trimmed: int32(trimmed)}, nil } func (ts *transitServer) Encrypt(ctx context.Context, req *pb.TransitEncryptRequest) (*pb.TransitEncryptResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } data := map[string]interface{}{ "key": req.Key, "plaintext": req.Plaintext, } if req.Context != "" { data["context"] = req.Context } resp, err := ts.transitHandleRequest(ctx, req.Mount, "encrypt", &engine.Request{ Operation: "encrypt", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } ct, _ := resp.Data["ciphertext"].(string) return &pb.TransitEncryptResponse{Ciphertext: ct}, nil } func (ts *transitServer) Decrypt(ctx context.Context, req *pb.TransitDecryptRequest) (*pb.TransitDecryptResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } data := map[string]interface{}{ "key": req.Key, "ciphertext": req.Ciphertext, } if req.Context != "" { data["context"] = req.Context } resp, err := ts.transitHandleRequest(ctx, req.Mount, "decrypt", &engine.Request{ Operation: "decrypt", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } pt, _ := resp.Data["plaintext"].(string) return &pb.TransitDecryptResponse{Plaintext: pt}, nil } func (ts *transitServer) Rewrap(ctx context.Context, req *pb.TransitRewrapRequest) (*pb.TransitRewrapResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } data := map[string]interface{}{ "key": req.Key, "ciphertext": req.Ciphertext, } if req.Context != "" { data["context"] = req.Context } resp, err := ts.transitHandleRequest(ctx, req.Mount, "rewrap", &engine.Request{ Operation: "rewrap", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } ct, _ := resp.Data["ciphertext"].(string) return &pb.TransitRewrapResponse{Ciphertext: ct}, nil } func (ts *transitServer) BatchEncrypt(ctx context.Context, req *pb.TransitBatchEncryptRequest) (*pb.TransitBatchResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } items := protoItemsToInterface(req.Items) resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-encrypt", &engine.Request{ Operation: "batch-encrypt", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: map[string]interface{}{"key": req.Key, "items": items}, }) if err != nil { return nil, err } return toBatchResponse(resp), nil } func (ts *transitServer) BatchDecrypt(ctx context.Context, req *pb.TransitBatchDecryptRequest) (*pb.TransitBatchResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } items := protoItemsToInterface(req.Items) resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-decrypt", &engine.Request{ Operation: "batch-decrypt", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: map[string]interface{}{"key": req.Key, "items": items}, }) if err != nil { return nil, err } return toBatchResponse(resp), nil } func (ts *transitServer) BatchRewrap(ctx context.Context, req *pb.TransitBatchRewrapRequest) (*pb.TransitBatchResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } items := protoItemsToInterface(req.Items) resp, err := ts.transitHandleRequest(ctx, req.Mount, "batch-rewrap", &engine.Request{ Operation: "batch-rewrap", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: map[string]interface{}{"key": req.Key, "items": items}, }) if err != nil { return nil, err } return toBatchResponse(resp), nil } func (ts *transitServer) Sign(ctx context.Context, req *pb.TransitSignRequest) (*pb.TransitSignResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "sign", &engine.Request{ Operation: "sign", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: map[string]interface{}{"key": req.Key, "input": req.Input}, }) if err != nil { return nil, err } sig, _ := resp.Data["signature"].(string) return &pb.TransitSignResponse{Signature: sig}, nil } func (ts *transitServer) Verify(ctx context.Context, req *pb.TransitVerifyRequest) (*pb.TransitVerifyResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } resp, err := ts.transitHandleRequest(ctx, req.Mount, "verify", &engine.Request{ Operation: "verify", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: map[string]interface{}{ "key": req.Key, "input": req.Input, "signature": req.Signature, }, }) if err != nil { return nil, err } valid, _ := resp.Data["valid"].(bool) return &pb.TransitVerifyResponse{Valid: valid}, nil } func (ts *transitServer) Hmac(ctx context.Context, req *pb.TransitHmacRequest) (*pb.TransitHmacResponse, error) { if req.Mount == "" || req.Key == "" { return nil, status.Error(codes.InvalidArgument, "mount and key are required") } data := map[string]interface{}{ "key": req.Key, "input": req.Input, } if req.Hmac != "" { data["hmac"] = req.Hmac } resp, err := ts.transitHandleRequest(ctx, req.Mount, "hmac", &engine.Request{ Operation: "hmac", CallerInfo: ts.callerInfo(ctx), CheckPolicy: ts.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } hmacStr, _ := resp.Data["hmac"].(string) valid, _ := resp.Data["valid"].(bool) return &pb.TransitHmacResponse{Hmac: hmacStr, Valid: valid}, nil } func (ts *transitServer) GetPublicKey(ctx context.Context, req *pb.GetTransitPublicKeyRequest) (*pb.GetTransitPublicKeyResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } data := map[string]interface{}{"name": req.Name} if req.Version != 0 { data["version"] = float64(req.Version) } resp, err := ts.transitHandleRequest(ctx, req.Mount, "get-public-key", &engine.Request{ Operation: "get-public-key", CallerInfo: ts.callerInfo(ctx), Data: data, }) if err != nil { return nil, err } pk, _ := resp.Data["public_key"].(string) version, _ := resp.Data["version"].(int) keyType, _ := resp.Data["type"].(string) return &pb.GetTransitPublicKeyResponse{ PublicKey: pk, Version: int32(version), Type: keyType, }, nil } // --- helpers --- func protoItemsToInterface(items []*pb.TransitBatchItem) []interface{} { out := make([]interface{}, len(items)) for i, item := range items { m := map[string]interface{}{} if item.Plaintext != "" { m["plaintext"] = item.Plaintext } if item.Ciphertext != "" { m["ciphertext"] = item.Ciphertext } if item.Context != "" { m["context"] = item.Context } if item.Reference != "" { m["reference"] = item.Reference } out[i] = m } return out } func toBatchResponse(resp *engine.Response) *pb.TransitBatchResponse { raw, _ := resp.Data["results"].([]interface{}) results := make([]*pb.TransitBatchResultItem, 0, len(raw)) for _, item := range raw { switch r := item.(type) { case map[string]interface{}: pt, _ := r["plaintext"].(string) ct, _ := r["ciphertext"].(string) ref, _ := r["reference"].(string) errStr, _ := r["error"].(string) results = append(results, &pb.TransitBatchResultItem{ Plaintext: pt, Ciphertext: ct, Reference: ref, Error: errStr, }) } } return &pb.TransitBatchResponse{Results: results} }