// policyServiceServer implements mciasv1.PolicyServiceServer. // All handlers are admin-only and delegate to the same db package used by // the REST policy handlers in internal/server/handlers_policy.go. package grpcserver import ( "context" "encoding/json" "errors" "fmt" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" mciasv1 "git.wntrmute.dev/kyle/mcias/gen/mcias/v1" "git.wntrmute.dev/kyle/mcias/internal/db" "git.wntrmute.dev/kyle/mcias/internal/model" "git.wntrmute.dev/kyle/mcias/internal/policy" ) type policyServiceServer struct { mciasv1.UnimplementedPolicyServiceServer s *Server } // policyRuleToProto converts a model.PolicyRuleRecord to the wire representation. func policyRuleToProto(rec *model.PolicyRuleRecord) *mciasv1.PolicyRule { r := &mciasv1.PolicyRule{ Id: rec.ID, Description: rec.Description, Priority: int32(rec.Priority), //nolint:gosec // priority is a small positive integer Enabled: rec.Enabled, RuleJson: rec.RuleJSON, CreatedAt: rec.CreatedAt.UTC().Format(time.RFC3339), UpdatedAt: rec.UpdatedAt.UTC().Format(time.RFC3339), } if rec.NotBefore != nil { r.NotBefore = rec.NotBefore.UTC().Format(time.RFC3339) } if rec.ExpiresAt != nil { r.ExpiresAt = rec.ExpiresAt.UTC().Format(time.RFC3339) } return r } // validateRuleJSON ensures the JSON string is valid and contains a recognised // effect. It mirrors the validation in the REST handleCreatePolicyRule handler. func validateRuleJSON(ruleJSON string) error { var body policy.RuleBody if err := json.Unmarshal([]byte(ruleJSON), &body); err != nil { return fmt.Errorf("rule_json is not valid JSON: %w", err) } if body.Effect != policy.Allow && body.Effect != policy.Deny { return fmt.Errorf("rule.effect must be %q or %q", policy.Allow, policy.Deny) } return nil } // ListPolicyRules returns all policy rules. Admin only. func (p *policyServiceServer) ListPolicyRules(ctx context.Context, _ *mciasv1.ListPolicyRulesRequest) (*mciasv1.ListPolicyRulesResponse, error) { if err := p.s.requireAdmin(ctx); err != nil { return nil, err } rules, err := p.s.db.ListPolicyRules(false) if err != nil { p.s.logger.Error("list policy rules", "error", err) return nil, status.Error(codes.Internal, "internal error") } resp := &mciasv1.ListPolicyRulesResponse{ Rules: make([]*mciasv1.PolicyRule, 0, len(rules)), } for _, rec := range rules { resp.Rules = append(resp.Rules, policyRuleToProto(rec)) } return resp, nil } // CreatePolicyRule creates a new policy rule. Admin only. func (p *policyServiceServer) CreatePolicyRule(ctx context.Context, req *mciasv1.CreatePolicyRuleRequest) (*mciasv1.CreatePolicyRuleResponse, error) { if err := p.s.requireAdmin(ctx); err != nil { return nil, err } if req.Description == "" { return nil, status.Error(codes.InvalidArgument, "description is required") } if req.RuleJson == "" { return nil, status.Error(codes.InvalidArgument, "rule_json is required") } if err := validateRuleJSON(req.RuleJson); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } priority := int(req.Priority) if priority == 0 { priority = 100 // default, matching REST handler } var notBefore, expiresAt *time.Time if req.NotBefore != "" { t, err := time.Parse(time.RFC3339, req.NotBefore) if err != nil { return nil, status.Error(codes.InvalidArgument, "not_before must be RFC3339") } notBefore = &t } if req.ExpiresAt != "" { t, err := time.Parse(time.RFC3339, req.ExpiresAt) if err != nil { return nil, status.Error(codes.InvalidArgument, "expires_at must be RFC3339") } expiresAt = &t } if notBefore != nil && expiresAt != nil && !expiresAt.After(*notBefore) { return nil, status.Error(codes.InvalidArgument, "expires_at must be after not_before") } claims := claimsFromContext(ctx) var createdBy *int64 if claims != nil { if actor, err := p.s.db.GetAccountByUUID(claims.Subject); err == nil { createdBy = &actor.ID } } rec, err := p.s.db.CreatePolicyRule(req.Description, priority, req.RuleJson, createdBy, notBefore, expiresAt) if err != nil { p.s.logger.Error("create policy rule", "error", err) return nil, status.Error(codes.Internal, "internal error") } p.s.db.WriteAuditEvent(model.EventPolicyRuleCreated, createdBy, nil, peerIP(ctx), //nolint:errcheck fmt.Sprintf(`{"rule_id":%d,"description":%q}`, rec.ID, rec.Description)) return &mciasv1.CreatePolicyRuleResponse{Rule: policyRuleToProto(rec)}, nil } // GetPolicyRule returns a single policy rule by ID. Admin only. func (p *policyServiceServer) GetPolicyRule(ctx context.Context, req *mciasv1.GetPolicyRuleRequest) (*mciasv1.GetPolicyRuleResponse, error) { if err := p.s.requireAdmin(ctx); err != nil { return nil, err } if req.Id == 0 { return nil, status.Error(codes.InvalidArgument, "id is required") } rec, err := p.s.db.GetPolicyRule(req.Id) if err != nil { if errors.Is(err, db.ErrNotFound) { return nil, status.Error(codes.NotFound, "policy rule not found") } p.s.logger.Error("get policy rule", "error", err) return nil, status.Error(codes.Internal, "internal error") } return &mciasv1.GetPolicyRuleResponse{Rule: policyRuleToProto(rec)}, nil } // UpdatePolicyRule applies a partial update to a policy rule. Admin only. func (p *policyServiceServer) UpdatePolicyRule(ctx context.Context, req *mciasv1.UpdatePolicyRuleRequest) (*mciasv1.UpdatePolicyRuleResponse, error) { if err := p.s.requireAdmin(ctx); err != nil { return nil, err } if req.Id == 0 { return nil, status.Error(codes.InvalidArgument, "id is required") } // Verify the rule exists before applying updates. if _, err := p.s.db.GetPolicyRule(req.Id); err != nil { if errors.Is(err, db.ErrNotFound) { return nil, status.Error(codes.NotFound, "policy rule not found") } p.s.logger.Error("get policy rule for update", "error", err) return nil, status.Error(codes.Internal, "internal error") } // Build optional update fields — nil means "do not change". var priority *int if req.Priority != nil { v := int(req.GetPriority()) priority = &v } // Double-pointer semantics for time fields: nil outer = no change; // non-nil outer with nil inner = set to NULL; non-nil both = set value. var notBefore, expiresAt **time.Time if req.ClearNotBefore { var nilTime *time.Time notBefore = &nilTime } else if req.NotBefore != "" { t, err := time.Parse(time.RFC3339, req.NotBefore) if err != nil { return nil, status.Error(codes.InvalidArgument, "not_before must be RFC3339") } tp := &t notBefore = &tp } if req.ClearExpiresAt { var nilTime *time.Time expiresAt = &nilTime } else if req.ExpiresAt != "" { t, err := time.Parse(time.RFC3339, req.ExpiresAt) if err != nil { return nil, status.Error(codes.InvalidArgument, "expires_at must be RFC3339") } tp := &t expiresAt = &tp } if err := p.s.db.UpdatePolicyRule(req.Id, nil, priority, nil, notBefore, expiresAt); err != nil { p.s.logger.Error("update policy rule", "error", err) return nil, status.Error(codes.Internal, "internal error") } if req.Enabled != nil { if err := p.s.db.SetPolicyRuleEnabled(req.Id, req.GetEnabled()); err != nil { p.s.logger.Error("set policy rule enabled", "error", err) return nil, status.Error(codes.Internal, "internal error") } } claims := claimsFromContext(ctx) var actorID *int64 if claims != nil { if actor, err := p.s.db.GetAccountByUUID(claims.Subject); err == nil { actorID = &actor.ID } } p.s.db.WriteAuditEvent(model.EventPolicyRuleUpdated, actorID, nil, peerIP(ctx), //nolint:errcheck fmt.Sprintf(`{"rule_id":%d}`, req.Id)) updated, err := p.s.db.GetPolicyRule(req.Id) if err != nil { p.s.logger.Error("get updated policy rule", "error", err) return nil, status.Error(codes.Internal, "internal error") } return &mciasv1.UpdatePolicyRuleResponse{Rule: policyRuleToProto(updated)}, nil } // DeletePolicyRule permanently removes a policy rule. Admin only. func (p *policyServiceServer) DeletePolicyRule(ctx context.Context, req *mciasv1.DeletePolicyRuleRequest) (*mciasv1.DeletePolicyRuleResponse, error) { if err := p.s.requireAdmin(ctx); err != nil { return nil, err } if req.Id == 0 { return nil, status.Error(codes.InvalidArgument, "id is required") } rec, err := p.s.db.GetPolicyRule(req.Id) if err != nil { if errors.Is(err, db.ErrNotFound) { return nil, status.Error(codes.NotFound, "policy rule not found") } p.s.logger.Error("get policy rule for delete", "error", err) return nil, status.Error(codes.Internal, "internal error") } if err := p.s.db.DeletePolicyRule(req.Id); err != nil { p.s.logger.Error("delete policy rule", "error", err) return nil, status.Error(codes.Internal, "internal error") } claims := claimsFromContext(ctx) var actorID *int64 if claims != nil { if actor, err := p.s.db.GetAccountByUUID(claims.Subject); err == nil { actorID = &actor.ID } } p.s.db.WriteAuditEvent(model.EventPolicyRuleDeleted, actorID, nil, peerIP(ctx), //nolint:errcheck fmt.Sprintf(`{"rule_id":%d,"description":%q}`, rec.ID, rec.Description)) return &mciasv1.DeletePolicyRuleResponse{}, nil }