Files
mc-proxy/internal/grpcserver/grpcserver_test.go
Kyle Isom 42c7fffc3e Add L7 policies for user-agent blocking and required headers
Per-route HTTP-level blocking policies for L7 routes. Two rule types:
block_user_agent (substring match against User-Agent, returns 403)
and require_header (named header must be present, returns 403).

Config: L7Policy struct with type/value fields, added as L7Policies
slice on Route. Validated in config (type enum, non-empty value,
warning if set on L4 routes).

DB: Migration 4 creates l7_policies table with route_id FK (cascade
delete), type CHECK constraint, UNIQUE(route_id, type, value). New
l7policies.go with ListL7Policies, CreateL7Policy, DeleteL7Policy,
GetRouteID. Seed updated to persist policies from config.

L7 middleware: PolicyMiddleware in internal/l7/policy.go evaluates
rules in order, returns 403 on first match, no-op if empty. Composed
into the handler chain between context injection and reverse proxy.

Server: L7PolicyRule type on RouteInfo with AddL7Policy/RemoveL7Policy
mutation methods on ListenerState. handleL7 threads policies into
l7.RouteConfig. Startup loads policies per L7 route from DB.

Proto: L7Policy message, repeated l7_policies on Route. Three new
RPCs: ListL7Policies, AddL7Policy, RemoveL7Policy. All follow the
write-through pattern.

Client: L7Policy type, ListL7Policies/AddL7Policy/RemoveL7Policy
methods. CLI: mcproxyctl policies list/add/remove subcommands.

Tests: 6 PolicyMiddleware unit tests (no policies, UA match/no-match,
header present/absent, multiple rules). 4 DB tests (CRUD, cascade,
duplicate, GetRouteID). 3 gRPC tests (add+list, remove, validation).
2 end-to-end L7 tests (UA block, required header with allow/deny).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 17:11:05 -07:00

832 lines
21 KiB
Go

