Add L7 policies for user-agent blocking and required headers

Per-route HTTP-level blocking policies for L7 routes. Two rule types:
block_user_agent (substring match against User-Agent, returns 403)
and require_header (named header must be present, returns 403).

Config: L7Policy struct with type/value fields, added as L7Policies
slice on Route. Validated in config (type enum, non-empty value,
warning if set on L4 routes).

DB: Migration 4 creates l7_policies table with route_id FK (cascade
delete), type CHECK constraint, UNIQUE(route_id, type, value). New
l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy,
GetRouteID. Seed updated to persist policies from config.

L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates
rules in order, returns 403 on first match, no-op if empty. Composed
into the handler chain between context injection and reverse proxy.

Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy
mutation methods on ListenerState. handleL7 threads policies into
l7.RouteConfig. Startup loads policies per L7 route from DB.

Proto: L7Policy message, repeated l7_policies on Route. Three new
RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the
write-through pattern.

Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy
methods. CLI: mcproxyctl policies list/add/remove subcommands.

Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match,
header present/absent, multiple rules). 4 DB tests (CRUD, cascade,
duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation).
2 end-to-end L7 tests (UA block, required header with allow/deny).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-25 17:11:05 -07:00
parent 1ad42dbbee
commit 42c7fffc3e
20 changed files with 1613 additions and 136 deletions

View File

@@ -60,14 +60,14 @@ proceeds. Each item is marked:
## Phase 7: L7 Policies
- [ ] 7.1 Config: `L7Policy` struct, `L7Policies` on Route, validation
- [ ] 7.2 DB: migration 4, `l7_policies` table, CRUD in `l7policies.go`
- [ ] 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go`
- [ ] 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig
- [ ] 7.5 Proto/gRPC: `L7Policy` message, policy management RPCs
- [ ] 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand
- [ ] 7.7 Startup: load L7 policies per route in `loadListenersFromDB`
- [ ] 7.8 Tests: middleware unit, DB CRUD + cascade, gRPC round-trip, e2e
- [x] 7.1 Config: `L7Policy` struct, `L7Policies` on Route, validation
- [x] 7.2 DB: migration 4, `l7_policies` table, CRUD in `l7policies.go`
- [x] 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go`
- [x] 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig
- [x] 7.5 Proto/gRPC: `L7Policy` message, policy management RPCs
- [x] 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand
- [x] 7.7 Startup: load L7 policies per route in `loadListenersFromDB`
- [x] 7.8 Tests: middleware unit, DB CRUD + cascade, gRPC round-trip, e2e
## Phase 8: Prometheus Metrics

View File

@@ -40,6 +40,12 @@ func (c *Client) Close() error {
return c.conn.Close()
}
// L7Policy represents an HTTP-level blocking policy.
type L7Policy struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// Route represents a hostname to backend mapping with mode and options.
type Route struct {
Hostname string
@@ -49,6 +55,7 @@ type Route struct {
TLSKey string
BackendTLS bool
SendProxyProtocol bool
L7Policies []L7Policy
}
// ListRoutes returns all routes for the given listener address.
@@ -211,6 +218,42 @@ func (c *Client) SetListenerMaxConnections(ctx context.Context, listenerAddr str
return err
}
// ListL7Policies returns L7 policies for a route.
func (c *Client) ListL7Policies(ctx context.Context, listenerAddr, hostname string) ([]L7Policy, error) {
resp, err := c.admin.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: listenerAddr,
Hostname: hostname,
})
if err != nil {
return nil, err
}
policies := make([]L7Policy, len(resp.Policies))
for i, p := range resp.Policies {
policies[i] = L7Policy{Type: p.Type, Value: p.Value}
}
return policies, nil
}
// AddL7Policy adds an L7 policy to a route.
func (c *Client) AddL7Policy(ctx context.Context, listenerAddr, hostname string, policy L7Policy) error {
_, err := c.admin.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: listenerAddr,
Hostname: hostname,
Policy: &pb.L7Policy{Type: policy.Type, Value: policy.Value},
})
return err
}
// RemoveL7Policy removes an L7 policy from a route.
func (c *Client) RemoveL7Policy(ctx context.Context, listenerAddr, hostname string, policy L7Policy) error {
_, err := c.admin.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
ListenerAddr: listenerAddr,
Hostname: hostname,
Policy: &pb.L7Policy{Type: policy.Type, Value: policy.Value},
})
return err
}
// HealthStatus represents the health of the server.
type HealthStatus int

