package grpcserver import ( "context" "errors" "fmt" "strconv" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" pb "git.wntrmute.dev/kyle/mcr/gen/mcr/v1" "git.wntrmute.dev/kyle/mcr/internal/auth" "git.wntrmute.dev/kyle/mcr/internal/db" "git.wntrmute.dev/kyle/mcr/internal/policy" ) var validActions = map[string]bool{ string(policy.ActionVersionCheck): true, string(policy.ActionPull): true, string(policy.ActionPush): true, string(policy.ActionDelete): true, string(policy.ActionCatalog): true, string(policy.ActionPolicyManage): true, } // policyService implements pb.PolicyServiceServer. type policyService struct { pb.UnimplementedPolicyServiceServer db *db.DB engine PolicyReloader auditFn AuditFunc } func (s *policyService) ListPolicyRules(_ context.Context, req *pb.ListPolicyRulesRequest) (*pb.ListPolicyRulesResponse, error) { limit := int32(50) offset := int32(0) if req.GetPagination() != nil { if req.Pagination.Limit > 0 { limit = req.Pagination.Limit } if req.Pagination.Offset >= 0 { offset = req.Pagination.Offset } } rules, err := s.db.ListPolicyRules(int(limit), int(offset)) if err != nil { return nil, status.Errorf(codes.Internal, "internal error") } var result []*pb.PolicyRule for _, r := range rules { result = append(result, policyRuleRowToProto(&r)) } return &pb.ListPolicyRulesResponse{Rules: result}, nil } func (s *policyService) CreatePolicyRule(ctx context.Context, req *pb.CreatePolicyRuleRequest) (*pb.PolicyRule, error) { if req.GetPriority() < 1 { return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)") } if req.GetDescription() == "" { return nil, status.Errorf(codes.InvalidArgument, "description is required") } if err := validateEffect(req.GetEffect()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error()) } if len(req.GetActions()) == 0 { return nil, status.Errorf(codes.InvalidArgument, "at least one action is required") } if err := validateActions(req.GetActions()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error()) } claims := auth.ClaimsFromContext(ctx) createdBy := "" if claims != nil { createdBy = claims.Subject } row := db.PolicyRuleRow{ Priority: int(req.Priority), Description: req.Description, Effect: req.Effect, Roles: req.Roles, AccountTypes: req.AccountTypes, SubjectUUID: req.SubjectUuid, Actions: req.Actions, Repositories: req.Repositories, Enabled: req.Enabled, CreatedBy: createdBy, } id, err := s.db.CreatePolicyRule(row) if err != nil { return nil, status.Errorf(codes.Internal, "internal error") } // Reload policy engine. if s.engine != nil { _ = s.engine.Reload(s.db) } if s.auditFn != nil { s.auditFn("policy_rule_created", createdBy, "", "", "", map[string]string{ "rule_id": strconv.FormatInt(id, 10), }) } created, err := s.db.GetPolicyRule(id) if err != nil { return nil, status.Errorf(codes.Internal, "internal error") } return policyRuleRowToProto(created), nil } func (s *policyService) GetPolicyRule(_ context.Context, req *pb.GetPolicyRuleRequest) (*pb.PolicyRule, error) { if req.GetId() == 0 { return nil, status.Errorf(codes.InvalidArgument, "rule ID required") } rule, err := s.db.GetPolicyRule(req.Id) if err != nil { if errors.Is(err, db.ErrPolicyRuleNotFound) { return nil, status.Errorf(codes.NotFound, "policy rule not found") } return nil, status.Errorf(codes.Internal, "internal error") } return policyRuleRowToProto(rule), nil } func (s *policyService) UpdatePolicyRule(ctx context.Context, req *pb.UpdatePolicyRuleRequest) (*pb.PolicyRule, error) { if req.GetId() == 0 { return nil, status.Errorf(codes.InvalidArgument, "rule ID required") } mask := make(map[string]bool, len(req.GetUpdateMask())) for _, f := range req.GetUpdateMask() { mask[f] = true } // Validate fields if they are in the update mask. if mask["priority"] && req.Priority < 1 { return nil, status.Errorf(codes.InvalidArgument, "priority must be >= 1 (0 is reserved for built-ins)") } if mask["effect"] { if err := validateEffect(req.GetEffect()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error()) } } if mask["actions"] { if len(req.GetActions()) == 0 { return nil, status.Errorf(codes.InvalidArgument, "at least one action is required") } if err := validateActions(req.GetActions()); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%s", err.Error()) } } updates := db.PolicyRuleRow{} if mask["priority"] { updates.Priority = int(req.Priority) } if mask["description"] { updates.Description = req.Description } if mask["effect"] { updates.Effect = req.Effect } if mask["roles"] { updates.Roles = req.Roles } if mask["account_types"] { updates.AccountTypes = req.AccountTypes } if mask["subject_uuid"] { updates.SubjectUUID = req.SubjectUuid } if mask["actions"] { updates.Actions = req.Actions } if mask["repositories"] { updates.Repositories = req.Repositories } if err := s.db.UpdatePolicyRule(req.Id, updates); err != nil { if errors.Is(err, db.ErrPolicyRuleNotFound) { return nil, status.Errorf(codes.NotFound, "policy rule not found") } return nil, status.Errorf(codes.Internal, "internal error") } // Handle enabled separately since it's a bool. if mask["enabled"] { if err := s.db.SetPolicyRuleEnabled(req.Id, req.Enabled); err != nil { return nil, status.Errorf(codes.Internal, "internal error") } } // Reload policy engine. if s.engine != nil { _ = s.engine.Reload(s.db) } if s.auditFn != nil { claims := auth.ClaimsFromContext(ctx) actorID := "" if claims != nil { actorID = claims.Subject } s.auditFn("policy_rule_updated", actorID, "", "", "", map[string]string{ "rule_id": strconv.FormatInt(req.Id, 10), }) } updated, err := s.db.GetPolicyRule(req.Id) if err != nil { return nil, status.Errorf(codes.Internal, "internal error") } return policyRuleRowToProto(updated), nil } func (s *policyService) DeletePolicyRule(ctx context.Context, req *pb.DeletePolicyRuleRequest) (*pb.DeletePolicyRuleResponse, error) { if req.GetId() == 0 { return nil, status.Errorf(codes.InvalidArgument, "rule ID required") } if err := s.db.DeletePolicyRule(req.Id); err != nil { if errors.Is(err, db.ErrPolicyRuleNotFound) { return nil, status.Errorf(codes.NotFound, "policy rule not found") } return nil, status.Errorf(codes.Internal, "internal error") } // Reload policy engine. if s.engine != nil { _ = s.engine.Reload(s.db) } if s.auditFn != nil { claims := auth.ClaimsFromContext(ctx) actorID := "" if claims != nil { actorID = claims.Subject } s.auditFn("policy_rule_deleted", actorID, "", "", "", map[string]string{ "rule_id": strconv.FormatInt(req.Id, 10), }) } return &pb.DeletePolicyRuleResponse{}, nil } // policyRuleRowToProto converts a db.PolicyRuleRow to a protobuf PolicyRule. func policyRuleRowToProto(r *db.PolicyRuleRow) *pb.PolicyRule { return &pb.PolicyRule{ Id: r.ID, Priority: int32(r.Priority), //nolint:gosec // priority is always small positive int Description: r.Description, Effect: r.Effect, Roles: r.Roles, AccountTypes: r.AccountTypes, SubjectUuid: r.SubjectUUID, Actions: r.Actions, Repositories: r.Repositories, Enabled: r.Enabled, CreatedBy: r.CreatedBy, CreatedAt: r.CreatedAt, UpdatedAt: r.UpdatedAt, } } func validateEffect(effect string) error { if effect != "allow" && effect != "deny" { return fmt.Errorf("invalid effect: %q (must be 'allow' or 'deny')", effect) } return nil } func validateActions(actions []string) error { for _, a := range actions { if !validActions[a] { return fmt.Errorf("invalid action: %q", a) } } return nil }