package grpcserver
import (
"context"
"io"
"log/slog"
"net"
"path/filepath"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
// testEnv bundles all the objects needed for a grpcserver test.
type testEnv struct {
client pb.ProxyAdminServiceClient
conn *grpc.ClientConn
store *db.Store
srv *server.Server
}
func setup(t *testing.T) *testEnv {
t.Helper()
// Database in temp dir.
dbPath := filepath.Join(t.TempDir(), "test.db")
store, err := db.Open(dbPath)
if err != nil {
t.Fatalf("open db: %v", err)
}
t.Cleanup(func() { store.Close() })
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
// Seed with one listener and one route.
listeners := []config.Listener{
{
Addr: ":443",
Routes: []config.Route{
{Hostname: "a.test", Backend: "127.0.0.1:8443"},
},
},
}
fw := config.Firewall{
BlockedIPs: []string{"10.0.0.1"},
}
if err := store.Seed(listeners, fw); err != nil {
t.Fatalf("seed: %v", err)
}
// Build server with matching in-memory state.
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
// Load listener data from DB to get correct IDs.
dbListeners, err := store.ListListeners()
if err != nil {
t.Fatalf("list listeners: %v", err)
}
var listenerData []server.ListenerData
for _, l := range dbListeners {
dbRoutes, err := store.ListRoutes(l.ID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
routes := make(map[string]server.RouteInfo, len(dbRoutes))
for _, r := range dbRoutes {
routes[r.Hostname] = server.RouteInfo{
Backend: r.Backend,
Mode: r.Mode,
}
}
listenerData = append(listenerData, server.ListenerData{
ID: l.ID,
Addr: l.Addr,
ProxyProtocol: l.ProxyProtocol,
Routes: routes,
})
}
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
// Set up bufconn gRPC server (no TLS for tests).
lis := bufconn.Listen(1024 * 1024)
grpcSrv := grpc.NewServer()
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
pb.RegisterProxyAdminServiceServer(grpcSrv, admin)
go func() {
if err := grpcSrv.Serve(lis); err != nil {
t.Logf("grpc serve: %v", err)
}
}()
t.Cleanup(grpcSrv.Stop)
conn, err := grpc.NewClient("passthrough://bufconn",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("dial bufconn: %v", err)
}
t.Cleanup(func() { conn.Close() })
return &testEnv{
client: pb.NewProxyAdminServiceClient(conn),
conn: conn,
store: store,
srv: srv,
}
}
func TestGetStatus(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
if err != nil {
t.Fatalf("GetStatus: %v", err)
}
if resp.Version != "test-version" {
t.Fatalf("got version %q, want %q", resp.Version, "test-version")
}
if len(resp.Listeners) != 1 {
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
}
if resp.Listeners[0].Addr != ":443" {
t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443")
}
if resp.Listeners[0].RouteCount != 1 {
t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount)
}
}
func TestListRoutes(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 1 {
t.Fatalf("got %d routes, want 1", len(resp.Routes))
}
if resp.Routes[0].Hostname != "a.test" {
t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test")
}
if resp.Routes[0].Backend != "127.0.0.1:8443" {
t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443")
}
}
func TestListRoutesNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"})
if err == nil {
t.Fatal("expected error for nonexistent listener")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddRoute(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"},
})
if err != nil {
t.Fatalf("AddRoute: %v", err)
}
// Verify in-memory.
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 2 {
t.Fatalf("got %d routes, want 2", len(resp.Routes))
}
// Verify in DB.
dbListeners, err := env.store.ListListeners()
if err != nil {
t.Fatalf("list listeners: %v", err)
}
dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
if len(dbRoutes) != 2 {
t.Fatalf("DB has %d routes, want 2", len(dbRoutes))
}
}
func TestAddRouteDuplicate(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
})
if err == nil {
t.Fatal("expected error for duplicate route")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
t.Fatalf("expected AlreadyExists, got %v", err)
}
}
func TestAddRouteValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Missing route.
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"})
if err == nil {
t.Fatal("expected error for nil route")
}
// Missing hostname.
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Backend: "127.0.0.1:1"},
})
if err == nil {
t.Fatal("expected error for empty hostname")
}
// Missing backend.
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "x.test"},
})
if err == nil {
t.Fatal("expected error for empty backend")
}
// Invalid backend (not host:port).
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
})
if err == nil {
t.Fatal("expected error for invalid backend address")
}
}
func TestRemoveRoute(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("RemoveRoute: %v", err)
}
// Verify removed from memory.
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 0 {
t.Fatalf("got %d routes, want 0", len(resp.Routes))
}
}
func TestRemoveRouteNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
ListenerAddr: ":443",
Hostname: "nonexistent.test",
})
if err == nil {
t.Fatal("expected error for removing nonexistent route")
}
}
func TestGetFirewallRules(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 1 {
t.Fatalf("got %d rules, want 1", len(resp.Rules))
}
if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP {
t.Fatalf("got type %v, want IP", resp.Rules[0].Type)
}
if resp.Rules[0].Value != "10.0.0.1" {
t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1")
}
}
func TestAddFirewallRule(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Add IP rule.
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "10.0.0.2",
},
})
if err != nil {
t.Fatalf("AddFirewallRule IP: %v", err)
}
// Add CIDR rule.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.0.0/16",
},
})
if err != nil {
t.Fatalf("AddFirewallRule CIDR: %v", err)
}
// Add country rule.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "RU",
},
})
if err != nil {
t.Fatalf("AddFirewallRule country: %v", err)
}
// Verify.
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 4 {
t.Fatalf("got %d rules, want 4", len(resp.Rules))
}
// Verify DB persistence.
dbRules, err := env.store.ListFirewallRules()
if err != nil {
t.Fatalf("list firewall rules: %v", err)
}
if len(dbRules) != 4 {
t.Fatalf("DB has %d rules, want 4", len(dbRules))
}
}
func TestAddFirewallRuleValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Nil rule.
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{})
if err == nil {
t.Fatal("expected error for nil rule")
}
// Unknown type.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED,
Value: "x",
},
})
if err == nil {
t.Fatal("expected error for unspecified rule type")
}
// Empty value.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
},
})
if err == nil {
t.Fatal("expected error for empty value")
}
// Invalid IP address.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "not-an-ip",
},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
// Invalid CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "not-a-cidr",
},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
// Non-canonical CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.1.5/16",
},
})
if err == nil {
t.Fatal("expected error for non-canonical CIDR")
}
// Invalid country code (lowercase).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "cn",
},
})
if err == nil {
t.Fatal("expected error for lowercase country code")
}
// Invalid country code (too long).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "USA",
},
})
if err == nil {
t.Fatal("expected error for 3-letter country code")
}
}
func TestRemoveFirewallRule(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "10.0.0.1",
},
})
if err != nil {
t.Fatalf("RemoveFirewallRule: %v", err)
}
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 0 {
t.Fatalf("got %d rules, want 0", len(resp.Rules))
}
}
func TestRemoveFirewallRuleNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "99.99.99.99",
},
})
if err == nil {
t.Fatal("expected error for removing nonexistent rule")
}
}
func TestAddRouteL7(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{
Hostname: "l7.test",
Backend: "127.0.0.1:8080",
Mode: "l7",
TlsCert: "/certs/l7.crt",
TlsKey: "/certs/l7.key",
BackendTls: false,
SendProxyProtocol: true,
},
})
if err != nil {
t.Fatalf("AddRoute L7: %v", err)
}
// Verify in-memory via ListRoutes.
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
var found *pb.Route
for _, r := range resp.Routes {
if r.Hostname == "l7.test" {
found = r
break
}
}
if found == nil {
t.Fatal("L7 route not found in ListRoutes response")
}
if found.Mode != "l7" {
t.Fatalf("mode = %q, want %q", found.Mode, "l7")
}
if found.TlsCert != "/certs/l7.crt" {
t.Fatalf("tls_cert = %q, want %q", found.TlsCert, "/certs/l7.crt")
}
if found.TlsKey != "/certs/l7.key" {
t.Fatalf("tls_key = %q, want %q", found.TlsKey, "/certs/l7.key")
}
if found.BackendTls {
t.Fatal("expected backend_tls = false")
}
if !found.SendProxyProtocol {
t.Fatal("expected send_proxy_protocol = true")
}
// Verify DB persistence.
dbListeners, _ := env.store.ListListeners()
dbRoutes, _ := env.store.ListRoutes(dbListeners[0].ID)
var dbRoute *db.Route
for i := range dbRoutes {
if dbRoutes[i].Hostname == "l7.test" {
dbRoute = &dbRoutes[i]
break
}
}
if dbRoute == nil {
t.Fatal("L7 route not found in DB")
}
if dbRoute.Mode != "l7" {
t.Fatalf("DB mode = %q, want %q", dbRoute.Mode, "l7")
}
if !dbRoute.SendProxyProtocol {
t.Fatal("DB send_proxy_protocol should be true")
}
}
func TestAddRouteL7MissingCert(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{
Hostname: "nocert.test",
Backend: "127.0.0.1:8080",
Mode: "l7",
},
})
if err == nil {
t.Fatal("expected error for L7 route without cert/key")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument {
t.Fatalf("expected InvalidArgument, got %v", err)
}
}
func TestAddRouteInvalidMode(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{
Hostname: "badmode.test",
Backend: "127.0.0.1:8080",
Mode: "l5",
},
})
if err == nil {
t.Fatal("expected error for invalid mode")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument {
t.Fatalf("expected InvalidArgument, got %v", err)
}
}
func TestAddRouteDefaultsToL4(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Add route without specifying mode — should default to "l4".
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{
Hostname: "default.test",
Backend: "127.0.0.1:9443",
},
})
if err != nil {
t.Fatalf("AddRoute: %v", err)
}
resp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
for _, r := range resp.Routes {
if r.Hostname == "default.test" {
if r.Mode != "l4" {
t.Fatalf("mode = %q, want %q", r.Mode, "l4")
}
return
}
}
t.Fatal("route not found")
}
func TestGetStatusProxyProtocol(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
if err != nil {
t.Fatalf("GetStatus: %v", err)
}
// The seeded listener has proxy_protocol = false.
if len(resp.Listeners) != 1 {
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
}
if resp.Listeners[0].ProxyProtocol {
t.Fatal("expected proxy_protocol = false")
}
}
func TestSetListenerMaxConnections(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Set max connections.
_, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{
ListenerAddr: ":443",
MaxConnections: 5000,
})
if err != nil {
t.Fatalf("SetListenerMaxConnections: %v", err)
}
// Verify via GetStatus.
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
if err != nil {
t.Fatalf("GetStatus: %v", err)
}
if resp.Listeners[0].MaxConnections != 5000 {
t.Fatalf("max_connections = %d, want 5000", resp.Listeners[0].MaxConnections)
}
// Verify DB persistence.
l, _ := env.store.GetListenerByAddr(":443")
if l.MaxConnections != 5000 {
t.Fatalf("DB max_connections = %d, want 5000", l.MaxConnections)
}
// Set to 0 (unlimited).
_, err = env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{
ListenerAddr: ":443",
MaxConnections: 0,
})
if err != nil {
t.Fatalf("SetListenerMaxConnections to 0: %v", err)
}
resp, _ = env.client.GetStatus(ctx, &pb.GetStatusRequest{})
if resp.Listeners[0].MaxConnections != 0 {
t.Fatalf("max_connections = %d, want 0", resp.Listeners[0].MaxConnections)
}
}
func TestSetListenerMaxConnectionsNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.SetListenerMaxConnections(ctx, &pb.SetListenerMaxConnectionsRequest{
ListenerAddr: ":9999",
MaxConnections: 100,
})
if err == nil {
t.Fatal("expected error for nonexistent listener")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddListL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: "BadBot"},
})
if err != nil {
t.Fatalf("AddL7Policy: %v", err)
}
resp, err := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("ListL7Policies: %v", err)
}
if len(resp.Policies) != 1 {
t.Fatalf("got %d policies, want 1", len(resp.Policies))
}
if resp.Policies[0].Type != "block_user_agent" || resp.Policies[0].Value != "BadBot" {
t.Fatalf("policy = %v, want block_user_agent/BadBot", resp.Policies[0])
}
routeResp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
for _, r := range routeResp.Routes {
if r.Hostname == "a.test" && len(r.L7Policies) != 1 {
t.Fatalf("ListRoutes: route has %d policies, want 1", len(r.L7Policies))
}
}
}
func TestRemoveL7Policy(t *testing.T) {
env := setup(t)
ctx := context.Background()
env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
_, err := env.client.RemoveL7Policy(ctx, &pb.RemoveL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "require_header", Value: "X-Token"},
})
if err != nil {
t.Fatalf("RemoveL7Policy: %v", err)
}
resp, _ := env.client.ListL7Policies(ctx, &pb.ListL7PoliciesRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if len(resp.Policies) != 0 {
t.Fatalf("got %d policies after remove, want 0", len(resp.Policies))
}
}
func TestAddL7PolicyValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "invalid_type", Value: "x"},
})
if err == nil {
t.Fatal("expected error for invalid policy type")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
Policy: &pb.L7Policy{Type: "block_user_agent", Value: ""},
})
if err == nil {
t.Fatal("expected error for empty policy value")
}
_, err = env.client.AddL7Policy(ctx, &pb.AddL7PolicyRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err == nil {
t.Fatal("expected error for nil policy")
}
}