Add L7 policies for user-agent blocking and required headers

Per-route HTTP-level blocking policies for L7 routes. Two rule types:
block_user_agent (substring match against User-Agent, returns 403)
and require_header (named header must be present, returns 403).

Config: L7Policy struct with type/value fields, added as L7Policies
slice on Route. Validated in config (type enum, non-empty value,
warning if set on L4 routes).

DB: Migration 4 creates l7_policies table with route_id FK (cascade
delete), type CHECK constraint, UNIQUE(route_id, type, value). New
l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy,
GetRouteID. Seed updated to persist policies from config.

L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates
rules in order, returns 403 on first match, no-op if empty. Composed
into the handler chain between context injection and reverse proxy.

Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy
mutation methods on ListenerState. handleL7 threads policies into
l7.RouteConfig. Startup loads policies per L7 route from DB.

Proto: L7Policy message, repeated l7_policies on Route. Three new
RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the
write-through pattern.

Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy
methods. CLI: mcproxyctl policies list/add/remove subcommands.

Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match,
header present/absent, multiple rules). 4 DB tests (CRUD, cascade,
duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation).
2 end-to-end L7 tests (UA block, required header with allow/deny).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 17:11:05 -07:00
parent 1ad42dbbee
commit 42c7fffc3e
20 changed files with 1613 additions and 136 deletions

View File

@@ -45,13 +45,20 @@ type Listener struct {
// Route is a proxy route within a listener.
type Route struct {
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
Mode string `toml:"mode"` // "l4" (default) or "l7"
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
Hostname string `toml:"hostname"`
Backend string `toml:"backend"`
Mode string `toml:"mode"` // "l4" (default) or "l7"
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
L7Policies []L7Policy `toml:"l7_policies"` // HTTP-level policies (L7 only)
}
// L7Policy is an HTTP-level blocking policy for L7 routes.
type L7Policy struct {
Type string `toml:"type"` // "block_user_agent" or "require_header"
Value string `toml:"value"` // UA substring or header name
}
// Firewall holds the global firewall configuration.
@@ -168,6 +175,22 @@ func (c *Config) validate() error {
i, l.Addr, j, r.Hostname, err)
}
}
// Validate L7 policies.
if r.Mode == "l4" && len(r.L7Policies) > 0 {
slog.Warn("L4 route has l7_policies set (ignored)",
"listener", l.Addr, "hostname", r.Hostname)
}
for k, p := range r.L7Policies {
if p.Type != "block_user_agent" && p.Type != "require_header" {
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: type must be \"block_user_agent\" or \"require_header\", got %q",
i, l.Addr, j, r.Hostname, k, p.Type)
}
if p.Value == "" {
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: value is required",
i, l.Addr, j, r.Hostname, k)
}
}
}
}

View File

@@ -508,3 +508,95 @@ func TestMigrationV2Upgrade(t *testing.T) {
t.Fatal("expected proxy_protocol = false")
}
}
func TestL7PolicyCRUD(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
// Create policies.
id1, err := store.CreateL7Policy(rid, "block_user_agent", "BadBot")
if err != nil {
t.Fatalf("create policy 1: %v", err)
}
if id1 == 0 {
t.Fatal("expected non-zero policy ID")
}
if _, err := store.CreateL7Policy(rid, "require_header", "X-API-Key"); err != nil {
t.Fatalf("create policy 2: %v", err)
}
// List policies.
policies, err := store.ListL7Policies(rid)
if err != nil {
t.Fatalf("list: %v", err)
}
if len(policies) != 2 {
t.Fatalf("got %d policies, want 2", len(policies))
}
// Delete one.
if err := store.DeleteL7Policy(rid, "block_user_agent", "BadBot"); err != nil {
t.Fatalf("delete: %v", err)
}
policies, _ = store.ListL7Policies(rid)
if len(policies) != 1 {
t.Fatalf("got %d policies after delete, want 1", len(policies))
}
if policies[0].Type != "require_header" {
t.Fatalf("remaining policy type = %q, want %q", policies[0].Type, "require_header")
}
}
func TestL7PolicyCascadeDelete(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
store.CreateL7Policy(rid, "block_user_agent", "Bot")
// Deleting the route should cascade-delete its policies.
store.DeleteRoute(lid, "api.test")
policies, _ := store.ListL7Policies(rid)
if len(policies) != 0 {
t.Fatalf("got %d policies after cascade delete, want 0", len(policies))
}
}
func TestL7PolicyDuplicate(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err != nil {
t.Fatalf("first create: %v", err)
}
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err == nil {
t.Fatal("expected error for duplicate policy")
}
}
func TestGetRouteID(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
rid, err := store.GetRouteID(lid, "api.test")
if err != nil {
t.Fatalf("GetRouteID: %v", err)
}
if rid == 0 {
t.Fatal("expected non-zero route ID")
}
_, err = store.GetRouteID(lid, "nonexistent.test")
if err == nil {
t.Fatal("expected error for nonexistent route")
}
}

