Add L7 policies for user-agent blocking and required headers
Per-route HTTP-level blocking policies for L7 routes. Two rule types: block_user_agent (substring match against User-Agent, returns 403) and require_header (named header must be present, returns 403). Config: L7Policy struct with type/value fields, added as L7Policies slice on Route. Validated in config (type enum, non-empty value, warning if set on L4 routes). DB: Migration 4 creates l7_policies table with route_id FK (cascade delete), type CHECK constraint, UNIQUE(route_id, type, value). New l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy, GetRouteID. Seed updated to persist policies from config. L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates rules in order, returns 403 on first match, no-op if empty. Composed into the handler chain between context injection and reverse proxy. Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy mutation methods on ListenerState. handleL7 threads policies into l7.RouteConfig. Startup loads policies per L7 route from DB. Proto: L7Policy message, repeated l7_policies on Route. Three new RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the write-through pattern. Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy methods. CLI: mcproxyctl policies list/add/remove subcommands. Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match, header present/absent, multiple rules). 4 DB tests (CRUD, cascade, duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation). 2 end-to-end L7 tests (UA block, required header with allow/deny). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -89,6 +89,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
||||
ListenerAddr: ls.Addr,
|
||||
}
|
||||
for hostname, route := range routes {
|
||||
var policies []*pb.L7Policy
|
||||
for _, p := range route.L7Policies {
|
||||
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
||||
}
|
||||
resp.Routes = append(resp.Routes, &pb.Route{
|
||||
Hostname: hostname,
|
||||
Backend: route.Backend,
|
||||
@@ -97,6 +101,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
||||
TlsKey: route.TLSKey,
|
||||
BackendTls: route.BackendTLS,
|
||||
SendProxyProtocol: route.SendProxyProtocol,
|
||||
L7Policies: policies,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
@@ -187,6 +192,100 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
|
||||
return &pb.RemoveRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// ListL7Policies returns L7 policies for a route.
|
||||
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostname := strings.ToLower(req.Hostname)
|
||||
routes := ls.Routes()
|
||||
route, ok := routes[hostname]
|
||||
if !ok {
|
||||
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
|
||||
}
|
||||
|
||||
var policies []*pb.L7Policy
|
||||
for _, p := range route.L7Policies {
|
||||
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
||||
}
|
||||
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
|
||||
}
|
||||
|
||||
// AddL7Policy adds an L7 policy to a route (write-through).
|
||||
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
|
||||
if req.Policy == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
||||
}
|
||||
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
|
||||
}
|
||||
if req.Policy.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "policy value is required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostname := strings.ToLower(req.Hostname)
|
||||
|
||||
// Get route ID from DB.
|
||||
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
||||
}
|
||||
|
||||
// Write-through: DB first.
|
||||
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
// Update in-memory state.
|
||||
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
|
||||
|
||||
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
|
||||
return &pb.AddL7PolicyResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveL7Policy removes an L7 policy from a route (write-through).
|
||||
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
|
||||
if req.Policy == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hostname := strings.ToLower(req.Hostname)
|
||||
|
||||
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||
}
|
||||
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
||||
}
|
||||
|
||||
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||
}
|
||||
|
||||
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
|
||||
|
||||
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
|
||||
return &pb.RemoveL7PolicyResponse{}, nil
|
||||
}
|
||||
|
||||
// GetFirewallRules returns all current firewall rules.
|
||||
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
||||
ips, cidrs, countries := a.srv.Firewall().Rules()
|
||||
|
||||
@@ -735,3 +735,97 @@ func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
|
||||
t.Fatalf("expected NotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddListL7Policy(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddL7Policy: %v", err)
|
||||
}
|
||||
|
||||
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("ListL7Policies: %v", err)
|
||||
}
|
||||
if len(resp.Policies) != 1 {
|
||||
t.Fatalf("got %d policies, want 1", len(resp.Policies))
|
||||
}
|
||||
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
|
||||
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
|
||||
}
|
||||
|
||||
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
||||
for _, r := range routeResp.Routes {
|
||||
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
|
||||
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveL7Policy(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
||||
})
|
||||
|
||||
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveL7Policy: %v", err)
|
||||
}
|
||||
|
||||
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
})
|
||||
if len(resp.Policies) != 0 {
|
||||
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddL7PolicyValidation(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid policy type")
|
||||
}
|
||||
|
||||
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty policy value")
|
||||
}
|
||||
|
||||
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil policy")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user