Files
metacrypt/internal/grpcserver/ca.go
Kyle Isom a5bb366558 Allow system accounts to issue certificates
Service tokens from MCIAS have account_type "system" but no roles.
Thread AccountType through CallerInfo and treat system accounts as
users for certificate issuance. This allows services to request
their own TLS certificates without admin credentials.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 20:07:22 -07:00

571 lines
17 KiB
Go

package grpcserver
import (
"context"
"errors"
"strings"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/metacrypt/gen/metacrypt/v2"
"git.wntrmute.dev/kyle/metacrypt/internal/engine"
"git.wntrmute.dev/kyle/metacrypt/internal/engine/ca"
"git.wntrmute.dev/kyle/metacrypt/internal/policy"
)
type caServer struct {
pb.UnimplementedCAServiceServer
s *GRPCServer
}
// caHandleRequest is a helper that dispatches a CA engine request and maps
// common errors to gRPC status codes.
func (cs *caServer) caHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) {
resp, err := cs.s.engines.HandleRequest(ctx, mount, req)
if err != nil {
st := codes.Internal
switch {
case errors.Is(err, engine.ErrMountNotFound):
st = codes.NotFound
case errors.Is(err, ca.ErrIssuerNotFound):
st = codes.NotFound
case errors.Is(err, ca.ErrCertNotFound):
st = codes.NotFound
case errors.Is(err, ca.ErrIssuerExists):
st = codes.AlreadyExists
case errors.Is(err, ca.ErrIdentifierInUse):
st = codes.AlreadyExists
case errors.Is(err, ca.ErrUnauthorized):
st = codes.Unauthenticated
case errors.Is(err, ca.ErrForbidden):
st = codes.PermissionDenied
case strings.Contains(err.Error(), "not found"):
st = codes.NotFound
}
cs.s.logger.Error("grpc: ca "+operation, "mount", mount, "error", err)
return nil, status.Error(st, err.Error())
}
return resp, nil
}
func callerUsername(ctx context.Context) string {
ti := tokenInfoFromContext(ctx)
if ti == nil {
return ""
}
return ti.Username
}
func (cs *caServer) callerInfo(ctx context.Context) *engine.CallerInfo {
ti := tokenInfoFromContext(ctx)
if ti == nil {
return nil
}
return &engine.CallerInfo{
Username: ti.Username,
AccountType: ti.AccountType,
Roles: ti.Roles,
IsAdmin: ti.IsAdmin,
}
}
func (cs *caServer) policyChecker(ctx context.Context) engine.PolicyChecker {
caller := cs.callerInfo(ctx)
if caller == nil {
return nil
}
return func(resource, action string) (string, bool) {
pReq := &policy.Request{
Username: caller.Username,
Roles: caller.Roles,
Resource: resource,
Action: action,
}
effect, matched, err := cs.s.policy.Match(ctx, pReq)
if err != nil {
return string(policy.EffectDeny), false
}
return string(effect), matched
}
}
func (cs *caServer) ImportRoot(ctx context.Context, req *pb.ImportRootRequest) (*pb.ImportRootResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "import-root", &engine.Request{
Operation: "import-root",
CallerInfo: cs.callerInfo(ctx),
Data: map[string]interface{}{
"cert_pem": string(req.CertPem),
"key_pem": string(req.KeyPem),
},
})
if err != nil {
return nil, err
}
cn, _ := resp.Data["cn"].(string)
var expiresAt *timestamppb.Timestamp
if s, ok := resp.Data["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
cs.s.logger.Info("audit: root CA imported", "mount", req.Mount, "cn", cn, "username", callerUsername(ctx))
return &pb.ImportRootResponse{CommonName: cn, ExpiresAt: expiresAt}, nil
}
func (cs *caServer) GetRoot(ctx context.Context, req *pb.GetRootRequest) (*pb.GetRootResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-root", &engine.Request{
Operation: "get-root",
})
if err != nil {
return nil, err
}
certPEM, _ := resp.Data["cert_pem"].(string)
return &pb.GetRootResponse{CertPem: []byte(certPEM)}, nil
}
func (cs *caServer) CreateIssuer(ctx context.Context, req *pb.CreateIssuerRequest) (*pb.CreateIssuerResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
data := map[string]interface{}{
"name": req.Name,
}
if req.KeyAlgorithm != "" {
data["key_algorithm"] = req.KeyAlgorithm
}
if req.KeySize != 0 {
data["key_size"] = float64(req.KeySize)
}
if req.Expiry != "" {
data["expiry"] = req.Expiry
}
if req.MaxTtl != "" {
data["max_ttl"] = req.MaxTtl
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "create-issuer", &engine.Request{
Operation: "create-issuer",
CallerInfo: cs.callerInfo(ctx),
Data: data,
})
if err != nil {
return nil, err
}
name, _ := resp.Data["name"].(string)
certPEM, _ := resp.Data["cert_pem"].(string)
cs.s.logger.Info("audit: issuer created", "mount", req.Mount, "issuer", name, "username", callerUsername(ctx))
return &pb.CreateIssuerResponse{Name: name, CertPem: []byte(certPEM)}, nil
}
func (cs *caServer) DeleteIssuer(ctx context.Context, req *pb.DeleteIssuerRequest) (*pb.DeleteIssuerResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
_, err := cs.caHandleRequest(ctx, req.Mount, "delete-issuer", &engine.Request{
Operation: "delete-issuer",
CallerInfo: cs.callerInfo(ctx),
Data: map[string]interface{}{"name": req.Name},
})
if err != nil {
return nil, err
}
cs.s.logger.Info("audit: issuer deleted", "mount", req.Mount, "issuer", req.Name, "username", callerUsername(ctx))
return &pb.DeleteIssuerResponse{}, nil
}
func (cs *caServer) ListIssuers(ctx context.Context, req *pb.ListIssuersRequest) (*pb.ListIssuersResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-issuers", &engine.Request{
Operation: "list-issuers",
CallerInfo: cs.callerInfo(ctx),
})
if err != nil {
return nil, err
}
raw, _ := resp.Data["issuers"].([]interface{})
issuers := make([]string, 0, len(raw))
for _, v := range raw {
if s, ok := v.(string); ok {
issuers = append(issuers, s)
}
}
return &pb.ListIssuersResponse{Issuers: issuers}, nil
}
func (cs *caServer) GetIssuer(ctx context.Context, req *pb.GetIssuerRequest) (*pb.GetIssuerResponse, error) {
if req.Mount == "" || req.Name == "" {
return nil, status.Error(codes.InvalidArgument, "mount and name are required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-issuer", &engine.Request{
Operation: "get-issuer",
Path: req.Name,
})
if err != nil {
return nil, err
}
certPEM, _ := resp.Data["cert_pem"].(string)
return &pb.GetIssuerResponse{CertPem: []byte(certPEM)}, nil
}
func (cs *caServer) GetChain(ctx context.Context, req *pb.CAServiceGetChainRequest) (*pb.CAServiceGetChainResponse, error) {
if req.Mount == "" || req.Issuer == "" {
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-chain", &engine.Request{
Operation: "get-chain",
Data: map[string]interface{}{"issuer": req.Issuer},
})
if err != nil {
return nil, err
}
chainPEM, _ := resp.Data["chain_pem"].(string)
return &pb.CAServiceGetChainResponse{ChainPem: []byte(chainPEM)}, nil
}
func (cs *caServer) IssueCert(ctx context.Context, req *pb.IssueCertRequest) (*pb.IssueCertResponse, error) {
if req.Mount == "" || req.Issuer == "" {
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
}
data := map[string]interface{}{
"issuer": req.Issuer,
}
if req.Profile != "" {
data["profile"] = req.Profile
}
if req.CommonName != "" {
data["common_name"] = req.CommonName
}
if len(req.DnsNames) > 0 {
dns := make([]interface{}, len(req.DnsNames))
for i, v := range req.DnsNames {
dns[i] = v
}
data["dns_names"] = dns
}
if len(req.IpAddresses) > 0 {
ips := make([]interface{}, len(req.IpAddresses))
for i, v := range req.IpAddresses {
ips[i] = v
}
data["ip_addresses"] = ips
}
if req.Ttl != "" {
data["ttl"] = req.Ttl
}
if req.KeyAlgorithm != "" {
data["key_algorithm"] = req.KeyAlgorithm
}
if req.KeySize != 0 {
data["key_size"] = float64(req.KeySize)
}
if len(req.KeyUsages) > 0 {
ku := make([]interface{}, len(req.KeyUsages))
for i, v := range req.KeyUsages {
ku[i] = v
}
data["key_usages"] = ku
}
if len(req.ExtKeyUsages) > 0 {
eku := make([]interface{}, len(req.ExtKeyUsages))
for i, v := range req.ExtKeyUsages {
eku[i] = v
}
data["ext_key_usages"] = eku
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "issue", &engine.Request{
Operation: "issue",
CallerInfo: cs.callerInfo(ctx),
CheckPolicy: cs.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
serial, _ := resp.Data["serial"].(string)
cn, _ := resp.Data["cn"].(string)
issuedBy, _ := resp.Data["issued_by"].(string)
certPEM, _ := resp.Data["cert_pem"].(string)
keyPEM, _ := resp.Data["key_pem"].(string)
chainPEM, _ := resp.Data["chain_pem"].(string)
sans := toStringSliceFromInterface(resp.Data["sans"])
var expiresAt *timestamppb.Timestamp
if s, ok := resp.Data["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
cs.s.logger.Info("audit: certificate issued", "mount", req.Mount, "issuer", issuedBy, "serial", serial, "cn", cn, "sans", sans, "username", callerUsername(ctx))
return &pb.IssueCertResponse{
Serial: serial,
CommonName: cn,
Sans: sans,
IssuedBy: issuedBy,
ExpiresAt: expiresAt,
CertPem: []byte(certPEM),
KeyPem: []byte(keyPEM),
ChainPem: []byte(chainPEM),
}, nil
}
func (cs *caServer) GetCert(ctx context.Context, req *pb.GetCertRequest) (*pb.GetCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{
Operation: "get-cert",
CallerInfo: cs.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
rec := certRecordFromData(resp.Data)
return &pb.GetCertResponse{Cert: rec}, nil
}
func (cs *caServer) ListCerts(ctx context.Context, req *pb.ListCertsRequest) (*pb.ListCertsResponse, error) {
if req.Mount == "" {
return nil, status.Error(codes.InvalidArgument, "mount is required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{
Operation: "list-certs",
CallerInfo: cs.callerInfo(ctx),
})
if err != nil {
return nil, err
}
raw, _ := resp.Data["certs"].([]interface{})
summaries := make([]*pb.CertSummary, 0, len(raw))
for _, item := range raw {
m, ok := item.(map[string]interface{})
if !ok {
continue
}
summaries = append(summaries, certSummaryFromData(m))
}
return &pb.ListCertsResponse{Certs: summaries}, nil
}
func (cs *caServer) RenewCert(ctx context.Context, req *pb.RenewCertRequest) (*pb.RenewCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "renew", &engine.Request{
Operation: "renew",
CallerInfo: cs.callerInfo(ctx),
CheckPolicy: cs.policyChecker(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
serial, _ := resp.Data["serial"].(string)
cn, _ := resp.Data["cn"].(string)
issuedBy, _ := resp.Data["issued_by"].(string)
certPEM, _ := resp.Data["cert_pem"].(string)
keyPEM, _ := resp.Data["key_pem"].(string)
chainPEM, _ := resp.Data["chain_pem"].(string)
sans := toStringSliceFromInterface(resp.Data["sans"])
var expiresAt *timestamppb.Timestamp
if s, ok := resp.Data["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
cs.s.logger.Info("audit: certificate renewed", "mount", req.Mount, "old_serial", req.Serial, "new_serial", serial, "cn", cn, "issued_by", issuedBy, "username", callerUsername(ctx))
return &pb.RenewCertResponse{
Serial: serial,
CommonName: cn,
Sans: sans,
IssuedBy: issuedBy,
ExpiresAt: expiresAt,
CertPem: []byte(certPEM),
KeyPem: []byte(keyPEM),
ChainPem: []byte(chainPEM),
}, nil
}
func (cs *caServer) SignCSR(ctx context.Context, req *pb.SignCSRRequest) (*pb.SignCSRResponse, error) {
if req.Mount == "" || req.Issuer == "" {
return nil, status.Error(codes.InvalidArgument, "mount and issuer are required")
}
if len(req.CsrPem) == 0 {
return nil, status.Error(codes.InvalidArgument, "csr_pem is required")
}
data := map[string]interface{}{
"issuer": req.Issuer,
"csr_pem": string(req.CsrPem),
}
if req.Profile != "" {
data["profile"] = req.Profile
}
if req.Ttl != "" {
data["ttl"] = req.Ttl
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "sign-csr", &engine.Request{
Operation: "sign-csr",
CallerInfo: cs.callerInfo(ctx),
CheckPolicy: cs.policyChecker(ctx),
Data: data,
})
if err != nil {
return nil, err
}
serial, _ := resp.Data["serial"].(string)
cn, _ := resp.Data["cn"].(string)
issuedBy, _ := resp.Data["issued_by"].(string)
certPEM, _ := resp.Data["cert_pem"].(string)
chainPEM, _ := resp.Data["chain_pem"].(string)
sans := toStringSliceFromInterface(resp.Data["sans"])
var expiresAt *timestamppb.Timestamp
if s, ok := resp.Data["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
cs.s.logger.Info("audit: CSR signed", "mount", req.Mount, "issuer", req.Issuer, "cn", cn, "serial", serial, "username", callerUsername(ctx))
return &pb.SignCSRResponse{
Serial: serial,
CommonName: cn,
Sans: sans,
IssuedBy: issuedBy,
ExpiresAt: expiresAt,
CertPem: []byte(certPEM),
ChainPem: []byte(chainPEM),
}, nil
}
func (cs *caServer) RevokeCert(ctx context.Context, req *pb.RevokeCertRequest) (*pb.RevokeCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
resp, err := cs.caHandleRequest(ctx, req.Mount, "revoke-cert", &engine.Request{
Operation: "revoke-cert",
CallerInfo: cs.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
serial, _ := resp.Data["serial"].(string)
var revokedAt *timestamppb.Timestamp
if s, ok := resp.Data["revoked_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
revokedAt = timestamppb.New(t)
}
}
cs.s.logger.Info("audit: certificate revoked", "mount", req.Mount, "serial", serial, "username", callerUsername(ctx))
return &pb.RevokeCertResponse{Serial: serial, RevokedAt: revokedAt}, nil
}
func (cs *caServer) DeleteCert(ctx context.Context, req *pb.DeleteCertRequest) (*pb.DeleteCertResponse, error) {
if req.Mount == "" || req.Serial == "" {
return nil, status.Error(codes.InvalidArgument, "mount and serial are required")
}
_, err := cs.caHandleRequest(ctx, req.Mount, "delete-cert", &engine.Request{
Operation: "delete-cert",
CallerInfo: cs.callerInfo(ctx),
Data: map[string]interface{}{"serial": req.Serial},
})
if err != nil {
return nil, err
}
cs.s.logger.Info("audit: certificate deleted", "mount", req.Mount, "serial", req.Serial, "username", callerUsername(ctx))
return &pb.DeleteCertResponse{}, nil
}
// --- helpers ---
func certRecordFromData(d map[string]interface{}) *pb.CertRecord {
serial, _ := d["serial"].(string)
issuer, _ := d["issuer"].(string)
cn, _ := d["cn"].(string)
profile, _ := d["profile"].(string)
issuedBy, _ := d["issued_by"].(string)
certPEM, _ := d["cert_pem"].(string)
revoked, _ := d["revoked"].(bool)
revokedBy, _ := d["revoked_by"].(string)
sans := toStringSliceFromInterface(d["sans"])
var issuedAt, expiresAt, revokedAt *timestamppb.Timestamp
if s, ok := d["issued_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
issuedAt = timestamppb.New(t)
}
}
if s, ok := d["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
if s, ok := d["revoked_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
revokedAt = timestamppb.New(t)
}
}
return &pb.CertRecord{
Serial: serial,
Issuer: issuer,
CommonName: cn,
Sans: sans,
Profile: profile,
IssuedBy: issuedBy,
IssuedAt: issuedAt,
ExpiresAt: expiresAt,
CertPem: []byte(certPEM),
Revoked: revoked,
RevokedAt: revokedAt,
RevokedBy: revokedBy,
}
}
func certSummaryFromData(d map[string]interface{}) *pb.CertSummary {
serial, _ := d["serial"].(string)
issuer, _ := d["issuer"].(string)
cn, _ := d["cn"].(string)
profile, _ := d["profile"].(string)
issuedBy, _ := d["issued_by"].(string)
var issuedAt, expiresAt *timestamppb.Timestamp
if s, ok := d["issued_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
issuedAt = timestamppb.New(t)
}
}
if s, ok := d["expires_at"].(string); ok {
if t, err := time.Parse(time.RFC3339, s); err == nil {
expiresAt = timestamppb.New(t)
}
}
return &pb.CertSummary{
Serial: serial,
Issuer: issuer,
CommonName: cn,
Profile: profile,
IssuedBy: issuedBy,
IssuedAt: issuedAt,
ExpiresAt: expiresAt,
}
}
func toStringSliceFromInterface(v interface{}) []string {
raw, _ := v.([]interface{})
out := make([]string, 0, len(raw))
for _, item := range raw {
if s, ok := item.(string); ok {
out = append(out, s)
}
}
return out
}