View File

@@ -132,6 +132,17 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
}
routes := make(map[string]server.RouteInfo, len(dbRoutes))
for _, r := range dbRoutes {
// Load L7 policies for this route.
var policies []server.L7PolicyRule
if r.Mode == "l7" {
dbPolicies, err := store.ListL7Policies(r.ID)
if err != nil {
return nil, fmt.Errorf("loading L7 policies for route %q: %w", r.Hostname, err)
}
for _, p := range dbPolicies {
policies = append(policies, server.L7PolicyRule{Type: p.Type, Value: p.Value})
}
}
routes[strings.ToLower(r.Hostname)] = server.RouteInfo{
Backend: r.Backend,
Mode: r.Mode,
@@ -139,6 +150,7 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
TLSKey: r.TLSKey,
BackendTLS: r.BackendTLS,
SendProxyProtocol: r.SendProxyProtocol,
L7Policies: policies,
}
}
result = append(result, server.ListenerData{

114
cmd/mcproxyctl/policies.go Normal file
View File

@@ -0,0 +1,114 @@
package main
import (
"context"
"fmt"
"time"
"github.com/spf13/cobra"
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
)
func policiesCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "policies",
Short: "Manage L7 policies",
Long: "Manage L7 HTTP policies for mc-proxy routes.",
}
cmd.AddCommand(policiesListCmd())
cmd.AddCommand(policiesAddCmd())
cmd.AddCommand(policiesRemoveCmd())
return cmd
}
func policiesListCmd() *cobra.Command {
return &cobra.Command{
Use: "list LISTENER HOSTNAME",
Short: "List L7 policies for a route",
Args: cobra.ExactArgs(2),
RunE: func(cmd *cobra.Command, args []string) error {
listenerAddr := args[0]
hostname := args[1]
client := clientFromContext(cmd.Context())
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
policies, err := client.ListL7Policies(ctx, listenerAddr, hostname)
if err != nil {
return fmt.Errorf("listing policies: %w", err)
}
if len(policies) == 0 {
fmt.Printf("No L7 policies for %s on %s\n", hostname, listenerAddr)
return nil
}
fmt.Printf("L7 policies for %s on %s:\n", hostname, listenerAddr)
for _, p := range policies {
fmt.Printf(" %-20s %s\n", p.Type, p.Value)
}
return nil
},
}
}
func policiesAddCmd() *cobra.Command {
return &cobra.Command{
Use: "add LISTENER HOSTNAME TYPE VALUE",
Short: "Add an L7 policy",
Long: "Add an L7 policy to a route. TYPE is block_user_agent or require_header.",
Args: cobra.ExactArgs(4),
RunE: func(cmd *cobra.Command, args []string) error {
listenerAddr := args[0]
hostname := args[1]
policyType := args[2]
policyValue := args[3]
client := clientFromContext(cmd.Context())
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
if err := client.AddL7Policy(ctx, listenerAddr, hostname, mcproxy.L7Policy{
Type: policyType,
Value: policyValue,
}); err != nil {
return fmt.Errorf("adding policy: %w", err)
}
fmt.Printf("Added policy: %s %q on %s/%s\n", policyType, policyValue, listenerAddr, hostname)
return nil
},
}
}
func policiesRemoveCmd() *cobra.Command {
return &cobra.Command{
Use: "remove LISTENER HOSTNAME TYPE VALUE",
Short: "Remove an L7 policy",
Args: cobra.ExactArgs(4),
RunE: func(cmd *cobra.Command, args []string) error {
listenerAddr := args[0]
hostname := args[1]
policyType := args[2]
policyValue := args[3]
client := clientFromContext(cmd.Context())
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
defer cancel()
if err := client.RemoveL7Policy(ctx, listenerAddr, hostname, mcproxy.L7Policy{
Type: policyType,
Value: policyValue,
}); err != nil {
return fmt.Errorf("removing policy: %w", err)
}
fmt.Printf("Removed policy: %s %q from %s/%s\n", policyType, policyValue, listenerAddr, hostname)
return nil
},
}
}

View File

