package grpcserver import ( "context" "errors" "strings" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" pb "git.wntrmute.dev/mc/metacrypt/gen/metacrypt/v2" "git.wntrmute.dev/mc/metacrypt/internal/auth" "git.wntrmute.dev/mc/metacrypt/internal/engine" "git.wntrmute.dev/mc/metacrypt/internal/engine/sshca" "git.wntrmute.dev/mc/metacrypt/internal/policy" ) type sshcaServer struct { pb.UnimplementedSSHCAServiceServer s *GRPCServer } func (ss *sshcaServer) sshcaHandleRequest(ctx context.Context, mount, operation string, req *engine.Request) (*engine.Response, error) { resp, err := ss.s.engines.HandleRequest(ctx, mount, req) if err != nil { st := codes.Internal switch { case errors.Is(err, engine.ErrMountNotFound): st = codes.NotFound case errors.Is(err, sshca.ErrCertNotFound): st = codes.NotFound case errors.Is(err, sshca.ErrProfileNotFound): st = codes.NotFound case errors.Is(err, sshca.ErrProfileExists): st = codes.AlreadyExists case errors.Is(err, sshca.ErrUnauthorized): st = codes.Unauthenticated case errors.Is(err, sshca.ErrForbidden): st = codes.PermissionDenied case strings.Contains(err.Error(), "not found"): st = codes.NotFound case strings.Contains(err.Error(), "forbidden"): st = codes.PermissionDenied } ss.s.logger.Error("grpc: sshca "+operation, "mount", mount, "error", err) return nil, status.Error(st, err.Error()) } return resp, nil } func (ss *sshcaServer) callerInfo(ctx context.Context) *engine.CallerInfo { ti := auth.TokenInfoFromContext(ctx) if ti == nil { return nil } return &engine.CallerInfo{ Username: ti.Username, Roles: ti.Roles, IsAdmin: ti.IsAdmin, } } func (ss *sshcaServer) policyChecker(ctx context.Context) engine.PolicyChecker { caller := ss.callerInfo(ctx) if caller == nil { return nil } return func(resource, action string) (string, bool) { pReq := &policy.Request{ Username: caller.Username, Roles: caller.Roles, Resource: resource, Action: action, } effect, matched, err := ss.s.policy.Match(ctx, pReq) if err != nil { return string(policy.EffectDeny), false } return string(effect), matched } } func (ss *sshcaServer) GetCAPublicKey(ctx context.Context, req *pb.SSHGetCAPublicKeyRequest) (*pb.SSHGetCAPublicKeyResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-ca-pubkey", &engine.Request{ Operation: "get-ca-pubkey", }) if err != nil { return nil, err } pubKey, _ := resp.Data["public_key"].(string) return &pb.SSHGetCAPublicKeyResponse{PublicKey: pubKey}, nil } func (ss *sshcaServer) SignHost(ctx context.Context, req *pb.SSHSignHostRequest) (*pb.SSHSignHostResponse, error) { if req.Mount == "" || req.PublicKey == "" || req.Hostname == "" { return nil, status.Error(codes.InvalidArgument, "mount, public_key, and hostname are required") } data := map[string]interface{}{ "public_key": req.PublicKey, "hostname": req.Hostname, } if req.Ttl != "" { data["ttl"] = req.Ttl } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "sign-host", &engine.Request{ Operation: "sign-host", CallerInfo: ss.callerInfo(ctx), CheckPolicy: ss.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } out := &pb.SSHSignHostResponse{ Serial: stringVal(resp.Data, "serial"), CertType: stringVal(resp.Data, "cert_type"), Principals: toStringSliceFromInterface(resp.Data["principals"]), CertData: stringVal(resp.Data, "cert_data"), KeyId: stringVal(resp.Data, "key_id"), IssuedBy: stringVal(resp.Data, "issued_by"), } out.IssuedAt = parseTimestamp(resp.Data, "issued_at") out.ExpiresAt = parseTimestamp(resp.Data, "expires_at") ss.s.logger.Info("audit: SSH host cert signed", "mount", req.Mount, "hostname", req.Hostname, "serial", out.Serial, "username", callerUsername(ctx)) return out, nil } func (ss *sshcaServer) SignUser(ctx context.Context, req *pb.SSHSignUserRequest) (*pb.SSHSignUserResponse, error) { if req.Mount == "" || req.PublicKey == "" { return nil, status.Error(codes.InvalidArgument, "mount and public_key are required") } data := map[string]interface{}{ "public_key": req.PublicKey, } if len(req.Principals) > 0 { principals := make([]interface{}, len(req.Principals)) for i, p := range req.Principals { principals[i] = p } data["principals"] = principals } if req.Profile != "" { data["profile"] = req.Profile } if req.Ttl != "" { data["ttl"] = req.Ttl } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "sign-user", &engine.Request{ Operation: "sign-user", CallerInfo: ss.callerInfo(ctx), CheckPolicy: ss.policyChecker(ctx), Data: data, }) if err != nil { return nil, err } out := &pb.SSHSignUserResponse{ Serial: stringVal(resp.Data, "serial"), CertType: stringVal(resp.Data, "cert_type"), Principals: toStringSliceFromInterface(resp.Data["principals"]), CertData: stringVal(resp.Data, "cert_data"), KeyId: stringVal(resp.Data, "key_id"), Profile: stringVal(resp.Data, "profile"), IssuedBy: stringVal(resp.Data, "issued_by"), } out.IssuedAt = parseTimestamp(resp.Data, "issued_at") out.ExpiresAt = parseTimestamp(resp.Data, "expires_at") ss.s.logger.Info("audit: SSH user cert signed", "mount", req.Mount, "serial", out.Serial, "username", callerUsername(ctx)) return out, nil } func (ss *sshcaServer) CreateProfile(ctx context.Context, req *pb.SSHCreateProfileRequest) (*pb.SSHCreateProfileResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } data := map[string]interface{}{ "name": req.Name, } if len(req.CriticalOptions) > 0 { opts := make(map[string]interface{}, len(req.CriticalOptions)) for k, v := range req.CriticalOptions { opts[k] = v } data["critical_options"] = opts } if len(req.Extensions) > 0 { exts := make(map[string]interface{}, len(req.Extensions)) for k, v := range req.Extensions { exts[k] = v } data["extensions"] = exts } if req.MaxTtl != "" { data["max_ttl"] = req.MaxTtl } if len(req.AllowedPrincipals) > 0 { principals := make([]interface{}, len(req.AllowedPrincipals)) for i, p := range req.AllowedPrincipals { principals[i] = p } data["allowed_principals"] = principals } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "create-profile", &engine.Request{ Operation: "create-profile", CallerInfo: ss.callerInfo(ctx), Data: data, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) ss.s.logger.Info("audit: SSH CA profile created", "mount", req.Mount, "profile", name, "username", callerUsername(ctx)) return &pb.SSHCreateProfileResponse{Name: name}, nil } func (ss *sshcaServer) UpdateProfile(ctx context.Context, req *pb.SSHUpdateProfileRequest) (*pb.SSHUpdateProfileResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } data := map[string]interface{}{ "name": req.Name, } if len(req.CriticalOptions) > 0 { opts := make(map[string]interface{}, len(req.CriticalOptions)) for k, v := range req.CriticalOptions { opts[k] = v } data["critical_options"] = opts } if len(req.Extensions) > 0 { exts := make(map[string]interface{}, len(req.Extensions)) for k, v := range req.Extensions { exts[k] = v } data["extensions"] = exts } if req.MaxTtl != "" { data["max_ttl"] = req.MaxTtl } if len(req.AllowedPrincipals) > 0 { principals := make([]interface{}, len(req.AllowedPrincipals)) for i, p := range req.AllowedPrincipals { principals[i] = p } data["allowed_principals"] = principals } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "update-profile", &engine.Request{ Operation: "update-profile", CallerInfo: ss.callerInfo(ctx), Data: data, }) if err != nil { return nil, err } name, _ := resp.Data["name"].(string) ss.s.logger.Info("audit: SSH CA profile updated", "mount", req.Mount, "profile", name, "username", callerUsername(ctx)) return &pb.SSHUpdateProfileResponse{Name: name}, nil } func (ss *sshcaServer) GetProfile(ctx context.Context, req *pb.SSHGetProfileRequest) (*pb.SSHGetProfileResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-profile", &engine.Request{ Operation: "get-profile", CallerInfo: ss.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } out := &pb.SSHGetProfileResponse{ Name: stringVal(resp.Data, "name"), MaxTtl: stringVal(resp.Data, "max_ttl"), AllowedPrincipals: toStringSliceFromInterface(resp.Data["allowed_principals"]), } if co, ok := resp.Data["critical_options"].(map[string]string); ok { out.CriticalOptions = co } if ext, ok := resp.Data["extensions"].(map[string]string); ok { out.Extensions = ext } return out, nil } func (ss *sshcaServer) ListProfiles(ctx context.Context, req *pb.SSHListProfilesRequest) (*pb.SSHListProfilesResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "list-profiles", &engine.Request{ Operation: "list-profiles", CallerInfo: ss.callerInfo(ctx), }) if err != nil { return nil, err } profiles := toStringSliceFromInterface(resp.Data["profiles"]) return &pb.SSHListProfilesResponse{Profiles: profiles}, nil } func (ss *sshcaServer) DeleteProfile(ctx context.Context, req *pb.SSHDeleteProfileRequest) (*pb.SSHDeleteProfileResponse, error) { if req.Mount == "" || req.Name == "" { return nil, status.Error(codes.InvalidArgument, "mount and name are required") } _, err := ss.sshcaHandleRequest(ctx, req.Mount, "delete-profile", &engine.Request{ Operation: "delete-profile", CallerInfo: ss.callerInfo(ctx), Data: map[string]interface{}{"name": req.Name}, }) if err != nil { return nil, err } ss.s.logger.Info("audit: SSH CA profile deleted", "mount", req.Mount, "profile", req.Name, "username", callerUsername(ctx)) return &pb.SSHDeleteProfileResponse{}, nil } func (ss *sshcaServer) GetCert(ctx context.Context, req *pb.SSHGetCertRequest) (*pb.SSHGetCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-cert", &engine.Request{ Operation: "get-cert", CallerInfo: ss.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } return &pb.SSHGetCertResponse{Cert: sshCertRecordFromData(resp.Data)}, nil } func (ss *sshcaServer) ListCerts(ctx context.Context, req *pb.SSHListCertsRequest) (*pb.SSHListCertsResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "list-certs", &engine.Request{ Operation: "list-certs", CallerInfo: ss.callerInfo(ctx), }) if err != nil { return nil, err } raw, _ := resp.Data["certs"].([]interface{}) summaries := make([]*pb.SSHCertSummary, 0, len(raw)) for _, item := range raw { m, ok := item.(map[string]interface{}) if !ok { continue } summaries = append(summaries, sshCertSummaryFromData(m)) } return &pb.SSHListCertsResponse{Certs: summaries}, nil } func (ss *sshcaServer) RevokeCert(ctx context.Context, req *pb.SSHRevokeCertRequest) (*pb.SSHRevokeCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "revoke-cert", &engine.Request{ Operation: "revoke-cert", CallerInfo: ss.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } serial, _ := resp.Data["serial"].(string) var revokedAt *timestamppb.Timestamp if s, ok := resp.Data["revoked_at"].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { revokedAt = timestamppb.New(t) } } ss.s.logger.Info("audit: SSH cert revoked", "mount", req.Mount, "serial", serial, "username", callerUsername(ctx)) return &pb.SSHRevokeCertResponse{Serial: serial, RevokedAt: revokedAt}, nil } func (ss *sshcaServer) DeleteCert(ctx context.Context, req *pb.SSHDeleteCertRequest) (*pb.SSHDeleteCertResponse, error) { if req.Mount == "" || req.Serial == "" { return nil, status.Error(codes.InvalidArgument, "mount and serial are required") } _, err := ss.sshcaHandleRequest(ctx, req.Mount, "delete-cert", &engine.Request{ Operation: "delete-cert", CallerInfo: ss.callerInfo(ctx), Data: map[string]interface{}{"serial": req.Serial}, }) if err != nil { return nil, err } ss.s.logger.Info("audit: SSH cert deleted", "mount", req.Mount, "serial", req.Serial, "username", callerUsername(ctx)) return &pb.SSHDeleteCertResponse{}, nil } func (ss *sshcaServer) GetKRL(ctx context.Context, req *pb.SSHGetKRLRequest) (*pb.SSHGetKRLResponse, error) { if req.Mount == "" { return nil, status.Error(codes.InvalidArgument, "mount is required") } resp, err := ss.sshcaHandleRequest(ctx, req.Mount, "get-krl", &engine.Request{ Operation: "get-krl", }) if err != nil { return nil, err } krl, _ := resp.Data["krl"].(string) return &pb.SSHGetKRLResponse{Krl: []byte(krl)}, nil } // --- helpers --- func stringVal(d map[string]interface{}, key string) string { v, _ := d[key].(string) return v } func parseTimestamp(d map[string]interface{}, key string) *timestamppb.Timestamp { if s, ok := d[key].(string); ok { if t, err := time.Parse(time.RFC3339, s); err == nil { return timestamppb.New(t) } } return nil } func sshCertRecordFromData(d map[string]interface{}) *pb.SSHCertRecord { revoked, _ := d["revoked"].(bool) rec := &pb.SSHCertRecord{ Serial: stringVal(d, "serial"), CertType: stringVal(d, "cert_type"), Principals: toStringSliceFromInterface(d["principals"]), CertData: stringVal(d, "cert_data"), KeyId: stringVal(d, "key_id"), Profile: stringVal(d, "profile"), IssuedBy: stringVal(d, "issued_by"), IssuedAt: parseTimestamp(d, "issued_at"), ExpiresAt: parseTimestamp(d, "expires_at"), Revoked: revoked, RevokedAt: parseTimestamp(d, "revoked_at"), RevokedBy: stringVal(d, "revoked_by"), } return rec } func sshCertSummaryFromData(d map[string]interface{}) *pb.SSHCertSummary { revoked, _ := d["revoked"].(bool) return &pb.SSHCertSummary{ Serial: stringVal(d, "serial"), CertType: stringVal(d, "cert_type"), Principals: toStringSliceFromInterface(d["principals"]), KeyId: stringVal(d, "key_id"), Profile: stringVal(d, "profile"), IssuedBy: stringVal(d, "issued_by"), IssuedAt: parseTimestamp(d, "issued_at"), ExpiresAt: parseTimestamp(d, "expires_at"), Revoked: revoked, } }