Add L7 policies for user-agent blocking and required headers

Per-route HTTP-level blocking policies for L7 routes. Two rule types:
block_user_agent (substring match against User-Agent, returns 403)
and require_header (named header must be present, returns 403).

Config: L7Policy struct with type/value fields, added as L7Policies
slice on Route. Validated in config (type enum, non-empty value,
warning if set on L4 routes).

DB: Migration 4 creates l7_policies table with route_id FK (cascade
delete), type CHECK constraint, UNIQUE(route_id, type, value). New
l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy,
GetRouteID. Seed updated to persist policies from config.

L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates
rules in order, returns 403 on first match, no-op if empty. Composed
into the handler chain between context injection and reverse proxy.

Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy
mutation methods on ListenerState. handleL7 threads policies into
l7.RouteConfig. Startup loads policies per L7 route from DB.

Proto: L7Policy message, repeated l7_policies on Route. Three new
RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the
write-through pattern.

Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy
methods. CLI: mcproxyctl policies list/add/remove subcommands.

Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match,
header present/absent, multiple rules). 4 DB tests (CRUD, cascade,
duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation).
2 end-to-end L7 tests (UA block, required header with allow/deny).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 17:11:05 -07:00
parent 1ad42dbbee
commit 42c7fffc3e
20 changed files with 1613 additions and 136 deletions

View File

@@ -89,6 +89,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
ListenerAddr: ls.Addr,
}
for hostname, route := range routes {
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: route.Backend,
@@ -97,6 +101,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
TlsKey: route.TLSKey,
BackendTls: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
L7Policies: policies,
})
}
return resp, nil
@@ -187,6 +192,100 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
return &pb.RemoveRouteResponse{}, nil
}
// ListL7Policies returns L7 policies for a route.
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
routes := ls.Routes()
route, ok := routes[hostname]
if !ok {
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
}
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
}
// AddL7Policy adds an L7 policy to a route (write-through).
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
}
if req.Policy.Value == "" {
return nil, status.Error(codes.InvalidArgument, "policy value is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
// Get route ID from DB.
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
// Write-through: DB first.
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
// Update in-memory state.
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
return &pb.AddL7PolicyResponse{}, nil
}
// RemoveL7Policy removes an L7 policy from a route (write-through).
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
return &pb.RemoveL7PolicyResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()

View File

@@ -735,3 +735,97 @@ func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddListL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
})
if err != nil {
t.Fatalf("AddL7Policy: %v", err)
}
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("ListL7Policies: %v", err)
}
if len(resp.Policies) != 1 {
t.Fatalf("got %d policies, want 1", len(resp.Policies))
}
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
}
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
for _, r := range routeResp.Routes {
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
}
}
}
func TestRemoveL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
if err != nil {
t.Fatalf("RemoveL7Policy: %v", err)
}
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if len(resp.Policies) != 0 {
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
}
}
func TestAddL7PolicyValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
})
if err == nil {
t.Fatal("expected error for invalid policy type")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
})
if err == nil {
t.Fatal("expected error for empty policy value")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err == nil {
t.Fatal("expected error for nil policy")
}
}