package grpcserver import ( "context" "errors" "strings" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" pb "git.wntrmute.dev/mc/metacrypt/gen/metacrypt/v2" "git.wntrmute.dev/mc/metacrypt/internal/auth" "git.wntrmute.dev/mc/metacrypt/internal/engine" "git.wntrmute.dev/mc/metacrypt/internal/engine/ca" "git.wntrmute.dev/mc/metacrypt/internal/policy" ) type caServer struct { pb.UnimplementedCAServiceServer s *GRPCServer } // caHandleRequest is a helper that dispatches a CA engine request and maps // common errors to gRPC status codes. func (cs *caServer) caHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) { resp, err := cs.s.engines.HandleRequest(ctx, mount, req) if err != nil { st := codes.Internal switch { case errors.Is(err, engine.ErrMountNotFound): st = codes.NotFound case errors.Is(err, ca.ErrIssuerNotFound): st = codes.NotFound case errors.Is(err, ca.ErrCertNotFound): st = codes.NotFound case errors.Is(err, ca.ErrIssuerExists): st = codes.AlreadyExists case errors.Is(err, ca.ErrIdentifierInUse): st = codes.AlreadyExists case errors.Is(err, ca.ErrUnauthorized): st = codes.Unauthenticated case errors.Is(err, ca.ErrForbidden): st = codes.PermissionDenied case strings.Contains(err.Error(), "not found"): st = codes.NotFound } cs.s.logger.Error("grpc: ca "+operation, "mount", mount, "error", err) return nil, status.Error(st, err.Error()) } return resp, nil } func (cs *caServer) callerInfo(ctx context.Context) *engine.CallerInfo { ti := auth.TokenInfoFromContext(ctx) if ti == nil { return nil } return &engine.CallerInfo{ Username: ti.Username, AccountType: ti.AccountType, Roles: ti.Roles, IsAdmin: ti.IsAdmin, } } func (cs *caServer) policyChecker(ctx context.Context) engine.PolicyChecker { caller := cs.callerInfo(ctx) if caller == nil { return nil } return func(resource, action string) (string, bool) { pReq := &policy.Request{ Username: caller.Username, Roles: caller.Roles, Resource: resource, Action: action, } effect, matched, err := cs.s.policy.Match(ctx, pReq) if err != nil { return string(policy.EffectDeny), false } return string(effect), matched } } func (cs *caServer) ImportRoot(ctx context.Context, req *pb.ImportRootRequest) (*pb.ImportRootResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "import-root", &engine.Request{ Operation: "import-root", CallerInfo: cs.callerInfo(ctx), Data: map[string]interface{}{ "cert_pem": string(req.CertPem), "key_pem": string(req.KeyPem), }, }) if err != nil { return nil, err } cn, _ := resp.Data["cn"].(string) var expiresAt *timestamppb.Timestamp if s, ok := resp.Data["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } cs.s.logger.Info("audit: root CA imported", "mount", req.Mount, "cn", cn, "username", callerUsername(ctx)) return &pb.ImportRootResponse{CommonName: cn, ExpiresAt: expiresAt}, nil } func (cs *caServer) GetRoot(ctx context.Context, req *pb.GetRootRequest) (*pb.GetRootResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "get-root", &engine.Request{ Operation: "get-root", }) if err != nil { return nil, err } certPEM, _ := resp.Data["cert_pem"].(string) return &pb.GetRootResponse{CertPem: []byte(certPEM)}, nil } func (cs *caServer) CreateIssuer(ctx context.Context, req *pb.CreateIssuerRequest) (*pb.CreateIssuerResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } data := map[string]interface{}{ "name": req.Name, } if req.KeyAlgorithm != "" { data["key_algorithm"] = req.KeyAlgorithm } if req.KeySize != 0 { data["key_size"] = float64(req.KeySize) } if req.Expiry != "" { data["expiry"] = req.Expiry } if req.MaxTtl != "" { data["max_ttl"] = req.MaxTtl } resp, err := cs.caHandleRequest(ctx, req.Mount, "create-issuer", &engine.Request{ Operation: "create-issuer", CallerInfo: cs.callerInfo(ctx), Data: data, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) certPEM, _ := resp.Data["cert_pem"].(string) cs.s.logger.Info("audit: issuer created", "mount", req.Mount, "issuer", name, "username", callerUsername(ctx)) return &pb.CreateIssuerResponse{Name: name, CertPem: []byte(certPEM)}, nil } func (cs *caServer) DeleteIssuer(ctx context.Context, req *pb.DeleteIssuerRequest) (*pb.DeleteIssuerResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } _, err := cs.caHandleRequest(ctx, req.Mount, "delete-issuer", &engine.Request{ Operation: "delete-issuer", CallerInfo: cs.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } cs.s.logger.Info("audit: issuer deleted", "mount", req.Mount, "issuer", req.Name, "username", callerUsername(ctx)) return &pb.DeleteIssuerResponse{}, nil } func (cs *caServer) ListIssuers(ctx context.Context, req *pb.ListIssuersRequest) (*pb.ListIssuersResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "list-issuers", &engine.Request{ Operation: "list-issuers", CallerInfo: cs.callerInfo(ctx), }) if err != nil { return nil, err } raw, _ := resp.Data["issuers"].([]interface{}) issuers := make([]string, 0, len(raw)) for _, v := range raw { if s, ok := v.(string); ok { issuers = append(issuers, s) } } return &pb.ListIssuersResponse{Issuers: issuers}, nil } func (cs *caServer) GetIssuer(ctx context.Context, req *pb.GetIssuerRequest) (*pb.GetIssuerResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "get-issuer", &engine.Request{ Operation: "get-issuer", Path: req.Name, }) if err != nil { return nil, err } certPEM, _ := resp.Data["cert_pem"].(string) return &pb.GetIssuerResponse{CertPem: []byte(certPEM)}, nil } func (cs *caServer) GetChain(ctx context.Context, req *pb.CAServiceGetChainRequest) (*pb.CAServiceGetChainResponse, error) { if req.Mount == "" || req.Issuer == "" { return nil, status.Error(codes.InvalidArgument, "mount and issuer are required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "get-chain", &engine.Request{ Operation: "get-chain", Data: map[string]interface{}{"issuer": req.Issuer}, }) if err != nil { return nil, err } chainPEM, _ := resp.Data["chain_pem"].(string) return &pb.CAServiceGetChainResponse{ChainPem: []byte(chainPEM)}, nil } func (cs *caServer) IssueCert(ctx context.Context, req *pb.IssueCertRequest) (*pb.IssueCertResponse, error) { if req.Mount == "" || req.Issuer == "" { return nil, status.Error(codes.InvalidArgument, "mount and issuer are required") } data := map[string]interface{}{ "issuer": req.Issuer, } if req.Profile != "" { data["profile"] = req.Profile } if req.CommonName != "" { data["common_name"] = req.CommonName } if len(req.DnsNames) > 0 { dns := make([]interface{}, len(req.DnsNames)) for i, v := range req.DnsNames { dns[i] = v } data["dns_names"] = dns } if len(req.IpAddresses) > 0 { ips := make([]interface{}, len(req.IpAddresses)) for i, v := range req.IpAddresses { ips[i] = v } data["ip_addresses"] = ips } if req.Ttl != "" { data["ttl"] = req.Ttl } if req.KeyAlgorithm != "" { data["key_algorithm"] = req.KeyAlgorithm } if req.KeySize != 0 { data["key_size"] = float64(req.KeySize) } if len(req.KeyUsages) > 0 { ku := make([]interface{}, len(req.KeyUsages)) for i, v := range req.KeyUsages { ku[i] = v } data["key_usages"] = ku } if len(req.ExtKeyUsages) > 0 { eku := make([]interface{}, len(req.ExtKeyUsages)) for i, v := range req.ExtKeyUsages { eku[i] = v } data["ext_key_usages"] = eku } resp, err := cs.caHandleRequest(ctx, req.Mount, "issue", &engine.Request{ Operation: "issue", CallerInfo: cs.callerInfo(ctx), CheckPolicy: cs.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } serial, _ := resp.Data["serial"].(string) cn, _ := resp.Data["cn"].(string) issuedBy, _ := resp.Data["issued_by"].(string) certPEM, _ := resp.Data["cert_pem"].(string) keyPEM, _ := resp.Data["key_pem"].(string) chainPEM, _ := resp.Data["chain_pem"].(string) sans := toStringSliceFromInterface(resp.Data["sans"]) var expiresAt *timestamppb.Timestamp if s, ok := resp.Data["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } cs.s.logger.Info("audit: certificate issued", "mount", req.Mount, "issuer", issuedBy, "serial", serial, "cn", cn, "sans", sans, "username", callerUsername(ctx)) return &pb.IssueCertResponse{ Serial: serial, CommonName: cn, Sans: sans, IssuedBy: issuedBy, ExpiresAt: expiresAt, CertPem: []byte(certPEM), KeyPem: []byte(keyPEM), ChainPem: []byte(chainPEM), }, nil } func (cs *caServer) GetCert(ctx context.Context, req *pb.GetCertRequest) (*pb.GetCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{ Operation: "get-cert", CallerInfo: cs.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } rec := certRecordFromData(resp.Data) return &pb.GetCertResponse{Cert: rec}, nil } func (cs *caServer) ListCerts(ctx context.Context, req *pb.ListCertsRequest) (*pb.ListCertsResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{ Operation: "list-certs", CallerInfo: cs.callerInfo(ctx), }) if err != nil { return nil, err } raw, _ := resp.Data["certs"].([]interface{}) summaries := make([]*pb.CertSummary, 0, len(raw)) for _, item := range raw { m, ok := item.(map[string]interface{}) if !ok { continue } summaries = append(summaries, certSummaryFromData(m)) } return &pb.ListCertsResponse{Certs: summaries}, nil } func (cs *caServer) RenewCert(ctx context.Context, req *pb.RenewCertRequest) (*pb.RenewCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "renew", &engine.Request{ Operation: "renew", CallerInfo: cs.callerInfo(ctx), CheckPolicy: cs.policyChecker(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } serial, _ := resp.Data["serial"].(string) cn, _ := resp.Data["cn"].(string) issuedBy, _ := resp.Data["issued_by"].(string) certPEM, _ := resp.Data["cert_pem"].(string) keyPEM, _ := resp.Data["key_pem"].(string) chainPEM, _ := resp.Data["chain_pem"].(string) sans := toStringSliceFromInterface(resp.Data["sans"]) var expiresAt *timestamppb.Timestamp if s, ok := resp.Data["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } cs.s.logger.Info("audit: certificate renewed", "mount", req.Mount, "old_serial", req.Serial, "new_serial", serial, "cn", cn, "issued_by", issuedBy, "username", callerUsername(ctx)) return &pb.RenewCertResponse{ Serial: serial, CommonName: cn, Sans: sans, IssuedBy: issuedBy, ExpiresAt: expiresAt, CertPem: []byte(certPEM), KeyPem: []byte(keyPEM), ChainPem: []byte(chainPEM), }, nil } func (cs *caServer) SignCSR(ctx context.Context, req *pb.SignCSRRequest) (*pb.SignCSRResponse, error) { if req.Mount == "" || req.Issuer == "" { return nil, status.Error(codes.InvalidArgument, "mount and issuer are required") } if len(req.CsrPem) == 0 { return nil, status.Error(codes.InvalidArgument, "csr_pem is required") } data := map[string]interface{}{ "issuer": req.Issuer, "csr_pem": string(req.CsrPem), } if req.Profile != "" { data["profile"] = req.Profile } if req.Ttl != "" { data["ttl"] = req.Ttl } resp, err := cs.caHandleRequest(ctx, req.Mount, "sign-csr", &engine.Request{ Operation: "sign-csr", CallerInfo: cs.callerInfo(ctx), CheckPolicy: cs.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } serial, _ := resp.Data["serial"].(string) cn, _ := resp.Data["cn"].(string) issuedBy, _ := resp.Data["issued_by"].(string) certPEM, _ := resp.Data["cert_pem"].(string) chainPEM, _ := resp.Data["chain_pem"].(string) sans := toStringSliceFromInterface(resp.Data["sans"]) var expiresAt *timestamppb.Timestamp if s, ok := resp.Data["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } cs.s.logger.Info("audit: CSR signed", "mount", req.Mount, "issuer", req.Issuer, "cn", cn, "serial", serial, "username", callerUsername(ctx)) return &pb.SignCSRResponse{ Serial: serial, CommonName: cn, Sans: sans, IssuedBy: issuedBy, ExpiresAt: expiresAt, CertPem: []byte(certPEM), ChainPem: []byte(chainPEM), }, nil } func (cs *caServer) RevokeCert(ctx context.Context, req *pb.RevokeCertRequest) (*pb.RevokeCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } resp, err := cs.caHandleRequest(ctx, req.Mount, "revoke-cert", &engine.Request{ Operation: "revoke-cert", CallerInfo: cs.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } serial, _ := resp.Data["serial"].(string) var revokedAt *timestamppb.Timestamp if s, ok := resp.Data["revoked_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { revokedAt = timestamppb.New(t) } } cs.s.logger.Info("audit: certificate revoked", "mount", req.Mount, "serial", serial, "username", callerUsername(ctx)) return &pb.RevokeCertResponse{Serial: serial, RevokedAt: revokedAt}, nil } func (cs *caServer) DeleteCert(ctx context.Context, req *pb.DeleteCertRequest) (*pb.DeleteCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } _, err := cs.caHandleRequest(ctx, req.Mount, "delete-cert", &engine.Request{ Operation: "delete-cert", CallerInfo: cs.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } cs.s.logger.Info("audit: certificate deleted", "mount", req.Mount, "serial", req.Serial, "username", callerUsername(ctx)) return &pb.DeleteCertResponse{}, nil } // --- helpers --- func certRecordFromData(d map[string]interface{}) *pb.CertRecord { serial, _ := d["serial"].(string) issuer, _ := d["issuer"].(string) cn, _ := d["cn"].(string) profile, _ := d["profile"].(string) issuedBy, _ := d["issued_by"].(string) certPEM, _ := d["cert_pem"].(string) revoked, _ := d["revoked"].(bool) revokedBy, _ := d["revoked_by"].(string) sans := toStringSliceFromInterface(d["sans"]) var issuedAt, expiresAt, revokedAt *timestamppb.Timestamp if s, ok := d["issued_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { issuedAt = timestamppb.New(t) } } if s, ok := d["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } if s, ok := d["revoked_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { revokedAt = timestamppb.New(t) } } return &pb.CertRecord{ Serial: serial, Issuer: issuer, CommonName: cn, Sans: sans, Profile: profile, IssuedBy: issuedBy, IssuedAt: issuedAt, ExpiresAt: expiresAt, CertPem: []byte(certPEM), Revoked: revoked, RevokedAt: revokedAt, RevokedBy: revokedBy, } } func certSummaryFromData(d map[string]interface{}) *pb.CertSummary { serial, _ := d["serial"].(string) issuer, _ := d["issuer"].(string) cn, _ := d["cn"].(string) profile, _ := d["profile"].(string) issuedBy, _ := d["issued_by"].(string) var issuedAt, expiresAt *timestamppb.Timestamp if s, ok := d["issued_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { issuedAt = timestamppb.New(t) } } if s, ok := d["expires_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { expiresAt = timestamppb.New(t) } } return &pb.CertSummary{ Serial: serial, Issuer: issuer, CommonName: cn, Profile: profile, IssuedBy: issuedBy, IssuedAt: issuedAt, ExpiresAt: expiresAt, } } func toStringSliceFromInterface(v interface{}) []string { raw, _ := v.([]interface{}) out := make([]string, 0, len(raw)) for _, item := range raw { if s, ok := item.(string); ok { out = append(out, s) } } return out }