@@ -58,6 +58,7 @@ func rootCmd() *cobra.Command {
cmd.AddCommand(healthCmd())
cmd.AddCommand(routesCmd())
cmd.AddCommand(firewallCmd())
cmd.AddCommand(policiesCmd())
return cmd
}

File diff suppressed because it is too large Load Diff

View File

@@ -26,6 +26,9 @@ const (
ProxyAdminService_AddFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddFirewallRule"
ProxyAdminService_RemoveFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveFirewallRule"
ProxyAdminService_SetListenerMaxConnections_FullMethodName = "/mc_proxy.v1.ProxyAdminService/SetListenerMaxConnections"
ProxyAdminService_ListL7Policies_FullMethodName = "/mc_proxy.v1.ProxyAdminService/ListL7Policies"
ProxyAdminService_AddL7Policy_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddL7Policy"
ProxyAdminService_RemoveL7Policy_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveL7Policy"
ProxyAdminService_GetStatus_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetStatus"
)
@@ -43,6 +46,10 @@ type ProxyAdminServiceClient interface {
RemoveFirewallRule(ctx context.Context, in *RemoveFirewallRuleRequest, opts ...grpc.CallOption) (*RemoveFirewallRuleResponse, error)
// Connection limits
SetListenerMaxConnections(ctx context.Context, in *SetListenerMaxConnectionsRequest, opts ...grpc.CallOption) (*SetListenerMaxConnectionsResponse, error)
// L7 policies
ListL7Policies(ctx context.Context, in *ListL7PoliciesRequest, opts ...grpc.CallOption) (*ListL7PoliciesResponse, error)
AddL7Policy(ctx context.Context, in *AddL7PolicyRequest, opts ...grpc.CallOption) (*AddL7PolicyResponse, error)
RemoveL7Policy(ctx context.Context, in *RemoveL7PolicyRequest, opts ...grpc.CallOption) (*RemoveL7PolicyResponse, error)
// Status
GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error)
}
@@ -125,6 +132,36 @@ func (c *proxyAdminServiceClient) SetListenerMaxConnections(ctx context.Context,
return out, nil
}
func (c *proxyAdminServiceClient) ListL7Policies(ctx context.Context, in *ListL7PoliciesRequest, opts ...grpc.CallOption) (*ListL7PoliciesResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(ListL7PoliciesResponse)
err := c.cc.Invoke(ctx, ProxyAdminService_ListL7Policies_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *proxyAdminServiceClient) AddL7Policy(ctx context.Context, in *AddL7PolicyRequest, opts ...grpc.CallOption) (*AddL7PolicyResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(AddL7PolicyResponse)
err := c.cc.Invoke(ctx, ProxyAdminService_AddL7Policy_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *proxyAdminServiceClient) RemoveL7Policy(ctx context.Context, in *RemoveL7PolicyRequest, opts ...grpc.CallOption) (*RemoveL7PolicyResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(RemoveL7PolicyResponse)
err := c.cc.Invoke(ctx, ProxyAdminService_RemoveL7Policy_FullMethodName, in, out, cOpts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *proxyAdminServiceClient) GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error) {
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
out := new(GetStatusResponse)
@@ -149,6 +186,10 @@ type ProxyAdminServiceServer interface {
RemoveFirewallRule(context.Context, *RemoveFirewallRuleRequest) (*RemoveFirewallRuleResponse, error)
// Connection limits
SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error)
// L7 policies
ListL7Policies(context.Context, *ListL7PoliciesRequest) (*ListL7PoliciesResponse, error)
AddL7Policy(context.Context, *AddL7PolicyRequest) (*AddL7PolicyResponse, error)
RemoveL7Policy(context.Context, *RemoveL7PolicyRequest) (*RemoveL7PolicyResponse, error)
// Status
GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error)
mustEmbedUnimplementedProxyAdminServiceServer()
@@ -182,6 +223,15 @@ func (UnimplementedProxyAdminServiceServer) RemoveFirewallRule(context.Context,
func (UnimplementedProxyAdminServiceServer) SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error) {
return nil, status.Error(codes.Unimplemented, "method SetListenerMaxConnections not implemented")
}
func (UnimplementedProxyAdminServiceServer) ListL7Policies(context.Context, *ListL7PoliciesRequest) (*ListL7PoliciesResponse, error) {
return nil, status.Error(codes.Unimplemented, "method ListL7Policies not implemented")
}
func (UnimplementedProxyAdminServiceServer) AddL7Policy(context.Context, *AddL7PolicyRequest) (*AddL7PolicyResponse, error) {
return nil, status.Error(codes.Unimplemented, "method AddL7Policy not implemented")
}
func (UnimplementedProxyAdminServiceServer) RemoveL7Policy(context.Context, *RemoveL7PolicyRequest) (*RemoveL7PolicyResponse, error) {
return nil, status.Error(codes.Unimplemented, "method RemoveL7Policy not implemented")
}
func (UnimplementedProxyAdminServiceServer) GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error) {
return nil, status.Error(codes.Unimplemented, "method GetStatus not implemented")
}
@@ -332,6 +382,60 @@ func _ProxyAdminService_SetListenerMaxConnections_Handler(srv interface{}, ctx c
return interceptor(ctx, in, info, handler)
}
func _ProxyAdminService_ListL7Policies_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(ListL7PoliciesRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ProxyAdminServiceServer).ListL7Policies(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ProxyAdminService_ListL7Policies_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ProxyAdminServiceServer).ListL7Policies(ctx, req.(*ListL7PoliciesRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ProxyAdminService_AddL7Policy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(AddL7PolicyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ProxyAdminServiceServer).AddL7Policy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ProxyAdminService_AddL7Policy_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ProxyAdminServiceServer).AddL7Policy(ctx, req.(*AddL7PolicyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ProxyAdminService_RemoveL7Policy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RemoveL7PolicyRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(ProxyAdminServiceServer).RemoveL7Policy(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: ProxyAdminService_RemoveL7Policy_FullMethodName,
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(ProxyAdminServiceServer).RemoveL7Policy(ctx, req.(*RemoveL7PolicyRequest))
}
return interceptor(ctx, in, info, handler)
}
func _ProxyAdminService_GetStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(GetStatusRequest)
if err := dec(in); err != nil {
@@ -385,6 +489,18 @@ var ProxyAdminService_ServiceDesc = grpc.ServiceDesc{
MethodName: "SetListenerMaxConnections",
Handler: _ProxyAdminService_SetListenerMaxConnections_Handler,
},
{
MethodName: "ListL7Policies",
Handler: _ProxyAdminService_ListL7Policies_Handler,
},
{
MethodName: "AddL7Policy",
Handler: _ProxyAdminService_AddL7Policy_Handler,
},
{
MethodName: "RemoveL7Policy",
Handler: _ProxyAdminService_RemoveL7Policy_Handler,
},
{
MethodName: "GetStatus",
Handler: _ProxyAdminService_GetStatus_Handler,

View File

@@ -52,6 +52,13 @@ type Route struct {
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
L7Policies []L7Policy `toml:"l7_policies"` // HTTP-level policies (L7 only)
}
// L7Policy is an HTTP-level blocking policy for L7 routes.
type L7Policy struct {
Type string `toml:"type"` // "block_user_agent" or "require_header"
Value string `toml:"value"` // UA substring or header name
}
// Firewall holds the global firewall configuration.
@@ -168,6 +175,22 @@ func (c *Config) validate() error {
i, l.Addr, j, r.Hostname, err)
}
}
// Validate L7 policies.
if r.Mode == "l4" && len(r.L7Policies) > 0 {
slog.Warn("L4 route has l7_policies set (ignored)",
"listener", l.Addr, "hostname", r.Hostname)
}
for k, p := range r.L7Policies {
if p.Type != "block_user_agent" && p.Type != "require_header" {
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: type must be \"block_user_agent\" or \"require_header\", got %q",
i, l.Addr, j, r.Hostname, k, p.Type)
}
if p.Value == "" {
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: value is required",
i, l.Addr, j, r.Hostname, k)
}
}
}
}