74
internal/db/l7policies.go Normal file
View File

@@ -0,0 +1,74 @@
package db
import "fmt"
// L7Policy is a database L7 policy record.
type L7Policy struct {
ID int64
RouteID int64
Type string // "block_user_agent" or "require_header"
Value string
}
// ListL7Policies returns all L7 policies for a route.
func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
rows, err := s.db.Query(
"SELECT id, route_id, type, value FROM l7_policies WHERE route_id = ? ORDER BY id",
routeID,
)
if err != nil {
return nil, fmt.Errorf("querying l7 policies: %w", err)
}
defer rows.Close()
var policies []L7Policy
for rows.Next() {
var p L7Policy
if err := rows.Scan(&p.ID, &p.RouteID, &p.Type, &p.Value); err != nil {
return nil, fmt.Errorf("scanning l7 policy: %w", err)
}
policies = append(policies, p)
}
return policies, rows.Err()
}
// CreateL7Policy inserts an L7 policy and returns its ID.
func (s *Store) CreateL7Policy(routeID int64, policyType, value string) (int64, error) {
result, err := s.db.Exec(
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
routeID, policyType, value,
)
if err != nil {
return 0, fmt.Errorf("inserting l7 policy: %w", err)
}
return result.LastInsertId()
}
// DeleteL7Policy deletes an L7 policy by route ID, type, and value.
func (s *Store) DeleteL7Policy(routeID int64, policyType, value string) error {
result, err := s.db.Exec(
"DELETE FROM l7_policies WHERE route_id = ? AND type = ? AND value = ?",
routeID, policyType, value,
)
if err != nil {
return fmt.Errorf("deleting l7 policy: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("l7 policy not found (route %d, type %q, value %q)", routeID, policyType, value)
}
return nil
}
// GetRouteID returns the route ID for a listener/hostname pair.
func (s *Store) GetRouteID(listenerID int64, hostname string) (int64, error) {
var id int64
err := s.db.QueryRow(
"SELECT id FROM routes WHERE listener_id = ? AND hostname = ?",
listenerID, hostname,
).Scan(&id)
if err != nil {
return 0, fmt.Errorf("looking up route %q on listener %d: %w", hostname, listenerID, err)
}
return id, nil
}

View File

@@ -45,6 +45,19 @@ ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0;`,
Name: "add_listener_max_connections",
SQL: `ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0;`,
},
{
Version: 4,
Name: "create_l7_policies_table",
SQL: `
CREATE TABLE IF NOT EXISTS l7_policies (
id INTEGER PRIMARY KEY,
route_id INTEGER NOT NULL REFERENCES routes(id) ON DELETE CASCADE,
type TEXT NOT NULL CHECK(type IN ('block_user_agent', 'require_header')),
value TEXT NOT NULL,
UNIQUE(route_id, type, value)
);
CREATE INDEX IF NOT EXISTS idx_l7_policies_route ON l7_policies(route_id);`,
},
}
// Migrate runs all unapplied migrations sequentially.

View File

@@ -31,7 +31,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if mode == "" {
mode = "l4"
}
_, err := tx.Exec(
routeResult, err := tx.Exec(
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
listenerID, strings.ToLower(r.Hostname), r.Backend,
@@ -40,6 +40,18 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if err != nil {
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)
}
if len(r.L7Policies) > 0 {
routeID, _ := routeResult.LastInsertId()
for _, p := range r.L7Policies {
if _, err := tx.Exec(
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
routeID, p.Type, p.Value,
); err != nil {
return fmt.Errorf("seeding l7 policy on route %q: %w", r.Hostname, err)
}
}
}
}
}

View File

@@ -89,6 +89,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
ListenerAddr: ls.Addr,
}
for hostname, route := range routes {
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: route.Backend,
@@ -97,6 +101,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
TlsKey: route.TLSKey,
BackendTls: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
L7Policies: policies,
})
}
return resp, nil
@@ -187,6 +192,100 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
return &pb.RemoveRouteResponse{}, nil
}
// ListL7Policies returns L7 policies for a route.
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
routes := ls.Routes()
route, ok := routes[hostname]
if !ok {
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
}
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
}
// AddL7Policy adds an L7 policy to a route (write-through).
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
}
if req.Policy.Value == "" {
return nil, status.Error(codes.InvalidArgument, "policy value is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
// Get route ID from DB.
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
// Write-through: DB first.
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
// Update in-memory state.
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
return &pb.AddL7PolicyResponse{}, nil
}
// RemoveL7Policy removes an L7 policy from a route (write-through).
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
return &pb.RemoveL7PolicyResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()

View File

@@ -735,3 +735,97 @@ func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddListL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
})
if err != nil {
t.Fatalf("AddL7Policy: %v", err)
}
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("ListL7Policies: %v", err)
}
if len(resp.Policies) != 1 {
t.Fatalf("got %d policies, want 1", len(resp.Policies))
}
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
}
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
for _, r := range routeResp.Routes {
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
}
}
}
func TestRemoveL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
if err != nil {
t.Fatalf("RemoveL7Policy: %v", err)
}
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if len(resp.Policies) != 0 {
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
}
}
func TestAddL7PolicyValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
})
if err == nil {
t.Fatal("expected error for invalid policy type")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
})
if err == nil {
t.Fatal("expected error for empty policy value")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err == nil {
t.Fatal("expected error for nil policy")
}
}

38
internal/l7/policy.go Normal file
View File

@@ -0,0 +1,38 @@
package l7
import (
"net/http"
"strings"
)
// PolicyRule defines an L7 blocking policy.
type PolicyRule struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// PolicyMiddleware returns an http.Handler that evaluates L7 policies
// before delegating to next. Returns HTTP 403 if any policy blocks.
// If policies is empty, returns next unchanged.
func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler {
if len(policies) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, p := range policies {
switch p.Type {
case "block_user_agent":
if strings.Contains(r.UserAgent(), p.Value) {
w.WriteHeader(http.StatusForbidden)
return
}
case "require_header":
if r.Header.Get(p.Value) == "" {
w.WriteHeader(http.StatusForbidden)
return
}
}
}
next.ServeHTTP(w, r)
})
}

158
internal/l7/policy_test.go Normal file
View File

@@ -0,0 +1,158 @@
package l7
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestPolicyMiddlewareNoPolicies(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
handler := PolicyMiddleware(nil, next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyBlockUserAgentMatch(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 BadBot/1.0")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("status = %d, want 403", w.Code)
}
}
func TestPolicyBlockUserAgentNoMatch(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 GoodBrowser/1.0")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyRequireHeaderPresent(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-API-Key", "secret")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyRequireHeaderAbsent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("status = %d, want 403", w.Code)
}
}
func TestPolicyMultipleRules(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
{Type: "require_header", Value: "X-Token"},
}
handler := PolicyMiddleware(policies, next)
// Blocked by UA even though header is present.
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "BadBot/1.0")
req.Header.Set("X-Token", "abc")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("UA block: status = %d, want 403", w.Code)
}
// Good UA but missing header.
req2 := httptest.NewRequest("GET", "/", nil)
req2.Header.Set("User-Agent", "GoodBot/1.0")
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w2.Code != 403 {
t.Fatalf("missing header: status = %d, want 403", w2.Code)
}
// Good UA and header present — passes.
req3 := httptest.NewRequest("GET", "/", nil)
req3.Header.Set("User-Agent", "GoodBot/1.0")
req3.Header.Set("X-Token", "abc")
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
if w3.Code != 200 {
t.Fatalf("pass: status = %d, want 200", w3.Code)
}
}

View File

@@ -26,6 +26,7 @@ type RouteConfig struct {
BackendTLS bool
SendProxyProtocol bool
ConnectTimeout time.Duration
Policies []PolicyRule
}
// contextKey is an unexported type for context keys in this package.
@@ -74,10 +75,12 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Wrap the handler to inject the real client IP into the request context.
// Build handler chain: context injection → L7 policies → reverse proxy.
var inner http.Handler = rp
inner = PolicyMiddleware(route.Policies, inner)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
rp.ServeHTTP(w, r)
inner.ServeHTTP(w, r)
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,

View File

@@ -551,3 +551,115 @@ func TestL7HTTP11Fallback(t *testing.T) {
t.Fatal("empty response body")
}
}
func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
certPath, keyPath := testCert(t, "policy.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "should-not-reach")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
Policies: []PolicyRule{
{Type: "block_user_agent", Value: "EvilBot"},
},
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
req, _ := http.NewRequest("GET", "https://policy.test/", nil)
req.Header.Set("User-Agent", "EvilBot/1.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 403 {
t.Fatalf("status = %d, want 403", resp.StatusCode)
}
}
func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
certPath, keyPath := testCert(t, "reqhdr.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
Policies: []PolicyRule{
{Type: "require_header", Value: "X-Auth-Token"},
},
}
// Accept two connections (one blocked, one allowed).
go func() {
for range 2 {
conn, err := proxyLn.Accept()
if err != nil {
return
}
go func() {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
}
}()
// Without the required header → 403.
client1 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
resp1, err := client1.Get("https://reqhdr.test/")
if err != nil {
t.Fatalf("GET without header: %v", err)
}
resp1.Body.Close()
if resp1.StatusCode != 403 {
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
}
// With the required header → 200.
client2 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
req, _ := http.NewRequest("GET", "https://reqhdr.test/", nil)
req.Header.Set("X-Auth-Token", "valid-token")
resp2, err := client2.Do(req)
if err != nil {
t.Fatalf("GET with header: %v", err)
}
defer resp2.Body.Close()
body, _ := io.ReadAll(resp2.Body)
if resp2.StatusCode != 200 {
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
}
if string(body) != "ok" {
t.Fatalf("body = %q, want %q", body, "ok")
}
}

View File

@@ -19,6 +19,12 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
)
// L7PolicyRule is an L7 blocking policy attached to a route.
type L7PolicyRule struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// RouteInfo holds the full configuration for a single route.
type RouteInfo struct {
Backend string
@@ -27,6 +33,7 @@ type RouteInfo struct {
TLSKey string
BackendTLS bool
SendProxyProtocol bool
L7Policies []L7PolicyRule
}
// ListenerState holds the mutable state for a single proxy listener.
@@ -91,6 +98,40 @@ func (ls *ListenerState) RemoveRoute(hostname string) error {
return nil
}
// AddL7Policy appends an L7 policy to a route's policy list.
func (ls *ListenerState) AddL7Policy(hostname string, policy L7PolicyRule) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if route, ok := ls.routes[key]; ok {
route.L7Policies = append(route.L7Policies, policy)
ls.routes[key] = route
}
}
// RemoveL7Policy removes an L7 policy from a route's policy list.
func (ls *ListenerState) RemoveL7Policy(hostname, policyType, policyValue string) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
route, ok := ls.routes[key]
if !ok {
return
}
filtered := route.L7Policies[:0]
for _, p := range route.L7Policies {
if p.Type != policyType || p.Value != policyValue {
filtered = append(filtered, p)
}
}
route.L7Policies = filtered
ls.routes[key] = route
}
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
@@ -362,6 +403,11 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
var policies []l7.PolicyRule
for _, p := range route.L7Policies {
policies = append(policies, l7.PolicyRule{Type: p.Type, Value: p.Value})
}
rc := l7.RouteConfig{
Backend: route.Backend,
TLSCert: route.TLSCert,
@@ -369,6 +415,7 @@ func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, c
BackendTLS: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
Policies: policies,
}
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {