Add L7 policies for user-agent blocking and required headers
Per-route HTTP-level blocking policies for L7 routes. Two rule types: block_user_agent (substring match against User-Agent, returns 403) and require_header (named header must be present, returns 403). Config: L7Policy struct with type/value fields, added as L7Policies slice on Route. Validated in config (type enum, non-empty value, warning if set on L4 routes). DB: Migration 4 creates l7_policies table with route_id FK (cascade delete), type CHECK constraint, UNIQUE(route_id, type, value). New l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy, GetRouteID. Seed updated to persist policies from config. L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates rules in order, returns 403 on first match, no-op if empty. Composed into the handler chain between context injection and reverse proxy. Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy mutation methods on ListenerState. handleL7 threads policies into l7.RouteConfig. Startup loads policies per L7 route from DB. Proto: L7Policy message, repeated l7_policies on Route. Three new RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the write-through pattern. Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy methods. CLI: mcproxyctl policies list/add/remove subcommands. Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match, header present/absent, multiple rules). 4 DB tests (CRUD, cascade, duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation). 2 end-to-end L7 tests (UA block, required header with allow/deny). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
16
PROGRESS.md
16
PROGRESS.md
@@ -60,14 +60,14 @@ proceeds. Each item is marked:
|
|||||||
|
|
||||||
## Phase 7: L7 Policies
|
## Phase 7: L7 Policies
|
||||||
|
|
||||||
- [ ] 7.1 Config: `L7Policy` struct, `L7Policies` on Route, validation
|
- [x] 7.1 Config: `L7Policy` struct, `L7Policies` on Route, validation
|
||||||
- [ ] 7.2 DB: migration 4, `l7_policies` table, CRUD in `l7policies.go`
|
- [x] 7.2 DB: migration 4, `l7_policies` table, CRUD in `l7policies.go`
|
||||||
- [ ] 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go`
|
- [x] 7.3 L7 middleware: `PolicyMiddleware` in `internal/l7/policy.go`
|
||||||
- [ ] 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig
|
- [x] 7.4 Server/L7 integration: thread policies from RouteInfo to RouteConfig
|
||||||
- [ ] 7.5 Proto/gRPC: `L7Policy` message, policy management RPCs
|
- [x] 7.5 Proto/gRPC: `L7Policy` message, policy management RPCs
|
||||||
- [ ] 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand
|
- [x] 7.6 Client/CLI: policy methods, `mcproxyctl policies` subcommand
|
||||||
- [ ] 7.7 Startup: load L7 policies per route in `loadListenersFromDB`
|
- [x] 7.7 Startup: load L7 policies per route in `loadListenersFromDB`
|
||||||
- [ ] 7.8 Tests: middleware unit, DB CRUD + cascade, gRPC round-trip, e2e
|
- [x] 7.8 Tests: middleware unit, DB CRUD + cascade, gRPC round-trip, e2e
|
||||||
|
|
||||||
## Phase 8: Prometheus Metrics
|
## Phase 8: Prometheus Metrics
|
||||||
|
|
||||||
|
|||||||
@@ -40,6 +40,12 @@ func (c *Client) Close() error {
|
|||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// L7Policy represents an HTTP-level blocking policy.
|
||||||
|
type L7Policy struct {
|
||||||
|
Type string // "block_user_agent" or "require_header"
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
// Route represents a hostname to backend mapping with mode and options.
|
// Route represents a hostname to backend mapping with mode and options.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
@@ -49,6 +55,7 @@ type Route struct {
|
|||||||
TLSKey string
|
TLSKey string
|
||||||
BackendTLS bool
|
BackendTLS bool
|
||||||
SendProxyProtocol bool
|
SendProxyProtocol bool
|
||||||
|
L7Policies []L7Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListRoutes returns all routes for the given listener address.
|
// ListRoutes returns all routes for the given listener address.
|
||||||
@@ -211,6 +218,42 @@ func (c *Client) SetListenerMaxConnections(ctx context.Context, listenerAddr str
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListL7Policies returns L7 policies for a route.
|
||||||
|
func (c *Client) ListL7Policies(ctx context.Context, listenerAddr, hostname string) ([]L7Policy, error) {
|
||||||
|
resp, err := c.admin.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
|
||||||
|
ListenerAddr: listenerAddr,
|
||||||
|
Hostname: hostname,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
policies := make([]L7Policy, len(resp.Policies))
|
||||||
|
for i, p := range resp.Policies {
|
||||||
|
policies[i] = L7Policy{Type: p.Type, Value: p.Value}
|
||||||
|
}
|
||||||
|
return policies, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddL7Policy adds an L7 policy to a route.
|
||||||
|
func (c *Client) AddL7Policy(ctx context.Context, listenerAddr, hostname string, policy L7Policy) error {
|
||||||
|
_, err := c.admin.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: listenerAddr,
|
||||||
|
Hostname: hostname,
|
||||||
|
Policy: &pb.L7Policy{Type: policy.Type, Value: policy.Value},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveL7Policy removes an L7 policy from a route.
|
||||||
|
func (c *Client) RemoveL7Policy(ctx context.Context, listenerAddr, hostname string, policy L7Policy) error {
|
||||||
|
_, err := c.admin.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
|
||||||
|
ListenerAddr: listenerAddr,
|
||||||
|
Hostname: hostname,
|
||||||
|
Policy: &pb.L7Policy{Type: policy.Type, Value: policy.Value},
|
||||||
|
})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
// HealthStatus represents the health of the server.
|
// HealthStatus represents the health of the server.
|
||||||
type HealthStatus int
|
type HealthStatus int
|
||||||
|
|
||||||
|
|||||||
@@ -132,6 +132,17 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
|
|||||||
}
|
}
|
||||||
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
||||||
for _, r := range dbRoutes {
|
for _, r := range dbRoutes {
|
||||||
|
// Load L7 policies for this route.
|
||||||
|
var policies []server.L7PolicyRule
|
||||||
|
if r.Mode == "l7" {
|
||||||
|
dbPolicies, err := store.ListL7Policies(r.ID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("loading L7 policies for route %q: %w", r.Hostname, err)
|
||||||
|
}
|
||||||
|
for _, p := range dbPolicies {
|
||||||
|
policies = append(policies, server.L7PolicyRule{Type: p.Type, Value: p.Value})
|
||||||
|
}
|
||||||
|
}
|
||||||
routes[strings.ToLower(r.Hostname)] = server.RouteInfo{
|
routes[strings.ToLower(r.Hostname)] = server.RouteInfo{
|
||||||
Backend: r.Backend,
|
Backend: r.Backend,
|
||||||
Mode: r.Mode,
|
Mode: r.Mode,
|
||||||
@@ -139,6 +150,7 @@ func loadListenersFromDB(store *db.Store) ([]server.ListenerData, error) {
|
|||||||
TLSKey: r.TLSKey,
|
TLSKey: r.TLSKey,
|
||||||
BackendTLS: r.BackendTLS,
|
BackendTLS: r.BackendTLS,
|
||||||
SendProxyProtocol: r.SendProxyProtocol,
|
SendProxyProtocol: r.SendProxyProtocol,
|
||||||
|
L7Policies: policies,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
result = append(result, server.ListenerData{
|
result = append(result, server.ListenerData{
|
||||||
|
|||||||
114
cmd/mcproxyctl/policies.go
Normal file
114
cmd/mcproxyctl/policies.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
mcproxy "git.wntrmute.dev/kyle/mc-proxy/client/mcproxy"
|
||||||
|
)
|
||||||
|
|
||||||
|
func policiesCmd() *cobra.Command {
|
||||||
|
cmd := &cobra.Command{
|
||||||
|
Use: "policies",
|
||||||
|
Short: "Manage L7 policies",
|
||||||
|
Long: "Manage L7 HTTP policies for mc-proxy routes.",
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.AddCommand(policiesListCmd())
|
||||||
|
cmd.AddCommand(policiesAddCmd())
|
||||||
|
cmd.AddCommand(policiesRemoveCmd())
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
}
|
||||||
|
|
||||||
|
func policiesListCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "list LISTENER HOSTNAME",
|
||||||
|
Short: "List L7 policies for a route",
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
listenerAddr := args[0]
|
||||||
|
hostname := args[1]
|
||||||
|
client := clientFromContext(cmd.Context())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
policies, err := client.ListL7Policies(ctx, listenerAddr, hostname)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("listing policies: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(policies) == 0 {
|
||||||
|
fmt.Printf("No L7 policies for %s on %s\n", hostname, listenerAddr)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("L7 policies for %s on %s:\n", hostname, listenerAddr)
|
||||||
|
for _, p := range policies {
|
||||||
|
fmt.Printf(" %-20s %s\n", p.Type, p.Value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func policiesAddCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "add LISTENER HOSTNAME TYPE VALUE",
|
||||||
|
Short: "Add an L7 policy",
|
||||||
|
Long: "Add an L7 policy to a route. TYPE is block_user_agent or require_header.",
|
||||||
|
Args: cobra.ExactArgs(4),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
listenerAddr := args[0]
|
||||||
|
hostname := args[1]
|
||||||
|
policyType := args[2]
|
||||||
|
policyValue := args[3]
|
||||||
|
client := clientFromContext(cmd.Context())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.AddL7Policy(ctx, listenerAddr, hostname, mcproxy.L7Policy{
|
||||||
|
Type: policyType,
|
||||||
|
Value: policyValue,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("adding policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Added policy: %s %q on %s/%s\n", policyType, policyValue, listenerAddr, hostname)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func policiesRemoveCmd() *cobra.Command {
|
||||||
|
return &cobra.Command{
|
||||||
|
Use: "remove LISTENER HOSTNAME TYPE VALUE",
|
||||||
|
Short: "Remove an L7 policy",
|
||||||
|
Args: cobra.ExactArgs(4),
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
listenerAddr := args[0]
|
||||||
|
hostname := args[1]
|
||||||
|
policyType := args[2]
|
||||||
|
policyValue := args[3]
|
||||||
|
client := clientFromContext(cmd.Context())
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(cmd.Context(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if err := client.RemoveL7Policy(ctx, listenerAddr, hostname, mcproxy.L7Policy{
|
||||||
|
Type: policyType,
|
||||||
|
Value: policyValue,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("removing policy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Removed policy: %s %q from %s/%s\n", policyType, policyValue, listenerAddr, hostname)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -58,6 +58,7 @@ func rootCmd() *cobra.Command {
|
|||||||
cmd.AddCommand(healthCmd())
|
cmd.AddCommand(healthCmd())
|
||||||
cmd.AddCommand(routesCmd())
|
cmd.AddCommand(routesCmd())
|
||||||
cmd.AddCommand(firewallCmd())
|
cmd.AddCommand(firewallCmd())
|
||||||
|
cmd.AddCommand(policiesCmd())
|
||||||
|
|
||||||
return cmd
|
return cmd
|
||||||
}
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -26,6 +26,9 @@ const (
|
|||||||
ProxyAdminService_AddFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddFirewallRule"
|
ProxyAdminService_AddFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddFirewallRule"
|
||||||
ProxyAdminService_RemoveFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveFirewallRule"
|
ProxyAdminService_RemoveFirewallRule_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveFirewallRule"
|
||||||
ProxyAdminService_SetListenerMaxConnections_FullMethodName = "/mc_proxy.v1.ProxyAdminService/SetListenerMaxConnections"
|
ProxyAdminService_SetListenerMaxConnections_FullMethodName = "/mc_proxy.v1.ProxyAdminService/SetListenerMaxConnections"
|
||||||
|
ProxyAdminService_ListL7Policies_FullMethodName = "/mc_proxy.v1.ProxyAdminService/ListL7Policies"
|
||||||
|
ProxyAdminService_AddL7Policy_FullMethodName = "/mc_proxy.v1.ProxyAdminService/AddL7Policy"
|
||||||
|
ProxyAdminService_RemoveL7Policy_FullMethodName = "/mc_proxy.v1.ProxyAdminService/RemoveL7Policy"
|
||||||
ProxyAdminService_GetStatus_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetStatus"
|
ProxyAdminService_GetStatus_FullMethodName = "/mc_proxy.v1.ProxyAdminService/GetStatus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -43,6 +46,10 @@ type ProxyAdminServiceClient interface {
|
|||||||
RemoveFirewallRule(ctx context.Context, in *RemoveFirewallRuleRequest, opts ...grpc.CallOption) (*RemoveFirewallRuleResponse, error)
|
RemoveFirewallRule(ctx context.Context, in *RemoveFirewallRuleRequest, opts ...grpc.CallOption) (*RemoveFirewallRuleResponse, error)
|
||||||
// Connection limits
|
// Connection limits
|
||||||
SetListenerMaxConnections(ctx context.Context, in *SetListenerMaxConnectionsRequest, opts ...grpc.CallOption) (*SetListenerMaxConnectionsResponse, error)
|
SetListenerMaxConnections(ctx context.Context, in *SetListenerMaxConnectionsRequest, opts ...grpc.CallOption) (*SetListenerMaxConnectionsResponse, error)
|
||||||
|
// L7 policies
|
||||||
|
ListL7Policies(ctx context.Context, in *ListL7PoliciesRequest, opts ...grpc.CallOption) (*ListL7PoliciesResponse, error)
|
||||||
|
AddL7Policy(ctx context.Context, in *AddL7PolicyRequest, opts ...grpc.CallOption) (*AddL7PolicyResponse, error)
|
||||||
|
RemoveL7Policy(ctx context.Context, in *RemoveL7PolicyRequest, opts ...grpc.CallOption) (*RemoveL7PolicyResponse, error)
|
||||||
// Status
|
// Status
|
||||||
GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error)
|
GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error)
|
||||||
}
|
}
|
||||||
@@ -125,6 +132,36 @@ func (c *proxyAdminServiceClient) SetListenerMaxConnections(ctx context.Context,
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *proxyAdminServiceClient) ListL7Policies(ctx context.Context, in *ListL7PoliciesRequest, opts ...grpc.CallOption) (*ListL7PoliciesResponse, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(ListL7PoliciesResponse)
|
||||||
|
err := c.cc.Invoke(ctx, ProxyAdminService_ListL7Policies_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyAdminServiceClient) AddL7Policy(ctx context.Context, in *AddL7PolicyRequest, opts ...grpc.CallOption) (*AddL7PolicyResponse, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(AddL7PolicyResponse)
|
||||||
|
err := c.cc.Invoke(ctx, ProxyAdminService_AddL7Policy_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *proxyAdminServiceClient) RemoveL7Policy(ctx context.Context, in *RemoveL7PolicyRequest, opts ...grpc.CallOption) (*RemoveL7PolicyResponse, error) {
|
||||||
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
|
out := new(RemoveL7PolicyResponse)
|
||||||
|
err := c.cc.Invoke(ctx, ProxyAdminService_RemoveL7Policy_FullMethodName, in, out, cOpts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *proxyAdminServiceClient) GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error) {
|
func (c *proxyAdminServiceClient) GetStatus(ctx context.Context, in *GetStatusRequest, opts ...grpc.CallOption) (*GetStatusResponse, error) {
|
||||||
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...)
|
||||||
out := new(GetStatusResponse)
|
out := new(GetStatusResponse)
|
||||||
@@ -149,6 +186,10 @@ type ProxyAdminServiceServer interface {
|
|||||||
RemoveFirewallRule(context.Context, *RemoveFirewallRuleRequest) (*RemoveFirewallRuleResponse, error)
|
RemoveFirewallRule(context.Context, *RemoveFirewallRuleRequest) (*RemoveFirewallRuleResponse, error)
|
||||||
// Connection limits
|
// Connection limits
|
||||||
SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error)
|
SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error)
|
||||||
|
// L7 policies
|
||||||
|
ListL7Policies(context.Context, *ListL7PoliciesRequest) (*ListL7PoliciesResponse, error)
|
||||||
|
AddL7Policy(context.Context, *AddL7PolicyRequest) (*AddL7PolicyResponse, error)
|
||||||
|
RemoveL7Policy(context.Context, *RemoveL7PolicyRequest) (*RemoveL7PolicyResponse, error)
|
||||||
// Status
|
// Status
|
||||||
GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error)
|
GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error)
|
||||||
mustEmbedUnimplementedProxyAdminServiceServer()
|
mustEmbedUnimplementedProxyAdminServiceServer()
|
||||||
@@ -182,6 +223,15 @@ func (UnimplementedProxyAdminServiceServer) RemoveFirewallRule(context.Context,
|
|||||||
func (UnimplementedProxyAdminServiceServer) SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error) {
|
func (UnimplementedProxyAdminServiceServer) SetListenerMaxConnections(context.Context, *SetListenerMaxConnectionsRequest) (*SetListenerMaxConnectionsResponse, error) {
|
||||||
return nil, status.Error(codes.Unimplemented, "method SetListenerMaxConnections not implemented")
|
return nil, status.Error(codes.Unimplemented, "method SetListenerMaxConnections not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedProxyAdminServiceServer) ListL7Policies(context.Context, *ListL7PoliciesRequest) (*ListL7PoliciesResponse, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method ListL7Policies not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedProxyAdminServiceServer) AddL7Policy(context.Context, *AddL7PolicyRequest) (*AddL7PolicyResponse, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method AddL7Policy not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedProxyAdminServiceServer) RemoveL7Policy(context.Context, *RemoveL7PolicyRequest) (*RemoveL7PolicyResponse, error) {
|
||||||
|
return nil, status.Error(codes.Unimplemented, "method RemoveL7Policy not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedProxyAdminServiceServer) GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error) {
|
func (UnimplementedProxyAdminServiceServer) GetStatus(context.Context, *GetStatusRequest) (*GetStatusResponse, error) {
|
||||||
return nil, status.Error(codes.Unimplemented, "method GetStatus not implemented")
|
return nil, status.Error(codes.Unimplemented, "method GetStatus not implemented")
|
||||||
}
|
}
|
||||||
@@ -332,6 +382,60 @@ func _ProxyAdminService_SetListenerMaxConnections_Handler(srv interface{}, ctx c
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _ProxyAdminService_ListL7Policies_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(ListL7PoliciesRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(ProxyAdminServiceServer).ListL7Policies(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: ProxyAdminService_ListL7Policies_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(ProxyAdminServiceServer).ListL7Policies(ctx, req.(*ListL7PoliciesRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _ProxyAdminService_AddL7Policy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(AddL7PolicyRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(ProxyAdminServiceServer).AddL7Policy(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: ProxyAdminService_AddL7Policy_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(ProxyAdminServiceServer).AddL7Policy(ctx, req.(*AddL7PolicyRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _ProxyAdminService_RemoveL7Policy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(RemoveL7PolicyRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(ProxyAdminServiceServer).RemoveL7Policy(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: ProxyAdminService_RemoveL7Policy_FullMethodName,
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(ProxyAdminServiceServer).RemoveL7Policy(ctx, req.(*RemoveL7PolicyRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
func _ProxyAdminService_GetStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
func _ProxyAdminService_GetStatus_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
in := new(GetStatusRequest)
|
in := new(GetStatusRequest)
|
||||||
if err := dec(in); err != nil {
|
if err := dec(in); err != nil {
|
||||||
@@ -385,6 +489,18 @@ var ProxyAdminService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "SetListenerMaxConnections",
|
MethodName: "SetListenerMaxConnections",
|
||||||
Handler: _ProxyAdminService_SetListenerMaxConnections_Handler,
|
Handler: _ProxyAdminService_SetListenerMaxConnections_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "ListL7Policies",
|
||||||
|
Handler: _ProxyAdminService_ListL7Policies_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "AddL7Policy",
|
||||||
|
Handler: _ProxyAdminService_AddL7Policy_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "RemoveL7Policy",
|
||||||
|
Handler: _ProxyAdminService_RemoveL7Policy_Handler,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
MethodName: "GetStatus",
|
MethodName: "GetStatus",
|
||||||
Handler: _ProxyAdminService_GetStatus_Handler,
|
Handler: _ProxyAdminService_GetStatus_Handler,
|
||||||
|
|||||||
@@ -45,13 +45,20 @@ type Listener struct {
|
|||||||
|
|
||||||
// Route is a proxy route within a listener.
|
// Route is a proxy route within a listener.
|
||||||
type Route struct {
|
type Route struct {
|
||||||
Hostname string `toml:"hostname"`
|
Hostname string `toml:"hostname"`
|
||||||
Backend string `toml:"backend"`
|
Backend string `toml:"backend"`
|
||||||
Mode string `toml:"mode"` // "l4" (default) or "l7"
|
Mode string `toml:"mode"` // "l4" (default) or "l7"
|
||||||
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
|
TLSCert string `toml:"tls_cert"` // PEM certificate path (L7 only)
|
||||||
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
|
TLSKey string `toml:"tls_key"` // PEM private key path (L7 only)
|
||||||
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
|
BackendTLS bool `toml:"backend_tls"` // re-encrypt to backend (L7 only)
|
||||||
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
|
SendProxyProtocol bool `toml:"send_proxy_protocol"` // send PROXY v2 header to backend
|
||||||
|
L7Policies []L7Policy `toml:"l7_policies"` // HTTP-level policies (L7 only)
|
||||||
|
}
|
||||||
|
|
||||||
|
// L7Policy is an HTTP-level blocking policy for L7 routes.
|
||||||
|
type L7Policy struct {
|
||||||
|
Type string `toml:"type"` // "block_user_agent" or "require_header"
|
||||||
|
Value string `toml:"value"` // UA substring or header name
|
||||||
}
|
}
|
||||||
|
|
||||||
// Firewall holds the global firewall configuration.
|
// Firewall holds the global firewall configuration.
|
||||||
@@ -168,6 +175,22 @@ func (c *Config) validate() error {
|
|||||||
i, l.Addr, j, r.Hostname, err)
|
i, l.Addr, j, r.Hostname, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate L7 policies.
|
||||||
|
if r.Mode == "l4" && len(r.L7Policies) > 0 {
|
||||||
|
slog.Warn("L4 route has l7_policies set (ignored)",
|
||||||
|
"listener", l.Addr, "hostname", r.Hostname)
|
||||||
|
}
|
||||||
|
for k, p := range r.L7Policies {
|
||||||
|
if p.Type != "block_user_agent" && p.Type != "require_header" {
|
||||||
|
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: type must be \"block_user_agent\" or \"require_header\", got %q",
|
||||||
|
i, l.Addr, j, r.Hostname, k, p.Type)
|
||||||
|
}
|
||||||
|
if p.Value == "" {
|
||||||
|
return fmt.Errorf("listener %d (%s), route %d (%s), policy %d: value is required",
|
||||||
|
i, l.Addr, j, r.Hostname, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -508,3 +508,95 @@ func TestMigrationV2Upgrade(t *testing.T) {
|
|||||||
t.Fatal("expected proxy_protocol = false")
|
t.Fatal("expected proxy_protocol = false")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestL7PolicyCRUD(t *testing.T) {
|
||||||
|
store := openTestDB(t)
|
||||||
|
|
||||||
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
|
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
|
|
||||||
|
// Create policies.
|
||||||
|
id1, err := store.CreateL7Policy(rid, "block_user_agent", "BadBot")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("create policy 1: %v", err)
|
||||||
|
}
|
||||||
|
if id1 == 0 {
|
||||||
|
t.Fatal("expected non-zero policy ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := store.CreateL7Policy(rid, "require_header", "X-API-Key"); err != nil {
|
||||||
|
t.Fatalf("create policy 2: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// List policies.
|
||||||
|
policies, err := store.ListL7Policies(rid)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("list: %v", err)
|
||||||
|
}
|
||||||
|
if len(policies) != 2 {
|
||||||
|
t.Fatalf("got %d policies, want 2", len(policies))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete one.
|
||||||
|
if err := store.DeleteL7Policy(rid, "block_user_agent", "BadBot"); err != nil {
|
||||||
|
t.Fatalf("delete: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
policies, _ = store.ListL7Policies(rid)
|
||||||
|
if len(policies) != 1 {
|
||||||
|
t.Fatalf("got %d policies after delete, want 1", len(policies))
|
||||||
|
}
|
||||||
|
if policies[0].Type != "require_header" {
|
||||||
|
t.Fatalf("remaining policy type = %q, want %q", policies[0].Type, "require_header")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7PolicyCascadeDelete(t *testing.T) {
|
||||||
|
store := openTestDB(t)
|
||||||
|
|
||||||
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
|
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
|
store.CreateL7Policy(rid, "block_user_agent", "Bot")
|
||||||
|
|
||||||
|
// Deleting the route should cascade-delete its policies.
|
||||||
|
store.DeleteRoute(lid, "api.test")
|
||||||
|
|
||||||
|
policies, _ := store.ListL7Policies(rid)
|
||||||
|
if len(policies) != 0 {
|
||||||
|
t.Fatalf("got %d policies after cascade delete, want 0", len(policies))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7PolicyDuplicate(t *testing.T) {
|
||||||
|
store := openTestDB(t)
|
||||||
|
|
||||||
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
|
rid, _ := store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
|
|
||||||
|
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err != nil {
|
||||||
|
t.Fatalf("first create: %v", err)
|
||||||
|
}
|
||||||
|
if _, err := store.CreateL7Policy(rid, "block_user_agent", "Bot"); err == nil {
|
||||||
|
t.Fatal("expected error for duplicate policy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGetRouteID(t *testing.T) {
|
||||||
|
store := openTestDB(t)
|
||||||
|
|
||||||
|
lid, _ := store.CreateListener(":443", false, 0)
|
||||||
|
store.CreateRoute(lid, "api.test", "127.0.0.1:8080", "l7", "/c.pem", "/k.pem", false, false)
|
||||||
|
|
||||||
|
rid, err := store.GetRouteID(lid, "api.test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetRouteID: %v", err)
|
||||||
|
}
|
||||||
|
if rid == 0 {
|
||||||
|
t.Fatal("expected non-zero route ID")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = store.GetRouteID(lid, "nonexistent.test")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nonexistent route")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
74
internal/db/l7policies.go
Normal file
74
internal/db/l7policies.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package db
|
||||||
|
|
||||||
|
import "fmt"
|
||||||
|
|
||||||
|
// L7Policy is a database L7 policy record.
|
||||||
|
type L7Policy struct {
|
||||||
|
ID int64
|
||||||
|
RouteID int64
|
||||||
|
Type string // "block_user_agent" or "require_header"
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListL7Policies returns all L7 policies for a route.
|
||||||
|
func (s *Store) ListL7Policies(routeID int64) ([]L7Policy, error) {
|
||||||
|
rows, err := s.db.Query(
|
||||||
|
"SELECT id, route_id, type, value FROM l7_policies WHERE route_id = ? ORDER BY id",
|
||||||
|
routeID,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("querying l7 policies: %w", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
var policies []L7Policy
|
||||||
|
for rows.Next() {
|
||||||
|
var p L7Policy
|
||||||
|
if err := rows.Scan(&p.ID, &p.RouteID, &p.Type, &p.Value); err != nil {
|
||||||
|
return nil, fmt.Errorf("scanning l7 policy: %w", err)
|
||||||
|
}
|
||||||
|
policies = append(policies, p)
|
||||||
|
}
|
||||||
|
return policies, rows.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateL7Policy inserts an L7 policy and returns its ID.
|
||||||
|
func (s *Store) CreateL7Policy(routeID int64, policyType, value string) (int64, error) {
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
|
||||||
|
routeID, policyType, value,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("inserting l7 policy: %w", err)
|
||||||
|
}
|
||||||
|
return result.LastInsertId()
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteL7Policy deletes an L7 policy by route ID, type, and value.
|
||||||
|
func (s *Store) DeleteL7Policy(routeID int64, policyType, value string) error {
|
||||||
|
result, err := s.db.Exec(
|
||||||
|
"DELETE FROM l7_policies WHERE route_id = ? AND type = ? AND value = ?",
|
||||||
|
routeID, policyType, value,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("deleting l7 policy: %w", err)
|
||||||
|
}
|
||||||
|
n, _ := result.RowsAffected()
|
||||||
|
if n == 0 {
|
||||||
|
return fmt.Errorf("l7 policy not found (route %d, type %q, value %q)", routeID, policyType, value)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRouteID returns the route ID for a listener/hostname pair.
|
||||||
|
func (s *Store) GetRouteID(listenerID int64, hostname string) (int64, error) {
|
||||||
|
var id int64
|
||||||
|
err := s.db.QueryRow(
|
||||||
|
"SELECT id FROM routes WHERE listener_id = ? AND hostname = ?",
|
||||||
|
listenerID, hostname,
|
||||||
|
).Scan(&id)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("looking up route %q on listener %d: %w", hostname, listenerID, err)
|
||||||
|
}
|
||||||
|
return id, nil
|
||||||
|
}
|
||||||
@@ -45,6 +45,19 @@ ALTER TABLE routes ADD COLUMN send_proxy_protocol INTEGER NOT NULL DEFAULT 0;`,
|
|||||||
Name: "add_listener_max_connections",
|
Name: "add_listener_max_connections",
|
||||||
SQL: `ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0;`,
|
SQL: `ALTER TABLE listeners ADD COLUMN max_connections INTEGER NOT NULL DEFAULT 0;`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Version: 4,
|
||||||
|
Name: "create_l7_policies_table",
|
||||||
|
SQL: `
|
||||||
|
CREATE TABLE IF NOT EXISTS l7_policies (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
route_id INTEGER NOT NULL REFERENCES routes(id) ON DELETE CASCADE,
|
||||||
|
type TEXT NOT NULL CHECK(type IN ('block_user_agent', 'require_header')),
|
||||||
|
value TEXT NOT NULL,
|
||||||
|
UNIQUE(route_id, type, value)
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_l7_policies_route ON l7_policies(route_id);`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// Migrate runs all unapplied migrations sequentially.
|
// Migrate runs all unapplied migrations sequentially.
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
|
|||||||
if mode == "" {
|
if mode == "" {
|
||||||
mode = "l4"
|
mode = "l4"
|
||||||
}
|
}
|
||||||
_, err := tx.Exec(
|
routeResult, err := tx.Exec(
|
||||||
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
|
`INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol)
|
||||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||||
listenerID, strings.ToLower(r.Hostname), r.Backend,
|
listenerID, strings.ToLower(r.Hostname), r.Backend,
|
||||||
@@ -40,6 +40,18 @@ func (s *Store) Seed(listeners []config.Listener, fw config.Firewall) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)
|
return fmt.Errorf("seeding route %q on listener %q: %w", r.Hostname, l.Addr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(r.L7Policies) > 0 {
|
||||||
|
routeID, _ := routeResult.LastInsertId()
|
||||||
|
for _, p := range r.L7Policies {
|
||||||
|
if _, err := tx.Exec(
|
||||||
|
"INSERT INTO l7_policies (route_id, type, value) VALUES (?, ?, ?)",
|
||||||
|
routeID, p.Type, p.Value,
|
||||||
|
); err != nil {
|
||||||
|
return fmt.Errorf("seeding l7 policy on route %q: %w", r.Hostname, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,6 +89,10 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
|||||||
ListenerAddr: ls.Addr,
|
ListenerAddr: ls.Addr,
|
||||||
}
|
}
|
||||||
for hostname, route := range routes {
|
for hostname, route := range routes {
|
||||||
|
var policies []*pb.L7Policy
|
||||||
|
for _, p := range route.L7Policies {
|
||||||
|
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
||||||
|
}
|
||||||
resp.Routes = append(resp.Routes, &pb.Route{
|
resp.Routes = append(resp.Routes, &pb.Route{
|
||||||
Hostname: hostname,
|
Hostname: hostname,
|
||||||
Backend: route.Backend,
|
Backend: route.Backend,
|
||||||
@@ -97,6 +101,7 @@ func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (
|
|||||||
TlsKey: route.TLSKey,
|
TlsKey: route.TLSKey,
|
||||||
BackendTls: route.BackendTLS,
|
BackendTls: route.BackendTLS,
|
||||||
SendProxyProtocol: route.SendProxyProtocol,
|
SendProxyProtocol: route.SendProxyProtocol,
|
||||||
|
L7Policies: policies,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@@ -187,6 +192,100 @@ func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest)
|
|||||||
return &pb.RemoveRouteResponse{}, nil
|
return &pb.RemoveRouteResponse{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ListL7Policies returns L7 policies for a route.
|
||||||
|
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
|
||||||
|
ls, err := a.findListener(req.ListenerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.ToLower(req.Hostname)
|
||||||
|
routes := ls.Routes()
|
||||||
|
route, ok := routes[hostname]
|
||||||
|
if !ok {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
var policies []*pb.L7Policy
|
||||||
|
for _, p := range route.L7Policies {
|
||||||
|
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
||||||
|
}
|
||||||
|
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddL7Policy adds an L7 policy to a route (write-through).
|
||||||
|
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
|
||||||
|
if req.Policy == nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
||||||
|
}
|
||||||
|
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
|
||||||
|
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
|
||||||
|
}
|
||||||
|
if req.Policy.Value == "" {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "policy value is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ls, err := a.findListener(req.ListenerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.ToLower(req.Hostname)
|
||||||
|
|
||||||
|
// Get route ID from DB.
|
||||||
|
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
|
}
|
||||||
|
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write-through: DB first.
|
||||||
|
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
||||||
|
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update in-memory state.
|
||||||
|
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
|
||||||
|
|
||||||
|
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
|
||||||
|
return &pb.AddL7PolicyResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveL7Policy removes an L7 policy from a route (write-through).
|
||||||
|
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
|
||||||
|
if req.Policy == nil {
|
||||||
|
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
ls, err := a.findListener(req.ListenerAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname := strings.ToLower(req.Hostname)
|
||||||
|
|
||||||
|
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
||||||
|
}
|
||||||
|
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
||||||
|
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
|
||||||
|
|
||||||
|
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
|
||||||
|
return &pb.RemoveL7PolicyResponse{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetFirewallRules returns all current firewall rules.
|
// GetFirewallRules returns all current firewall rules.
|
||||||
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
||||||
ips, cidrs, countries := a.srv.Firewall().Rules()
|
ips, cidrs, countries := a.srv.Firewall().Rules()
|
||||||
|
|||||||
@@ -735,3 +735,97 @@ func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
|
|||||||
t.Fatalf("expected NotFound, got %v", err)
|
t.Fatalf("expected NotFound, got %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAddListL7Policy(t *testing.T) {
|
||||||
|
env := setup(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("AddL7Policy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("ListL7Policies: %v", err)
|
||||||
|
}
|
||||||
|
if len(resp.Policies) != 1 {
|
||||||
|
t.Fatalf("got %d policies, want 1", len(resp.Policies))
|
||||||
|
}
|
||||||
|
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
|
||||||
|
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
||||||
|
for _, r := range routeResp.Routes {
|
||||||
|
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
|
||||||
|
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRemoveL7Policy(t *testing.T) {
|
||||||
|
env := setup(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
||||||
|
})
|
||||||
|
|
||||||
|
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("RemoveL7Policy: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
})
|
||||||
|
if len(resp.Policies) != 0 {
|
||||||
|
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddL7PolicyValidation(t *testing.T) {
|
||||||
|
env := setup(t)
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for invalid policy type")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for empty policy value")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
|
||||||
|
ListenerAddr: ":443",
|
||||||
|
Hostname: "a.test",
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("expected error for nil policy")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
38
internal/l7/policy.go
Normal file
38
internal/l7/policy.go
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
package l7
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PolicyRule defines an L7 blocking policy.
|
||||||
|
type PolicyRule struct {
|
||||||
|
Type string // "block_user_agent" or "require_header"
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
|
// PolicyMiddleware returns an http.Handler that evaluates L7 policies
|
||||||
|
// before delegating to next. Returns HTTP 403 if any policy blocks.
|
||||||
|
// If policies is empty, returns next unchanged.
|
||||||
|
func PolicyMiddleware(policies []PolicyRule, next http.Handler) http.Handler {
|
||||||
|
if len(policies) == 0 {
|
||||||
|
return next
|
||||||
|
}
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
for _, p := range policies {
|
||||||
|
switch p.Type {
|
||||||
|
case "block_user_agent":
|
||||||
|
if strings.Contains(r.UserAgent(), p.Value) {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case "require_header":
|
||||||
|
if r.Header.Get(p.Value) == "" {
|
||||||
|
w.WriteHeader(http.StatusForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
158
internal/l7/policy_test.go
Normal file
158
internal/l7/policy_test.go
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
package l7
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPolicyMiddlewareNoPolicies(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := PolicyMiddleware(nil, next)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("next handler was not called")
|
||||||
|
}
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Fatalf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyBlockUserAgentMatch(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
policies := []PolicyRule{
|
||||||
|
{Type: "block_user_agent", Value: "BadBot"},
|
||||||
|
}
|
||||||
|
handler := PolicyMiddleware(policies, next)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("User-Agent", "Mozilla/5.0 BadBot/1.0")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != 403 {
|
||||||
|
t.Fatalf("status = %d, want 403", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyBlockUserAgentNoMatch(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
policies := []PolicyRule{
|
||||||
|
{Type: "block_user_agent", Value: "BadBot"},
|
||||||
|
}
|
||||||
|
handler := PolicyMiddleware(policies, next)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("User-Agent", "Mozilla/5.0 GoodBrowser/1.0")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("next handler was not called")
|
||||||
|
}
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Fatalf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRequireHeaderPresent(t *testing.T) {
|
||||||
|
called := false
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
called = true
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
policies := []PolicyRule{
|
||||||
|
{Type: "require_header", Value: "X-API-Key"},
|
||||||
|
}
|
||||||
|
handler := PolicyMiddleware(policies, next)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("X-API-Key", "secret")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if !called {
|
||||||
|
t.Fatal("next handler was not called")
|
||||||
|
}
|
||||||
|
if w.Code != 200 {
|
||||||
|
t.Fatalf("status = %d, want 200", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyRequireHeaderAbsent(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
policies := []PolicyRule{
|
||||||
|
{Type: "require_header", Value: "X-API-Key"},
|
||||||
|
}
|
||||||
|
handler := PolicyMiddleware(policies, next)
|
||||||
|
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
if w.Code != 403 {
|
||||||
|
t.Fatalf("status = %d, want 403", w.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPolicyMultipleRules(t *testing.T) {
|
||||||
|
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.WriteHeader(200)
|
||||||
|
})
|
||||||
|
|
||||||
|
policies := []PolicyRule{
|
||||||
|
{Type: "block_user_agent", Value: "BadBot"},
|
||||||
|
{Type: "require_header", Value: "X-Token"},
|
||||||
|
}
|
||||||
|
handler := PolicyMiddleware(policies, next)
|
||||||
|
|
||||||
|
// Blocked by UA even though header is present.
|
||||||
|
req := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req.Header.Set("User-Agent", "BadBot/1.0")
|
||||||
|
req.Header.Set("X-Token", "abc")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w, req)
|
||||||
|
if w.Code != 403 {
|
||||||
|
t.Fatalf("UA block: status = %d, want 403", w.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Good UA but missing header.
|
||||||
|
req2 := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req2.Header.Set("User-Agent", "GoodBot/1.0")
|
||||||
|
w2 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w2, req2)
|
||||||
|
if w2.Code != 403 {
|
||||||
|
t.Fatalf("missing header: status = %d, want 403", w2.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Good UA and header present — passes.
|
||||||
|
req3 := httptest.NewRequest("GET", "/", nil)
|
||||||
|
req3.Header.Set("User-Agent", "GoodBot/1.0")
|
||||||
|
req3.Header.Set("X-Token", "abc")
|
||||||
|
w3 := httptest.NewRecorder()
|
||||||
|
handler.ServeHTTP(w3, req3)
|
||||||
|
if w3.Code != 200 {
|
||||||
|
t.Fatalf("pass: status = %d, want 200", w3.Code)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -26,6 +26,7 @@ type RouteConfig struct {
|
|||||||
BackendTLS bool
|
BackendTLS bool
|
||||||
SendProxyProtocol bool
|
SendProxyProtocol bool
|
||||||
ConnectTimeout time.Duration
|
ConnectTimeout time.Duration
|
||||||
|
Policies []PolicyRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// contextKey is an unexported type for context keys in this package.
|
// contextKey is an unexported type for context keys in this package.
|
||||||
@@ -74,10 +75,12 @@ func Serve(ctx context.Context, conn net.Conn, peeked []byte, route RouteConfig,
|
|||||||
return fmt.Errorf("creating reverse proxy: %w", err)
|
return fmt.Errorf("creating reverse proxy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrap the handler to inject the real client IP into the request context.
|
// Build handler chain: context injection → L7 policies → reverse proxy.
|
||||||
|
var inner http.Handler = rp
|
||||||
|
inner = PolicyMiddleware(route.Policies, inner)
|
||||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
r = r.WithContext(context.WithValue(r.Context(), clientAddrKey, clientAddr))
|
||||||
rp.ServeHTTP(w, r)
|
inner.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|
||||||
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
// Serve HTTP on the TLS connection. Use HTTP/2 if negotiated,
|
||||||
|
|||||||
@@ -551,3 +551,115 @@ func TestL7HTTP11Fallback(t *testing.T) {
|
|||||||
t.Fatal("empty response body")
|
t.Fatal("empty response body")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestL7PolicyBlocksUserAgentE2E(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "policy.test")
|
||||||
|
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, "should-not-reach")
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: backendAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
Policies: []PolicyRule{
|
||||||
|
{Type: "block_user_agent", Value: "EvilBot"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := dialTLSToProxy(t, proxyLn.Addr().String(), "policy.test")
|
||||||
|
req, _ := http.NewRequest("GET", "https://policy.test/", nil)
|
||||||
|
req.Header.Set("User-Agent", "EvilBot/1.0")
|
||||||
|
resp, err := client.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET: %v", err)
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if resp.StatusCode != 403 {
|
||||||
|
t.Fatalf("status = %d, want 403", resp.StatusCode)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestL7PolicyRequiresHeaderE2E(t *testing.T) {
|
||||||
|
certPath, keyPath := testCert(t, "reqhdr.test")
|
||||||
|
|
||||||
|
backendAddr := startH2CBackend(t, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
fmt.Fprint(w, "ok")
|
||||||
|
}))
|
||||||
|
|
||||||
|
proxyLn, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("proxy listen: %v", err)
|
||||||
|
}
|
||||||
|
defer proxyLn.Close()
|
||||||
|
|
||||||
|
route := RouteConfig{
|
||||||
|
Backend: backendAddr,
|
||||||
|
TLSCert: certPath,
|
||||||
|
TLSKey: keyPath,
|
||||||
|
ConnectTimeout: 5 * time.Second,
|
||||||
|
Policies: []PolicyRule{
|
||||||
|
{Type: "require_header", Value: "X-Auth-Token"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept two connections (one blocked, one allowed).
|
||||||
|
go func() {
|
||||||
|
for range 2 {
|
||||||
|
conn, err := proxyLn.Accept()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||||
|
Serve(context.Background(), conn, nil, route, netip.MustParseAddrPort("203.0.113.50:12345"), logger)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Without the required header → 403.
|
||||||
|
client1 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
|
||||||
|
resp1, err := client1.Get("https://reqhdr.test/")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET without header: %v", err)
|
||||||
|
}
|
||||||
|
resp1.Body.Close()
|
||||||
|
if resp1.StatusCode != 403 {
|
||||||
|
t.Fatalf("without header: status = %d, want 403", resp1.StatusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// With the required header → 200.
|
||||||
|
client2 := dialTLSToProxy(t, proxyLn.Addr().String(), "reqhdr.test")
|
||||||
|
req, _ := http.NewRequest("GET", "https://reqhdr.test/", nil)
|
||||||
|
req.Header.Set("X-Auth-Token", "valid-token")
|
||||||
|
resp2, err := client2.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GET with header: %v", err)
|
||||||
|
}
|
||||||
|
defer resp2.Body.Close()
|
||||||
|
body, _ := io.ReadAll(resp2.Body)
|
||||||
|
if resp2.StatusCode != 200 {
|
||||||
|
t.Fatalf("with header: status = %d, want 200", resp2.StatusCode)
|
||||||
|
}
|
||||||
|
if string(body) != "ok" {
|
||||||
|
t.Fatalf("body = %q, want %q", body, "ok")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -19,6 +19,12 @@ import (
|
|||||||
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
"git.wntrmute.dev/kyle/mc-proxy/internal/sni"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// L7PolicyRule is an L7 blocking policy attached to a route.
|
||||||
|
type L7PolicyRule struct {
|
||||||
|
Type string // "block_user_agent" or "require_header"
|
||||||
|
Value string
|
||||||
|
}
|
||||||
|
|
||||||
// RouteInfo holds the full configuration for a single route.
|
// RouteInfo holds the full configuration for a single route.
|
||||||
type RouteInfo struct {
|
type RouteInfo struct {
|
||||||
Backend string
|
Backend string
|
||||||
@@ -27,6 +33,7 @@ type RouteInfo struct {
|
|||||||
TLSKey string
|
TLSKey string
|
||||||
BackendTLS bool
|
BackendTLS bool
|
||||||
SendProxyProtocol bool
|
SendProxyProtocol bool
|
||||||
|
L7Policies []L7PolicyRule
|
||||||
}
|
}
|
||||||
|
|
||||||
// ListenerState holds the mutable state for a single proxy listener.
|
// ListenerState holds the mutable state for a single proxy listener.
|
||||||
@@ -91,6 +98,40 @@ func (ls *ListenerState) RemoveRoute(hostname string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddL7Policy appends an L7 policy to a route's policy list.
|
||||||
|
func (ls *ListenerState) AddL7Policy(hostname string, policy L7PolicyRule) {
|
||||||
|
key := strings.ToLower(hostname)
|
||||||
|
|
||||||
|
ls.mu.Lock()
|
||||||
|
defer ls.mu.Unlock()
|
||||||
|
|
||||||
|
if route, ok := ls.routes[key]; ok {
|
||||||
|
route.L7Policies = append(route.L7Policies, policy)
|
||||||
|
ls.routes[key] = route
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveL7Policy removes an L7 policy from a route's policy list.
|
||||||
|
func (ls *ListenerState) RemoveL7Policy(hostname, policyType, policyValue string) {
|
||||||
|
key := strings.ToLower(hostname)
|
||||||
|
|
||||||
|
ls.mu.Lock()
|
||||||
|
defer ls.mu.Unlock()
|
||||||
|
|
||||||
|
route, ok := ls.routes[key]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
filtered := route.L7Policies[:0]
|
||||||
|
for _, p := range route.L7Policies {
|
||||||
|
if p.Type != policyType || p.Value != policyValue {
|
||||||
|
filtered = append(filtered, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
route.L7Policies = filtered
|
||||||
|
ls.routes[key] = route
|
||||||
|
}
|
||||||
|
|
||||||
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
|
func (ls *ListenerState) lookupRoute(hostname string) (RouteInfo, bool) {
|
||||||
ls.mu.RLock()
|
ls.mu.RLock()
|
||||||
defer ls.mu.RUnlock()
|
defer ls.mu.RUnlock()
|
||||||
@@ -362,6 +403,11 @@ func (s *Server) handleL4(ctx context.Context, conn net.Conn, addr netip.Addr, c
|
|||||||
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
|
func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, clientAddrPort netip.AddrPort, hostname string, route RouteInfo, peeked []byte) {
|
||||||
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
s.logger.Debug("L7 proxying", "addr", addr, "hostname", hostname, "backend", route.Backend)
|
||||||
|
|
||||||
|
var policies []l7.PolicyRule
|
||||||
|
for _, p := range route.L7Policies {
|
||||||
|
policies = append(policies, l7.PolicyRule{Type: p.Type, Value: p.Value})
|
||||||
|
}
|
||||||
|
|
||||||
rc := l7.RouteConfig{
|
rc := l7.RouteConfig{
|
||||||
Backend: route.Backend,
|
Backend: route.Backend,
|
||||||
TLSCert: route.TLSCert,
|
TLSCert: route.TLSCert,
|
||||||
@@ -369,6 +415,7 @@ func (s *Server) handleL7(ctx context.Context, conn net.Conn, addr netip.Addr, c
|
|||||||
BackendTLS: route.BackendTLS,
|
BackendTLS: route.BackendTLS,
|
||||||
SendProxyProtocol: route.SendProxyProtocol,
|
SendProxyProtocol: route.SendProxyProtocol,
|
||||||
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
|
ConnectTimeout: s.cfg.Proxy.ConnectTimeout.Duration,
|
||||||
|
Policies: policies,
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {
|
if err := l7.Serve(ctx, conn, peeked, rc, clientAddrPort, s.logger); err != nil {
|
||||||
|
|||||||
@@ -20,20 +20,31 @@ service ProxyAdminService {
|
|||||||
// Connection limits
|
// Connection limits
|
||||||
rpc SetListenerMaxConnections(SetListenerMaxConnectionsRequest) returns (SetListenerMaxConnectionsResponse);
|
rpc SetListenerMaxConnections(SetListenerMaxConnectionsRequest) returns (SetListenerMaxConnectionsResponse);
|
||||||
|
|
||||||
|
// L7 policies
|
||||||
|
rpc ListL7Policies(ListL7PoliciesRequest) returns (ListL7PoliciesResponse);
|
||||||
|
rpc AddL7Policy(AddL7PolicyRequest) returns (AddL7PolicyResponse);
|
||||||
|
rpc RemoveL7Policy(RemoveL7PolicyRequest) returns (RemoveL7PolicyResponse);
|
||||||
|
|
||||||
// Status
|
// Status
|
||||||
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
|
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Routes
|
// Routes
|
||||||
|
|
||||||
|
message L7Policy {
|
||||||
|
string type = 1; // "block_user_agent" or "require_header"
|
||||||
|
string value = 2;
|
||||||
|
}
|
||||||
|
|
||||||
message Route {
|
message Route {
|
||||||
string hostname = 1;
|
string hostname = 1;
|
||||||
string backend = 2;
|
string backend = 2;
|
||||||
string mode = 3; // "l4" (default) or "l7"
|
string mode = 3; // "l4" (default) or "l7"
|
||||||
string tls_cert = 4; // PEM certificate path (L7 only)
|
string tls_cert = 4; // PEM certificate path (L7 only)
|
||||||
string tls_key = 5; // PEM private key path (L7 only)
|
string tls_key = 5; // PEM private key path (L7 only)
|
||||||
bool backend_tls = 6; // re-encrypt to backend (L7 only)
|
bool backend_tls = 6; // re-encrypt to backend (L7 only)
|
||||||
bool send_proxy_protocol = 7; // send PROXY v2 header to backend
|
bool send_proxy_protocol = 7; // send PROXY v2 header to backend
|
||||||
|
repeated L7Policy l7_policies = 8; // HTTP-level policies (L7 only)
|
||||||
}
|
}
|
||||||
|
|
||||||
message ListRoutesRequest {
|
message ListRoutesRequest {
|
||||||
@@ -59,6 +70,33 @@ message RemoveRouteRequest {
|
|||||||
|
|
||||||
message RemoveRouteResponse {}
|
message RemoveRouteResponse {}
|
||||||
|
|
||||||
|
// L7 Policies
|
||||||
|
|
||||||
|
message ListL7PoliciesRequest {
|
||||||
|
string listener_addr = 1;
|
||||||
|
string hostname = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
message ListL7PoliciesResponse {
|
||||||
|
repeated L7Policy policies = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AddL7PolicyRequest {
|
||||||
|
string listener_addr = 1;
|
||||||
|
string hostname = 2;
|
||||||
|
L7Policy policy = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message AddL7PolicyResponse {}
|
||||||
|
|
||||||
|
message RemoveL7PolicyRequest {
|
||||||
|
string listener_addr = 1;
|
||||||
|
string hostname = 2;
|
||||||
|
L7Policy policy = 3;
|
||||||
|
}
|
||||||
|
|
||||||
|
message RemoveL7PolicyResponse {}
|
||||||
|
|
||||||
// Firewall
|
// Firewall
|
||||||
|
|
||||||
enum FirewallRuleType {
|
enum FirewallRuleType {
|
||||||
|
|||||||
Reference in New Issue
Block a user