View File

@@ -508,3 +508,95 @@ func TestMigrationV2Upgrade(t *testing.T) {
t.Fatal("expected proxy_protocol = false")
}
}
func TestL7PolicyCRUD(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
// Create policies.
id1, err := store.CreateL7Policy(rid, "block_user_agent", "BadBot")
if err != nil {
t.Fatalf("create policy 1: %v", err)
}
if id1 == 0 {
t.Fatal("expected non-zero policy ID")
}
if _, err := store.CreateL7Policy(rid, "require_header", "X-API-Key"); err != nil {
t.Fatalf("create policy 2: %v", err)
}
// List policies.
policies, err := store.ListL7Policies(rid)
if err != nil {
t.Fatalf("list: %v", err)
}
if len(policies) != 2 {
t.Fatalf("got %d policies, want 2", len(policies))
}
// Delete one.
if err := store.DeleteL7Policy(rid, "block_user_agent", "BadBot"); err != nil {
t.Fatalf("delete: %v", err)
}
policies, _ = store.ListL7Policies(rid)
if len(policies) != 1 {
t.Fatalf("got %d policies after delete, want 1", len(policies))
}
if policies[0].Type != "require_header" {
t.Fatalf("remaining policy type = %q, want %q", policies[0].Type, "require_header")
}
}
func TestL7PolicyCascadeDelete(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
store.CreateL7Policy(rid, "block_user_agent", "Bot")
// Deleting the route should cascade-delete its policies.
store.DeleteRoute(lid, "api.test")
policies, _ := store.ListL7Policies(rid)
if len(policies) != 0 {
t.Fatalf("got %d policies after cascade delete, want 0", len(policies))
}
}
func TestL7PolicyDuplicate(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err != nil {
t.Fatalf("first create: %v", err)
}
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err == nil {
t.Fatal("expected error for duplicate policy")
}
}
func TestGetRouteID(t *testing.T) {
store := openTestDB(t)
lid, _ := store.CreateListener(":443", false, 0)
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
rid, err := store.GetRouteID(lid, "api.test")
if err != nil {
t.Fatalf("GetRouteID: %v", err)
}
if rid == 0 {
t.Fatal("expected non-zero route ID")
}
_, err = store.GetRouteID(lid, "nonexistent.test")
if err == nil {
t.Fatal("expected error for nonexistent route")
}
}

