Files
metacrypt/internal/grpcserver/sshca.go
Kyle Isom 310ed83f28 Migrate gRPC server to mcdsl grpcserver package
Replace metacrypt's hand-rolled gRPC interceptor chain with the mcdsl
grpcserver package, which provides TLS setup, logging, and method-map
auth (public/auth-required/admin-required) out of the box.

Metacrypt-specific interceptors are preserved as hooks:
- sealInterceptor runs as a PreInterceptor (before logging/auth)
- auditInterceptor runs as a PostInterceptor (after auth)

The three legacy method maps (seal/auth/admin) are restructured into
mcdsl's MethodMap (Public/AuthRequired/AdminRequired) plus a separate
seal-required map for the PreInterceptor. Token context is now stored
via mcdsl/auth.ContextWithTokenInfo instead of a package-local key.

Bumps mcdsl from v1.0.0 to v1.0.1 (adds PreInterceptors/PostInterceptors
to grpcserver.Options).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 14:42:41 -07:00

462 lines
15 KiB
Go

package grpcserver
import (
"context"
"errors"
"strings"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
"git.wntrmute.dev/kyle/metacrypt/internal/auth"
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
"git.wntrmute.dev/kyle/metacrypt/internal/engine/sshca"
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
)
type sshcaServer struct {
pb.UnimplementedSSHCAServiceServer
s *GRPCServer
}
func (ss *sshcaServer) sshcaHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) {
resp, err := ss.s.engines.HandleRequest(ctx, mount, req)
if err != nil {
st := codes.Internal
switch {
case errors.Is(err, engine.ErrMountNotFound):
st = codes.NotFound
case errors.Is(err, sshca.ErrCertNotFound):
st = codes.NotFound
case errors.Is(err, sshca.ErrProfileNotFound):
st = codes.NotFound
case errors.Is(err, sshca.ErrProfileExists):
st = codes.AlreadyExists
case errors.Is(err, sshca.ErrUnauthorized):
st = codes.Unauthenticated
case errors.Is(err, sshca.ErrForbidden):
st = codes.PermissionDenied
case strings.Contains(err.Error(), "not found"):
st = codes.NotFound
case strings.Contains(err.Error(), "forbidden"):
st = codes.PermissionDenied
}
ss.s.logger.Error("grpc: sshca "+operation, "mount", mount, "error", err)
return nil, status.Error(st, err.Error())
}
return resp, nil
}
func (ss *sshcaServer) callerInfo(ctx context.Context) *engine.CallerInfo {
ti := auth.TokenInfoFromContext(ctx)
if ti == nil {
return nil
}
return &engine.CallerInfo{
Username: ti.Username,
Roles: ti.Roles,
IsAdmin: ti.IsAdmin,
}
}
func (ss *sshcaServer) policyChecker(ctx context.Context) engine.PolicyChecker {
caller := ss.callerInfo(ctx)
if caller == nil {
return nil
}
return func(resource, action string) (string, bool) {
pReq := &policy.Request{
Username: caller.Username,
Roles: caller.Roles,
Resource: resource,
Action: action,
}
effect, matched, err := ss.s.policy.Match(ctx, pReq)
if err != nil {
return string(policy.EffectDeny), false
}
return string(effect), matched
}
}
func (ss *sshcaServer) GetCAPublicKey(ctx context.Context, req *pb.SSHGetCAPublicKeyRequest) (*pb.SSHGetCAPublicKeyResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-ca-pubkey", &engine.Request{
Operation: "get-ca-pubkey",
})
if err != nil {
return nil, err
}
pubKey, _ := resp.Data["public_key"].(string)
return &pb.SSHGetCAPublicKeyResponse{PublicKey: pubKey}, nil
}
func (ss *sshcaServer) SignHost(ctx context.Context, req *pb.SSHSignHostRequest) (*pb.SSHSignHostResponse, error) {
if req.Mount == "" || req.PublicKey == "" || req.Hostname == "" {
return nil, status.Error(codes.InvalidArgument, "mount, public_key, and hostname are required")
}
data := map[string]interface{}{
"public_key": req.PublicKey,
"hostname": req.Hostname,
}
if req.Ttl != "" {
data["ttl"] = req.Ttl
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "sign-host", &engine.Request{
Operation: "sign-host",
CallerInfo: ss.callerInfo(ctx),
CheckPolicy: ss.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
out := &pb.SSHSignHostResponse{
Serial: stringVal(resp.Data, "serial"),
CertType: stringVal(resp.Data, "cert_type"),
Principals: toStringSliceFromInterface(resp.Data["principals"]),
CertData: stringVal(resp.Data, "cert_data"),
KeyId: stringVal(resp.Data, "key_id"),
IssuedBy: stringVal(resp.Data, "issued_by"),
}
out.IssuedAt = parseTimestamp(resp.Data, "issued_at")
out.ExpiresAt = parseTimestamp(resp.Data, "expires_at")
ss.s.logger.Info("audit: SSH host cert signed", "mount", req.Mount, "hostname", req.Hostname, "serial", out.Serial, "username", callerUsername(ctx))
return out, nil
}
func (ss *sshcaServer) SignUser(ctx context.Context, req *pb.SSHSignUserRequest) (*pb.SSHSignUserResponse, error) {
if req.Mount == "" || req.PublicKey == "" {
return nil, status.Error(codes.InvalidArgument, "mount and public_key are required")
}
data := map[string]interface{}{
"public_key": req.PublicKey,
}
if len(req.Principals) > 0 {
principals := make([]interface{}, len(req.Principals))
for i, p := range req.Principals {
principals[i] = p
}
data["principals"] = principals
}
if req.Profile != "" {
data["profile"] = req.Profile
}
if req.Ttl != "" {
data["ttl"] = req.Ttl
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "sign-user", &engine.Request{
Operation: "sign-user",
CallerInfo: ss.callerInfo(ctx),
CheckPolicy: ss.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
out := &pb.SSHSignUserResponse{
Serial: stringVal(resp.Data, "serial"),
CertType: stringVal(resp.Data, "cert_type"),
Principals: toStringSliceFromInterface(resp.Data["principals"]),
CertData: stringVal(resp.Data, "cert_data"),
KeyId: stringVal(resp.Data, "key_id"),
Profile: stringVal(resp.Data, "profile"),
IssuedBy: stringVal(resp.Data, "issued_by"),
}
out.IssuedAt = parseTimestamp(resp.Data, "issued_at")
out.ExpiresAt = parseTimestamp(resp.Data, "expires_at")
ss.s.logger.Info("audit: SSH user cert signed", "mount", req.Mount, "serial", out.Serial, "username", callerUsername(ctx))
return out, nil
}
func (ss *sshcaServer) CreateProfile(ctx context.Context, req *pb.SSHCreateProfileRequest) (*pb.SSHCreateProfileResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
data := map[string]interface{}{
"name": req.Name,
}
if len(req.CriticalOptions) > 0 {
opts := make(map[string]interface{}, len(req.CriticalOptions))
for k, v := range req.CriticalOptions {
opts[k] = v
}
data["critical_options"] = opts
}
if len(req.Extensions) > 0 {
exts := make(map[string]interface{}, len(req.Extensions))
for k, v := range req.Extensions {
exts[k] = v
}
data["extensions"] = exts
}
if req.MaxTtl != "" {
data["max_ttl"] = req.MaxTtl
}
if len(req.AllowedPrincipals) > 0 {
principals := make([]interface{}, len(req.AllowedPrincipals))
for i, p := range req.AllowedPrincipals {
principals[i] = p
}
data["allowed_principals"] = principals
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "create-profile", &engine.Request{
Operation: "create-profile",
CallerInfo: ss.callerInfo(ctx),
Data: data,
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
ss.s.logger.Info("audit: SSH CA profile created", "mount", req.Mount, "profile", name, "username", callerUsername(ctx))
return &pb.SSHCreateProfileResponse{Name: name}, nil
}
func (ss *sshcaServer) UpdateProfile(ctx context.Context, req *pb.SSHUpdateProfileRequest) (*pb.SSHUpdateProfileResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
data := map[string]interface{}{
"name": req.Name,
}
if len(req.CriticalOptions) > 0 {
opts := make(map[string]interface{}, len(req.CriticalOptions))
for k, v := range req.CriticalOptions {
opts[k] = v
}
data["critical_options"] = opts
}
if len(req.Extensions) > 0 {
exts := make(map[string]interface{}, len(req.Extensions))
for k, v := range req.Extensions {
exts[k] = v
}
data["extensions"] = exts
}
if req.MaxTtl != "" {
data["max_ttl"] = req.MaxTtl
}
if len(req.AllowedPrincipals) > 0 {
principals := make([]interface{}, len(req.AllowedPrincipals))
for i, p := range req.AllowedPrincipals {
principals[i] = p
}
data["allowed_principals"] = principals
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "update-profile", &engine.Request{
Operation: "update-profile",
CallerInfo: ss.callerInfo(ctx),
Data: data,
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
ss.s.logger.Info("audit: SSH CA profile updated", "mount", req.Mount, "profile", name, "username", callerUsername(ctx))
return &pb.SSHUpdateProfileResponse{Name: name}, nil
}
func (ss *sshcaServer) GetProfile(ctx context.Context, req *pb.SSHGetProfileRequest) (*pb.SSHGetProfileResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-profile", &engine.Request{
Operation: "get-profile",
CallerInfo: ss.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
out := &pb.SSHGetProfileResponse{
Name: stringVal(resp.Data, "name"),
MaxTtl: stringVal(resp.Data, "max_ttl"),
AllowedPrincipals: toStringSliceFromInterface(resp.Data["allowed_principals"]),
}
if co, ok := resp.Data["critical_options"].(map[string]string); ok {
out.CriticalOptions = co
}
if ext, ok := resp.Data["extensions"].(map[string]string); ok {
out.Extensions = ext
}
return out, nil
}
func (ss *sshcaServer) ListProfiles(ctx context.Context, req *pb.SSHListProfilesRequest) (*pb.SSHListProfilesResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "list-profiles", &engine.Request{
Operation: "list-profiles",
CallerInfo: ss.callerInfo(ctx),
})
if err != nil {
return nil, err
}
profiles := toStringSliceFromInterface(resp.Data["profiles"])
return &pb.SSHListProfilesResponse{Profiles: profiles}, nil
}
func (ss *sshcaServer) DeleteProfile(ctx context.Context, req *pb.SSHDeleteProfileRequest) (*pb.SSHDeleteProfileResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
_, err := ss.sshcaHandleRequest(ctx, req.Mount, "delete-profile", &engine.Request{
Operation: "delete-profile",
CallerInfo: ss.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
ss.s.logger.Info("audit: SSH CA profile deleted", "mount", req.Mount, "profile", req.Name, "username", callerUsername(ctx))
return &pb.SSHDeleteProfileResponse{}, nil
}
func (ss *sshcaServer) GetCert(ctx context.Context, req *pb.SSHGetCertRequest) (*pb.SSHGetCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{
Operation: "get-cert",
CallerInfo: ss.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
return &pb.SSHGetCertResponse{Cert: sshCertRecordFromData(resp.Data)}, nil
}
func (ss *sshcaServer) ListCerts(ctx context.Context, req *pb.SSHListCertsRequest) (*pb.SSHListCertsResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{
Operation: "list-certs",
CallerInfo: ss.callerInfo(ctx),
})
if err != nil {
return nil, err
}
raw, _ := resp.Data["certs"].([]interface{})
summaries := make([]*pb.SSHCertSummary, 0, len(raw))
for _, item := range raw {
m, ok := item.(map[string]interface{})
if !ok {
continue
}
summaries = append(summaries, sshCertSummaryFromData(m))
}
return &pb.SSHListCertsResponse{Certs: summaries}, nil
}
func (ss *sshcaServer) RevokeCert(ctx context.Context, req *pb.SSHRevokeCertRequest) (*pb.SSHRevokeCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "revoke-cert", &engine.Request{
Operation: "revoke-cert",
CallerInfo: ss.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
serial, _ := resp.Data["serial"].(string)
var revokedAt *timestamppb.Timestamp
if s, ok := resp.Data["revoked_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
revokedAt = timestamppb.New(t)
}
}
ss.s.logger.Info("audit: SSH cert revoked", "mount", req.Mount, "serial", serial, "username", callerUsername(ctx))
return &pb.SSHRevokeCertResponse{Serial: serial, RevokedAt: revokedAt}, nil
}
func (ss *sshcaServer) DeleteCert(ctx context.Context, req *pb.SSHDeleteCertRequest) (*pb.SSHDeleteCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
_, err := ss.sshcaHandleRequest(ctx, req.Mount, "delete-cert", &engine.Request{
Operation: "delete-cert",
CallerInfo: ss.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
ss.s.logger.Info("audit: SSH cert deleted", "mount", req.Mount, "serial", req.Serial, "username", callerUsername(ctx))
return &pb.SSHDeleteCertResponse{}, nil
}
func (ss *sshcaServer) GetKRL(ctx context.Context, req *pb.SSHGetKRLRequest) (*pb.SSHGetKRLResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-krl", &engine.Request{
Operation: "get-krl",
})
if err != nil {
return nil, err
}
krl, _ := resp.Data["krl"].(string)
return &pb.SSHGetKRLResponse{Krl: []byte(krl)}, nil
}
// --- helpers ---
func stringVal(d map[string]interface{}, key string) string {
v, _ := d[key].(string)
return v
}
func parseTimestamp(d map[string]interface{}, key string) *timestamppb.Timestamp {
if s, ok := d[key].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
return timestamppb.New(t)
}
}
return nil
}
func sshCertRecordFromData(d map[string]interface{}) *pb.SSHCertRecord {
revoked, _ := d["revoked"].(bool)
rec := &pb.SSHCertRecord{
Serial: stringVal(d, "serial"),
CertType: stringVal(d, "cert_type"),
Principals: toStringSliceFromInterface(d["principals"]),
CertData: stringVal(d, "cert_data"),
KeyId: stringVal(d, "key_id"),
Profile: stringVal(d, "profile"),
IssuedBy: stringVal(d, "issued_by"),
IssuedAt: parseTimestamp(d, "issued_at"),
ExpiresAt: parseTimestamp(d, "expires_at"),
Revoked: revoked,
RevokedAt: parseTimestamp(d, "revoked_at"),
RevokedBy: stringVal(d, "revoked_by"),
}
return rec
}
func sshCertSummaryFromData(d map[string]interface{}) *pb.SSHCertSummary {
revoked, _ := d["revoked"].(bool)
return &pb.SSHCertSummary{
Serial: stringVal(d, "serial"),
CertType: stringVal(d, "cert_type"),
Principals: toStringSliceFromInterface(d["principals"]),
KeyId: stringVal(d, "key_id"),
Profile: stringVal(d, "profile"),
IssuedBy: stringVal(d, "issued_by"),
IssuedAt: parseTimestamp(d, "issued_at"),
ExpiresAt: parseTimestamp(d, "expires_at"),
Revoked: revoked,
}
}