Extend the config, database schema, and server internals to support per-route L4/L7 mode selection and PROXY protocol fields. This is the foundation for L7 HTTP/2 reverse proxying and multi-hop PROXY protocol support described in the updated ARCHITECTURE.md. Config: Listener gains ProxyProtocol; Route gains Mode, TLSCert, TLSKey, BackendTLS, SendProxyProtocol. L7 routes validated at load time (cert/key pair must exist and parse). Mode defaults to "l4". DB: Migration v2 adds columns to listeners and routes tables. CRUD and seeding updated to persist all new fields. Server: RouteInfo replaces bare backend string in route lookup. handleConn dispatches on route.Mode (L7 path stubbed with error). ListenerState and ListenerData carry ProxyProtocol flag. All existing L4 tests pass unchanged. New tests cover migration v2, L7 field persistence, config validation for mode/cert/key, and proxy_protocol flag round-tripping. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
336 lines
8.9 KiB
Go
336 lines
8.9 KiB
Go
package mcproxy
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/health"
|
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/grpcserver"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
|
)
|
|
|
|
func setupTestClient(t *testing.T) *Client {
|
|
t.Helper()
|
|
|
|
// Database in temp dir.
|
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
|
store, err := db.Open(dbPath)
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
t.Cleanup(func() { store.Close() })
|
|
|
|
if err := store.Migrate(); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
// Seed with one listener and one route.
|
|
listeners := []config.Listener{
|
|
{
|
|
Addr: ":443",
|
|
Routes: []config.Route{
|
|
{Hostname: "example.test", Backend: "127.0.0.1:8443"},
|
|
},
|
|
},
|
|
}
|
|
fw := config.Firewall{
|
|
BlockedIPs: []string{"10.0.0.1"},
|
|
}
|
|
if err := store.Seed(listeners, fw); err != nil {
|
|
t.Fatalf("seed: %v", err)
|
|
}
|
|
|
|
// Build server with matching in-memory state.
|
|
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
|
|
if err != nil {
|
|
t.Fatalf("firewall: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Proxy: config.Proxy{
|
|
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
|
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
|
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
|
},
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
|
|
// Load listener data from DB to get correct IDs.
|
|
dbListeners, err := store.ListListeners()
|
|
if err != nil {
|
|
t.Fatalf("list listeners: %v", err)
|
|
}
|
|
var listenerData []server.ListenerData
|
|
for _, l := range dbListeners {
|
|
dbRoutes, err := store.ListRoutes(l.ID)
|
|
if err != nil {
|
|
t.Fatalf("list routes: %v", err)
|
|
}
|
|
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
|
for _, r := range dbRoutes {
|
|
routes[r.Hostname] = server.RouteInfo{
|
|
Backend: r.Backend,
|
|
Mode: r.Mode,
|
|
}
|
|
}
|
|
listenerData = append(listenerData, server.ListenerData{
|
|
ID: l.ID,
|
|
Addr: l.Addr,
|
|
ProxyProtocol: l.ProxyProtocol,
|
|
Routes: routes,
|
|
})
|
|
}
|
|
|
|
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
|
|
|
|
// Set up bufconn gRPC server.
|
|
lis := bufconn.Listen(1024 * 1024)
|
|
grpcSrv := grpc.NewServer()
|
|
|
|
pb.RegisterProxyAdminServiceServer(grpcSrv, &testAdminServer{
|
|
srv: srv,
|
|
store: store,
|
|
logger: logger,
|
|
})
|
|
|
|
// Register health service.
|
|
healthServer := health.NewServer()
|
|
healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
|
|
healthpb.RegisterHealthServer(grpcSrv, healthServer)
|
|
|
|
go func() {
|
|
if err := grpcSrv.Serve(lis); err != nil {
|
|
t.Logf("grpc serve: %v", err)
|
|
}
|
|
}()
|
|
t.Cleanup(grpcSrv.Stop)
|
|
|
|
conn, err := grpc.NewClient("passthrough://bufconn",
|
|
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
|
|
return lis.DialContext(ctx)
|
|
}),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("dial bufconn: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
return &Client{
|
|
conn: conn,
|
|
admin: pb.NewProxyAdminServiceClient(conn),
|
|
health: healthpb.NewHealthClient(conn),
|
|
}
|
|
}
|
|
|
|
// testAdminServer is a minimal implementation for testing.
|
|
// It delegates to the real grpcserver.AdminServer logic.
|
|
type testAdminServer struct {
|
|
pb.UnimplementedProxyAdminServiceServer
|
|
srv *server.Server
|
|
store *db.Store
|
|
logger *slog.Logger
|
|
}
|
|
|
|
func (s *testAdminServer) GetStatus(ctx context.Context, req *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).GetStatus(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) ListRoutes(ctx context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).ListRoutes(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) AddRoute(ctx context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).AddRoute(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) RemoveRoute(ctx context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).RemoveRoute(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) GetFirewallRules(ctx context.Context, req *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).GetFirewallRules(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) AddFirewallRule(ctx context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).AddFirewallRule(ctx, req)
|
|
}
|
|
|
|
func (s *testAdminServer) RemoveFirewallRule(ctx context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
|
|
return grpcserver.NewAdminServer(s.srv, s.store, s.logger).RemoveFirewallRule(ctx, req)
|
|
}
|
|
|
|
func TestClientGetStatus(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
status, err := client.GetStatus(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetStatus: %v", err)
|
|
}
|
|
|
|
if status.Version != "test-version" {
|
|
t.Errorf("got version %q, want %q", status.Version, "test-version")
|
|
}
|
|
if len(status.Listeners) != 1 {
|
|
t.Errorf("got %d listeners, want 1", len(status.Listeners))
|
|
}
|
|
if status.Listeners[0].Addr != ":443" {
|
|
t.Errorf("got listener addr %q, want %q", status.Listeners[0].Addr, ":443")
|
|
}
|
|
}
|
|
|
|
func TestClientListRoutes(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
routes, err := client.ListRoutes(ctx, ":443")
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
|
|
if len(routes) != 1 {
|
|
t.Fatalf("got %d routes, want 1", len(routes))
|
|
}
|
|
if routes[0].Hostname != "example.test" {
|
|
t.Errorf("got hostname %q, want %q", routes[0].Hostname, "example.test")
|
|
}
|
|
if routes[0].Backend != "127.0.0.1:8443" {
|
|
t.Errorf("got backend %q, want %q", routes[0].Backend, "127.0.0.1:8443")
|
|
}
|
|
}
|
|
|
|
func TestClientAddRemoveRoute(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
// Add a new route.
|
|
err := client.AddRoute(ctx, ":443", "new.test", "127.0.0.1:9443")
|
|
if err != nil {
|
|
t.Fatalf("AddRoute: %v", err)
|
|
}
|
|
|
|
// Verify it was added.
|
|
routes, err := client.ListRoutes(ctx, ":443")
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(routes) != 2 {
|
|
t.Fatalf("got %d routes after add, want 2", len(routes))
|
|
}
|
|
|
|
// Remove the route.
|
|
err = client.RemoveRoute(ctx, ":443", "new.test")
|
|
if err != nil {
|
|
t.Fatalf("RemoveRoute: %v", err)
|
|
}
|
|
|
|
// Verify it was removed.
|
|
routes, err = client.ListRoutes(ctx, ":443")
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(routes) != 1 {
|
|
t.Fatalf("got %d routes after remove, want 1", len(routes))
|
|
}
|
|
}
|
|
|
|
func TestClientGetFirewallRules(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
rules, err := client.GetFirewallRules(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
|
|
if len(rules) != 1 {
|
|
t.Fatalf("got %d rules, want 1", len(rules))
|
|
}
|
|
if rules[0].Type != FirewallRuleIP {
|
|
t.Errorf("got type %q, want %q", rules[0].Type, FirewallRuleIP)
|
|
}
|
|
if rules[0].Value != "10.0.0.1" {
|
|
t.Errorf("got value %q, want %q", rules[0].Value, "10.0.0.1")
|
|
}
|
|
}
|
|
|
|
func TestClientAddRemoveFirewallRule(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
// Add a CIDR rule.
|
|
err := client.AddFirewallRule(ctx, FirewallRuleCIDR, "192.168.0.0/16")
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule: %v", err)
|
|
}
|
|
|
|
// Verify it was added.
|
|
rules, err := client.GetFirewallRules(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(rules) != 2 {
|
|
t.Fatalf("got %d rules after add, want 2", len(rules))
|
|
}
|
|
|
|
// Remove the rule.
|
|
err = client.RemoveFirewallRule(ctx, FirewallRuleCIDR, "192.168.0.0/16")
|
|
if err != nil {
|
|
t.Fatalf("RemoveFirewallRule: %v", err)
|
|
}
|
|
|
|
// Verify it was removed.
|
|
rules, err = client.GetFirewallRules(ctx)
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(rules) != 1 {
|
|
t.Fatalf("got %d rules after remove, want 1", len(rules))
|
|
}
|
|
}
|
|
|
|
func TestClientCheckHealth(t *testing.T) {
|
|
client := setupTestClient(t)
|
|
ctx := context.Background()
|
|
|
|
status, err := client.CheckHealth(ctx)
|
|
if err != nil {
|
|
t.Fatalf("CheckHealth: %v", err)
|
|
}
|
|
|
|
if status != HealthServing {
|
|
t.Errorf("got health status %v, want %v", status, HealthServing)
|
|
}
|
|
}
|
|
|
|
func TestHealthStatusString(t *testing.T) {
|
|
tests := []struct {
|
|
status HealthStatus
|
|
want string
|
|
}{
|
|
{HealthUnknown, "UNKNOWN"},
|
|
{HealthServing, "SERVING"},
|
|
{HealthNotServing, "NOT_SERVING"},
|
|
}
|
|
for _, tt := range tests {
|
|
if got := tt.status.String(); got != tt.want {
|
|
t.Errorf("HealthStatus(%d).String() = %q, want %q", tt.status, got, tt.want)
|
|
}
|
|
}
|
|
}
|