74
internal/db/l7policies.go Normal file
View File

@@ -0,0 +1,74 @@
package db
import "fmt"
// L7Policy is a database L7 policy record.
type L7Policy struct {
ID int64
RouteID int64
Type string // "block_user_agent" or "require_header"
Value string
}
// ListL7Policies returns all L7 policies for a route.
func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
rows, err := s.db.Query(
"SELECT id, route_id, type, value FROM l7_policies WHERE route_id = ? ORDER BY id",
routeID,
)
if err != nil {
return nil, fmt.Errorf("querying l7 policies: %w", err)
}
defer rows.Close()
var policies []L7Policy
for rows.Next() {
var p L7Policy
if err := rows.Scan(&p.ID, &p.RouteID, &p.Type, &p.Value); err != nil {
return nil, fmt.Errorf("scanning l7 policy: %w", err)
}
policies = append(policies, p)
}
return policies, rows.Err()
}
// CreateL7Policy inserts an L7 policy and returns its ID.
func (s *Store) CreateL7Policy(routeID int64, policyType, value string) (int64, error) {
result, err := s.db.Exec(
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
routeID, policyType, value,
)
if err != nil {
return 0, fmt.Errorf("inserting l7 policy: %w", err)
}
return result.LastInsertId()
}
// DeleteL7Policy deletes an L7 policy by route ID, type, and value.
func (s *Store) DeleteL7Policy(routeID int64, policyType, value string) error {
result, err := s.db.Exec(
"DELETE FROM l7_policies WHERE route_id = ? AND type = ? AND value = ?",
routeID, policyType, value,
)
if err != nil {
return fmt.Errorf("deleting l7 policy: %w", err)
}
n, _ := result.RowsAffected()
if n == 0 {
return fmt.Errorf("l7 policy not found (route %d, type %q, value %q)", routeID, policyType, value)
}
return nil
}
// GetRouteID returns the route ID for a listener/hostname pair.
func (s *Store) GetRouteID(listenerID int64, hostname string) (int64, error) {
var id int64
err := s.db.QueryRow(
"SELECT id FROM routes WHERE listener_id = ? AND hostname = ?",
listenerID, hostname,
).Scan(&id)
if err != nil {
return 0, fmt.Errorf("looking up route %q on listener %d: %w", hostname, listenerID, err)
}
return id, nil
}

View File

