Files
mc-proxy/internal/grpcserver/grpcserver_test.go
Kyle Isom ed94548dfa Add L7/PROXY protocol data model, config, and architecture docs
Extend the config, database schema, and server internals to support
per-route L4/L7 mode selection and PROXY protocol fields. This is the
foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol
support described in the updated ARCHITECTURE.md.

Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert,
TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load
time (cert/key pair must exist and parse). Mode defaults to "l4".

DB: Migration v2 adds columns to listeners and routes tables. CRUD
and seeding updated to persist all new fields.

Server: RouteInfo replaces bare backend string in route lookup.
handleConn dispatches on route.Mode (L7 path stubbed with error).
ListenerState and ListenerData carry ProxyProtocol flag.

All existing L4 tests pass unchanged. New tests cover migration v2,
L7 field persistence, config validation for mode/cert/key, and
proxy_protocol flag round-tripping.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-25 13:15:51 -07:00

520 lines
13 KiB
Go

package grpcserver
import (
"context"
"io"
"log/slog"
"net"
"path/filepath"
"testing"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
// testEnv bundles all the objects needed for a grpcserver test.
type testEnv struct {
client pb.ProxyAdminServiceClient
conn *grpc.ClientConn
store *db.Store
srv *server.Server
}
func setup(t *testing.T) *testEnv {
t.Helper()
// Database in temp dir.
dbPath := filepath.Join(t.TempDir(), "test.db")
store, err := db.Open(dbPath)
if err != nil {
t.Fatalf("open db: %v", err)
}
t.Cleanup(func() { store.Close() })
if err := store.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
// Seed with one listener and one route.
listeners := []config.Listener{
{
Addr: ":443",
Routes: []config.Route{
{Hostname: "a.test", Backend: "127.0.0.1:8443"},
},
},
}
fw := config.Firewall{
BlockedIPs: []string{"10.0.0.1"},
}
if err := store.Seed(listeners, fw); err != nil {
t.Fatalf("seed: %v", err)
}
// Build server with matching in-memory state.
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
if err != nil {
t.Fatalf("firewall: %v", err)
}
cfg := &config.Config{
Proxy: config.Proxy{
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
IdleTimeout: config.Duration{Duration: 30 * time.Second},
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
},
}
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
// Load listener data from DB to get correct IDs.
dbListeners, err := store.ListListeners()
if err != nil {
t.Fatalf("list listeners: %v", err)
}
var listenerData []server.ListenerData
for _, l := range dbListeners {
dbRoutes, err := store.ListRoutes(l.ID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
routes := make(map[string]server.RouteInfo, len(dbRoutes))
for _, r := range dbRoutes {
routes[r.Hostname] = server.RouteInfo{
Backend: r.Backend,
Mode: r.Mode,
}
}
listenerData = append(listenerData, server.ListenerData{
ID: l.ID,
Addr: l.Addr,
ProxyProtocol: l.ProxyProtocol,
Routes: routes,
})
}
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
// Set up bufconn gRPC server (no TLS for tests).
lis := bufconn.Listen(1024 * 1024)
grpcSrv := grpc.NewServer()
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
pb.RegisterProxyAdminServiceServer(grpcSrv, admin)
go func() {
if err := grpcSrv.Serve(lis); err != nil {
t.Logf("grpc serve: %v", err)
}
}()
t.Cleanup(grpcSrv.Stop)
conn, err := grpc.NewClient("passthrough://bufconn",
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
return lis.DialContext(ctx)
}),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
t.Fatalf("dial bufconn: %v", err)
}
t.Cleanup(func() { conn.Close() })
return &testEnv{
client: pb.NewProxyAdminServiceClient(conn),
conn: conn,
store: store,
srv: srv,
}
}
func TestGetStatus(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
if err != nil {
t.Fatalf("GetStatus: %v", err)
}
if resp.Version != "test-version" {
t.Fatalf("got version %q, want %q", resp.Version, "test-version")
}
if len(resp.Listeners) != 1 {
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
}
if resp.Listeners[0].Addr != ":443" {
t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443")
}
if resp.Listeners[0].RouteCount != 1 {
t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount)
}
}
func TestListRoutes(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 1 {
t.Fatalf("got %d routes, want 1", len(resp.Routes))
}
if resp.Routes[0].Hostname != "a.test" {
t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test")
}
if resp.Routes[0].Backend != "127.0.0.1:8443" {
t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443")
}
}
func TestListRoutesNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"})
if err == nil {
t.Fatal("expected error for nonexistent listener")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
t.Fatalf("expected NotFound, got %v", err)
}
}
func TestAddRoute(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"},
})
if err != nil {
t.Fatalf("AddRoute: %v", err)
}
// Verify in-memory.
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 2 {
t.Fatalf("got %d routes, want 2", len(resp.Routes))
}
// Verify in DB.
dbListeners, err := env.store.ListListeners()
if err != nil {
t.Fatalf("list listeners: %v", err)
}
dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID)
if err != nil {
t.Fatalf("list routes: %v", err)
}
if len(dbRoutes) != 2 {
t.Fatalf("DB has %d routes, want 2", len(dbRoutes))
}
}
func TestAddRouteDuplicate(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
})
if err == nil {
t.Fatal("expected error for duplicate route")
}
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
t.Fatalf("expected AlreadyExists, got %v", err)
}
}
func TestAddRouteValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Missing route.
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"})
if err == nil {
t.Fatal("expected error for nil route")
}
// Missing hostname.
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Backend: "127.0.0.1:1"},
})
if err == nil {
t.Fatal("expected error for empty hostname")
}
// Missing backend.
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "x.test"},
})
if err == nil {
t.Fatal("expected error for empty backend")
}
// Invalid backend (not host:port).
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
ListenerAddr: ":443",
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
})
if err == nil {
t.Fatal("expected error for invalid backend address")
}
}
func TestRemoveRoute(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
ListenerAddr: ":443",
Hostname: "a.test",
})
if err != nil {
t.Fatalf("RemoveRoute: %v", err)
}
// Verify removed from memory.
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
if err != nil {
t.Fatalf("ListRoutes: %v", err)
}
if len(resp.Routes) != 0 {
t.Fatalf("got %d routes, want 0", len(resp.Routes))
}
}
func TestRemoveRouteNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
ListenerAddr: ":443",
Hostname: "nonexistent.test",
})
if err == nil {
t.Fatal("expected error for removing nonexistent route")
}
}
func TestGetFirewallRules(t *testing.T) {
env := setup(t)
ctx := context.Background()
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 1 {
t.Fatalf("got %d rules, want 1", len(resp.Rules))
}
if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP {
t.Fatalf("got type %v, want IP", resp.Rules[0].Type)
}
if resp.Rules[0].Value != "10.0.0.1" {
t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1")
}
}
func TestAddFirewallRule(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Add IP rule.
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "10.0.0.2",
},
})
if err != nil {
t.Fatalf("AddFirewallRule IP: %v", err)
}
// Add CIDR rule.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.0.0/16",
},
})
if err != nil {
t.Fatalf("AddFirewallRule CIDR: %v", err)
}
// Add country rule.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "RU",
},
})
if err != nil {
t.Fatalf("AddFirewallRule country: %v", err)
}
// Verify.
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 4 {
t.Fatalf("got %d rules, want 4", len(resp.Rules))
}
// Verify DB persistence.
dbRules, err := env.store.ListFirewallRules()
if err != nil {
t.Fatalf("list firewall rules: %v", err)
}
if len(dbRules) != 4 {
t.Fatalf("DB has %d rules, want 4", len(dbRules))
}
}
func TestAddFirewallRuleValidation(t *testing.T) {
env := setup(t)
ctx := context.Background()
// Nil rule.
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{})
if err == nil {
t.Fatal("expected error for nil rule")
}
// Unknown type.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED,
Value: "x",
},
})
if err == nil {
t.Fatal("expected error for unspecified rule type")
}
// Empty value.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
},
})
if err == nil {
t.Fatal("expected error for empty value")
}
// Invalid IP address.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "not-an-ip",
},
})
if err == nil {
t.Fatal("expected error for invalid IP")
}
// Invalid CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "not-a-cidr",
},
})
if err == nil {
t.Fatal("expected error for invalid CIDR")
}
// Non-canonical CIDR.
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: "192.168.1.5/16",
},
})
if err == nil {
t.Fatal("expected error for non-canonical CIDR")
}
// Invalid country code (lowercase).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "cn",
},
})
if err == nil {
t.Fatal("expected error for lowercase country code")
}
// Invalid country code (too long).
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: "USA",
},
})
if err == nil {
t.Fatal("expected error for 3-letter country code")
}
}
func TestRemoveFirewallRule(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "10.0.0.1",
},
})
if err != nil {
t.Fatalf("RemoveFirewallRule: %v", err)
}
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
if err != nil {
t.Fatalf("GetFirewallRules: %v", err)
}
if len(resp.Rules) != 0 {
t.Fatalf("got %d rules, want 0", len(resp.Rules))
}
}
func TestRemoveFirewallRuleNotFound(t *testing.T) {
env := setup(t)
ctx := context.Background()
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
Rule: &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: "99.99.99.99",
},
})
if err == nil {
t.Fatal("expected error for removing nonexistent rule")
}
}