Extend the config, database schema, and server internals to support per-route L4/L7 mode selection and PROXY protocol fields. This is the foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol support described in the updated ARCHITECTURE.md. Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert, TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load time (cert/key pair must exist and parse). Mode defaults to "l4". DB: Migration v2 adds columns to listeners and routes tables. CRUD and seeding updated to persist all new fields. Server: RouteInfo replaces bare backend string in route lookup. handleConn dispatches on route.Mode (L7 path stubbed with error). ListenerState and ListenerData carry ProxyProtocol flag. All existing L4 tests pass unchanged. New tests cover migration v2, L7 field persistence, config validation for mode/cert/key, and proxy_protocol flag round-tripping. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
520 lines
13 KiB
Go
520 lines
13 KiB
Go
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
|
)
|
|
|
|
// testEnv bundles all the objects needed for a grpcserver test.
|
|
type testEnv struct {
|
|
client pb.ProxyAdminServiceClient
|
|
conn *grpc.ClientConn
|
|
store *db.Store
|
|
srv *server.Server
|
|
}
|
|
|
|
func setup(t *testing.T) *testEnv {
|
|
t.Helper()
|
|
|
|
// Database in temp dir.
|
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
|
store, err := db.Open(dbPath)
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
t.Cleanup(func() { store.Close() })
|
|
|
|
if err := store.Migrate(); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
// Seed with one listener and one route.
|
|
listeners := []config.Listener{
|
|
{
|
|
Addr: ":443",
|
|
Routes: []config.Route{
|
|
{Hostname: "a.test", Backend: "127.0.0.1:8443"},
|
|
},
|
|
},
|
|
}
|
|
fw := config.Firewall{
|
|
BlockedIPs: []string{"10.0.0.1"},
|
|
}
|
|
if err := store.Seed(listeners, fw); err != nil {
|
|
t.Fatalf("seed: %v", err)
|
|
}
|
|
|
|
// Build server with matching in-memory state.
|
|
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
|
|
if err != nil {
|
|
t.Fatalf("firewall: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Proxy: config.Proxy{
|
|
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
|
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
|
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
|
},
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
|
|
// Load listener data from DB to get correct IDs.
|
|
dbListeners, err := store.ListListeners()
|
|
if err != nil {
|
|
t.Fatalf("list listeners: %v", err)
|
|
}
|
|
var listenerData []server.ListenerData
|
|
for _, l := range dbListeners {
|
|
dbRoutes, err := store.ListRoutes(l.ID)
|
|
if err != nil {
|
|
t.Fatalf("list routes: %v", err)
|
|
}
|
|
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
|
for _, r := range dbRoutes {
|
|
routes[r.Hostname] = server.RouteInfo{
|
|
Backend: r.Backend,
|
|
Mode: r.Mode,
|
|
}
|
|
}
|
|
listenerData = append(listenerData, server.ListenerData{
|
|
ID: l.ID,
|
|
Addr: l.Addr,
|
|
ProxyProtocol: l.ProxyProtocol,
|
|
Routes: routes,
|
|
})
|
|
}
|
|
|
|
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
|
|
|
|
// Set up bufconn gRPC server (no TLS for tests).
|
|
lis := bufconn.Listen(1024 * 1024)
|
|
grpcSrv := grpc.NewServer()
|
|
admin := &AdminServer{
|
|
srv: srv,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
pb.RegisterProxyAdminServiceServer(grpcSrv, admin)
|
|
|
|
go func() {
|
|
if err := grpcSrv.Serve(lis); err != nil {
|
|
t.Logf("grpc serve: %v", err)
|
|
}
|
|
}()
|
|
t.Cleanup(grpcSrv.Stop)
|
|
|
|
conn, err := grpc.NewClient("passthrough://bufconn",
|
|
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
|
|
return lis.DialContext(ctx)
|
|
}),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("dial bufconn: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
return &testEnv{
|
|
client: pb.NewProxyAdminServiceClient(conn),
|
|
conn: conn,
|
|
store: store,
|
|
srv: srv,
|
|
}
|
|
}
|
|
|
|
func TestGetStatus(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetStatus: %v", err)
|
|
}
|
|
if resp.Version != "test-version" {
|
|
t.Fatalf("got version %q, want %q", resp.Version, "test-version")
|
|
}
|
|
if len(resp.Listeners) != 1 {
|
|
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
|
|
}
|
|
if resp.Listeners[0].Addr != ":443" {
|
|
t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443")
|
|
}
|
|
if resp.Listeners[0].RouteCount != 1 {
|
|
t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount)
|
|
}
|
|
}
|
|
|
|
func TestListRoutes(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 1 {
|
|
t.Fatalf("got %d routes, want 1", len(resp.Routes))
|
|
}
|
|
if resp.Routes[0].Hostname != "a.test" {
|
|
t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test")
|
|
}
|
|
if resp.Routes[0].Backend != "127.0.0.1:8443" {
|
|
t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443")
|
|
}
|
|
}
|
|
|
|
func TestListRoutesNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"})
|
|
if err == nil {
|
|
t.Fatal("expected error for nonexistent listener")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
|
|
t.Fatalf("expected NotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRoute(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddRoute: %v", err)
|
|
}
|
|
|
|
// Verify in-memory.
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 2 {
|
|
t.Fatalf("got %d routes, want 2", len(resp.Routes))
|
|
}
|
|
|
|
// Verify in DB.
|
|
dbListeners, err := env.store.ListListeners()
|
|
if err != nil {
|
|
t.Fatalf("list listeners: %v", err)
|
|
}
|
|
dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID)
|
|
if err != nil {
|
|
t.Fatalf("list routes: %v", err)
|
|
}
|
|
if len(dbRoutes) != 2 {
|
|
t.Fatalf("DB has %d routes, want 2", len(dbRoutes))
|
|
}
|
|
}
|
|
|
|
func TestAddRouteDuplicate(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for duplicate route")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
|
|
t.Fatalf("expected AlreadyExists, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRouteValidation(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Missing route.
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"})
|
|
if err == nil {
|
|
t.Fatal("expected error for nil route")
|
|
}
|
|
|
|
// Missing hostname.
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Backend: "127.0.0.1:1"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty hostname")
|
|
}
|
|
|
|
// Missing backend.
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "x.test"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty backend")
|
|
}
|
|
|
|
// Invalid backend (not host:port).
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid backend address")
|
|
}
|
|
}
|
|
|
|
func TestRemoveRoute(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Hostname: "a.test",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RemoveRoute: %v", err)
|
|
}
|
|
|
|
// Verify removed from memory.
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 0 {
|
|
t.Fatalf("got %d routes, want 0", len(resp.Routes))
|
|
}
|
|
}
|
|
|
|
func TestRemoveRouteNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Hostname: "nonexistent.test",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for removing nonexistent route")
|
|
}
|
|
}
|
|
|
|
func TestGetFirewallRules(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 1 {
|
|
t.Fatalf("got %d rules, want 1", len(resp.Rules))
|
|
}
|
|
if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP {
|
|
t.Fatalf("got type %v, want IP", resp.Rules[0].Type)
|
|
}
|
|
if resp.Rules[0].Value != "10.0.0.1" {
|
|
t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1")
|
|
}
|
|
}
|
|
|
|
func TestAddFirewallRule(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Add IP rule.
|
|
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "10.0.0.2",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule IP: %v", err)
|
|
}
|
|
|
|
// Add CIDR rule.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "192.168.0.0/16",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule CIDR: %v", err)
|
|
}
|
|
|
|
// Add country rule.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "RU",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule country: %v", err)
|
|
}
|
|
|
|
// Verify.
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 4 {
|
|
t.Fatalf("got %d rules, want 4", len(resp.Rules))
|
|
}
|
|
|
|
// Verify DB persistence.
|
|
dbRules, err := env.store.ListFirewallRules()
|
|
if err != nil {
|
|
t.Fatalf("list firewall rules: %v", err)
|
|
}
|
|
if len(dbRules) != 4 {
|
|
t.Fatalf("DB has %d rules, want 4", len(dbRules))
|
|
}
|
|
}
|
|
|
|
func TestAddFirewallRuleValidation(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Nil rule.
|
|
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{})
|
|
if err == nil {
|
|
t.Fatal("expected error for nil rule")
|
|
}
|
|
|
|
// Unknown type.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED,
|
|
Value: "x",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for unspecified rule type")
|
|
}
|
|
|
|
// Empty value.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty value")
|
|
}
|
|
|
|
// Invalid IP address.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "not-an-ip",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid IP")
|
|
}
|
|
|
|
// Invalid CIDR.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "not-a-cidr",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid CIDR")
|
|
}
|
|
|
|
// Non-canonical CIDR.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "192.168.1.5/16",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for non-canonical CIDR")
|
|
}
|
|
|
|
// Invalid country code (lowercase).
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "cn",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for lowercase country code")
|
|
}
|
|
|
|
// Invalid country code (too long).
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "USA",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for 3-letter country code")
|
|
}
|
|
}
|
|
|
|
func TestRemoveFirewallRule(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "10.0.0.1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RemoveFirewallRule: %v", err)
|
|
}
|
|
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 0 {
|
|
t.Fatalf("got %d rules, want 0", len(resp.Rules))
|
|
}
|
|
}
|
|
|
|
func TestRemoveFirewallRuleNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "99.99.99.99",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for removing nonexistent rule")
|
|
}
|
|
}
|