@@ -45,6 +45,19 @@ ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0;`,
Name: "add_listener_max_connections",
SQL: `ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0;`,
},
{
Version: 4,
Name: "create_l7_policies_table",
SQL: `
CREATE TABLE IF NOT EXISTS l7_policies (
id INTEGER PRIMARY KEY,
route_id INTEGER NOT NULL REFERENCES routes(id) ON DELETE CASCADE,
type TEXT NOT NULL CHECK(type IN ('block_user_agent', 'require_header')),
value TEXT NOT NULL,
UNIQUE(route_id, type, value)
);
CREATE INDEX IF NOT EXISTS idx_l7_policies_route ON l7_policies(route_id);`,
},
}
// Migrate runs all unapplied migrations sequentially.

View File

@@ -31,7 +31,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if mode == "" {
mode = "l4"
}
_, err := tx.Exec(
routeResult, err := tx.Exec(
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
listenerID, strings.ToLower(r.Hostname), r.Backend,
@@ -40,6 +40,18 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
if err != nil {
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)
}
if len(r.L7Policies) > 0 {
routeID, _ := routeResult.LastInsertId()
for _, p := range r.L7Policies {
if _, err := tx.Exec(
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
routeID, p.Type, p.Value,
); err != nil {
return fmt.Errorf("seeding l7 policy on route %q: %w", r.Hostname, err)
}
}
}
}
}

View File

@@ -89,6 +89,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
ListenerAddr: ls.Addr,
}
for hostname, route := range routes {
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: route.Backend,
@@ -97,6 +101,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
TlsKey: route.TLSKey,
BackendTls: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
L7Policies: policies,
})
}
return resp, nil
@@ -187,6 +192,100 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
return &pb.RemoveRouteResponse{}, nil
}
// ListL7Policies returns L7 policies for a route.
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
routes := ls.Routes()
route, ok := routes[hostname]
if !ok {
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
}
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
}
// AddL7Policy adds an L7 policy to a route (write-through).
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
}
if req.Policy.Value == "" {
return nil, status.Error(codes.InvalidArgument, "policy value is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
// Get route ID from DB.
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
// Write-through: DB first.
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
// Update in-memory state.
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
return &pb.AddL7PolicyResponse{}, nil
}
// RemoveL7Policy removes an L7 policy from a route (write-through).
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
return &pb.RemoveL7PolicyResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()

View File

@@ -735,3 +735,97 @@ func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddListL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
})
if err != nil {
t.Fatalf("AddL7Policy: %v", err)
}
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("ListL7Policies: %v", err)
}
if len(resp.Policies) != 1 {
t.Fatalf("got %d policies, want 1", len(resp.Policies))
}
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
}
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
for _, r := range routeResp.Routes {
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
}
}
}
func TestRemoveL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
if err != nil {
t.Fatalf("RemoveL7Policy: %v", err)
}
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if len(resp.Policies) != 0 {
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
}
}
func TestAddL7PolicyValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
})
if err == nil {
t.Fatal("expected error for invalid policy type")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
})
if err == nil {
t.Fatal("expected error for empty policy value")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err == nil {
t.Fatal("expected error for nil policy")
}
}

38
internal/l7/policy.go Normal file
View File

@@ -0,0 +1,38 @@
package l7
import (
"net/http"
"strings"
)
// PolicyRule defines an L7 blocking policy.
type PolicyRule struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// PolicyMiddleware returns an http.Handler that evaluates L7 policies
// before delegating to next. Returns HTTP 403 if any policy blocks.
// If policies is empty, returns next unchanged.
func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler {
if len(policies) == 0 {
return next
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, p := range policies {
switch p.Type {
case "block_user_agent":
if strings.Contains(r.UserAgent(), p.Value) {
w.WriteHeader(http.StatusForbidden)
return
}
case "require_header":
if r.Header.Get(p.Value) == "" {
w.WriteHeader(http.StatusForbidden)
return
}
}
}
next.ServeHTTP(w, r)
})
}

158
internal/l7/policy_test.go Normal file
View File

@@ -0,0 +1,158 @@
package l7
import (
"net/http"
"net/http/httptest"
"testing"
)
func TestPolicyMiddlewareNoPolicies(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
handler := PolicyMiddleware(nil, next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyBlockUserAgentMatch(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 BadBot/1.0")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("status = %d, want 403", w.Code)
}
}
func TestPolicyBlockUserAgentNoMatch(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "Mozilla/5.0 GoodBrowser/1.0")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyRequireHeaderPresent(t *testing.T) {
called := false
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
called = true
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("X-API-Key", "secret")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if !called {
t.Fatal("next handler was not called")
}
if w.Code != 200 {
t.Fatalf("status = %d, want 200", w.Code)
}
}
func TestPolicyRequireHeaderAbsent(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "require_header", Value: "X-API-Key"},
}
handler := PolicyMiddleware(policies, next)
req := httptest.NewRequest("GET", "/", nil)
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("status = %d, want 403", w.Code)
}
}
func TestPolicyMultipleRules(t *testing.T) {
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
})
policies := []PolicyRule{
{Type: "block_user_agent", Value: "BadBot"},
{Type: "require_header", Value: "X-Token"},
}
handler := PolicyMiddleware(policies, next)
// Blocked by UA even though header is present.
req := httptest.NewRequest("GET", "/", nil)
req.Header.Set("User-Agent", "BadBot/1.0")
req.Header.Set("X-Token", "abc")
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
if w.Code != 403 {
t.Fatalf("UA block: status = %d, want 403", w.Code)
}
// Good UA but missing header.
req2 := httptest.NewRequest("GET", "/", nil)
req2.Header.Set("User-Agent", "GoodBot/1.0")
w2 := httptest.NewRecorder()
handler.ServeHTTP(w2, req2)
if w2.Code != 403 {
t.Fatalf("missing header: status = %d, want 403", w2.Code)
}
// Good UA and header present — passes.
req3 := httptest.NewRequest("GET", "/", nil)
req3.Header.Set("User-Agent", "GoodBot/1.0")
req3.Header.Set("X-Token", "abc")
w3 := httptest.NewRecorder()
handler.ServeHTTP(w3, req3)
if w3.Code != 200 {
t.Fatalf("pass: status = %d, want 200", w3.Code)
}
}

View File

@@ -26,6 +26,7 @@ type RouteConfig struct {
BackendTLS bool
SendProxyProtocol bool
ConnectTimeout time.Duration
Policies []PolicyRule
}
// contextKey is an unexported type for context keys in this package.
@@ -74,10 +75,12 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
return fmt.Errorf("creating reverse proxy: %w", err)
}
// Wrap the handler to inject the real client IP into the request context.
// Build handler chain: context injection → L7 policies → reverse proxy.
var inner http.Handler = rp
inner = PolicyMiddleware(route.Policies, inner)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
rp.ServeHTTP(w, r)
inner.ServeHTTP(w, r)
})
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,

View File

@@ -551,3 +551,115 @@ func TestL7HTTP11Fallback(t *testing.T) {
t.Fatal("empty response body")
}
}
func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
certPath, keyPath := testCert(t, "policy.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "should-not-reach")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
Policies: []PolicyRule{
{Type: "block_user_agent", Value: "EvilBot"},
},
}
go func() {
conn, err := proxyLn.Accept()
if err != nil {
return
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
req, _ := http.NewRequest("GET", "https://policy.test/", nil)
req.Header.Set("User-Agent", "EvilBot/1.0")
resp, err := client.Do(req)
if err != nil {
t.Fatalf("GET: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != 403 {
t.Fatalf("status = %d, want 403", resp.StatusCode)
}
}
func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
certPath, keyPath := testCert(t, "reqhdr.test")
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, "ok")
}))
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("proxy listen: %v", err)
}
defer proxyLn.Close()
route := RouteConfig{
Backend: backendAddr,
TLSCert: certPath,
TLSKey: keyPath,
ConnectTimeout: 5 * time.Second,
Policies: []PolicyRule{
{Type: "require_header", Value: "X-Auth-Token"},
},
}
// Accept two connections (one blocked, one allowed).
go func() {
for range 2 {
conn, err := proxyLn.Accept()
if err != nil {
return
}
go func() {
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
}()
}
}()
// Without the required header → 403.
client1 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
resp1, err := client1.Get("https://reqhdr.test/")
if err != nil {
t.Fatalf("GET without header: %v", err)
}
resp1.Body.Close()
if resp1.StatusCode != 403 {
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
}
// With the required header → 200.
client2 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
req, _ := http.NewRequest("GET", "https://reqhdr.test/", nil)
req.Header.Set("X-Auth-Token", "valid-token")
resp2, err := client2.Do(req)
if err != nil {
t.Fatalf("GET with header: %v", err)
}
defer resp2.Body.Close()
body, _ := io.ReadAll(resp2.Body)
if resp2.StatusCode != 200 {
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
}
if string(body) != "ok" {
t.Fatalf("body = %q, want %q", body, "ok")
}
}

View File

@@ -19,6 +19,12 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
)
// L7PolicyRule is an L7 blocking policy attached to a route.
type L7PolicyRule struct {
Type string // "block_user_agent" or "require_header"
Value string
}
// RouteInfo holds the full configuration for a single route.
type RouteInfo struct {
Backend string
@@ -27,6 +33,7 @@ type RouteInfo struct {
TLSKey string
BackendTLS bool
SendProxyProtocol bool
L7Policies []L7PolicyRule
}
// ListenerState holds the mutable state for a single proxy listener.
@@ -91,6 +98,40 @@ func (ls *ListenerState) RemoveRoute(hostname string) error {
return nil
}
// AddL7Policy appends an L7 policy to a route's policy list.
func (ls *ListenerState) AddL7Policy(hostname string, policy L7PolicyRule) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
if route, ok := ls.routes[key]; ok {
route.L7Policies = append(route.L7Policies, policy)
ls.routes[key] = route
}
}
// RemoveL7Policy removes an L7 policy from a route's policy list.
func (ls *ListenerState) RemoveL7Policy(hostname, policyType, policyValue string) {
key := strings.ToLower(hostname)
ls.mu.Lock()
defer ls.mu.Unlock()
route, ok := ls.routes[key]
if !ok {
return
}
filtered := route.L7Policies[:0]
for _, p := range route.L7Policies {
if p.Type != policyType || p.Value != policyValue {
filtered = append(filtered, p)
}
}
route.L7Policies = filtered
ls.routes[key] = route
}
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
ls.mu.RLock()
defer ls.mu.RUnlock()
@@ -362,6 +403,11 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
var policies []l7.PolicyRule
for _, p := range route.L7Policies {
policies = append(policies, l7.PolicyRule{Type: p.Type, Value: p.Value})
}
rc := l7.RouteConfig{
Backend: route.Backend,
TLSCert: route.TLSCert,
@@ -369,6 +415,7 @@ func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, c
BackendTLS: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
Policies: policies,
}
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {

View File

@@ -20,12 +20,22 @@ service ProxyAdminService {
// Connection limits
rpc SetListenerMaxConnections(SetListenerMaxConnectionsRequest) returns (SetListenerMaxConnectionsResponse);
// L7 policies
rpc ListL7Policies(ListL7PoliciesRequest) returns (ListL7PoliciesResponse);
rpc AddL7Policy(AddL7PolicyRequest) returns (AddL7PolicyResponse);
rpc RemoveL7Policy(RemoveL7PolicyRequest) returns (RemoveL7PolicyResponse);
// Status
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
}
// Routes
message L7Policy {
string type = 1; // "block_user_agent" or "require_header"
string value = 2;
}
message Route {
string hostname = 1;
string backend = 2;
@@ -34,6 +44,7 @@ message Route {
string tls_key = 5; // PEM private key path (L7 only)
bool backend_tls = 6; // re-encrypt to backend (L7 only)
bool send_proxy_protocol = 7; // send PROXY v2 header to backend
repeated L7Policy l7_policies = 8; // HTTP-level policies (L7 only)
}
message ListRoutesRequest {
@@ -59,6 +70,33 @@ message RemoveRouteRequest {
message RemoveRouteResponse {}
// L7 Policies
message ListL7PoliciesRequest {
string listener_addr = 1;
string hostname = 2;
}
message ListL7PoliciesResponse {
repeated L7Policy policies = 1;
}
message AddL7PolicyRequest {
string listener_addr = 1;
string hostname = 2;
L7Policy policy = 3;
}
message AddL7PolicyResponse {}
message RemoveL7PolicyRequest {
string listener_addr = 1;
string hostname = 2;
L7Policy policy = 3;
}
message RemoveL7PolicyResponse {}
// Firewall
enum FirewallRuleType {