Proto: Route message gains mode, tls_cert, tls_key, backend_tls,
send_proxy_protocol fields. ListenerStatus gains proxy_protocol.
Generated code regenerated with protoc v29.5.
gRPC server: AddRoute validates mode ("l4"/"l7", defaults to "l4"),
requires tls_cert/tls_key for L7 routes, persists all fields via
write-through. ListRoutes returns full route info. GetStatus
includes proxy_protocol on listener status.
Client package: Route struct expanded with Mode, TLSCert, TLSKey,
BackendTLS, SendProxyProtocol. AddRoute signature changed to accept
a Route struct instead of individual hostname/backend strings.
ListenerStatus gains ProxyProtocol. ListRoutes maps all proto fields.
mcproxyctl: routes add gains --mode, --tls-cert, --tls-key,
--backend-tls, --send-proxy-protocol flags. routes list displays
mode and option tags for each route.
New tests: add L7 route via gRPC with field round-trip verification,
L7 route missing cert/key (InvalidArgument), invalid mode rejection,
default-to-L4 backward compatibility, proxy_protocol in status.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
679 lines
17 KiB
Go
679 lines
17 KiB
Go
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"log/slog"
|
|
"net"
|
|
"path/filepath"
|
|
"testing"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/credentials/insecure"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/grpc/test/bufconn"
|
|
|
|
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
|
)
|
|
|
|
// testEnv bundles all the objects needed for a grpcserver test.
|
|
type testEnv struct {
|
|
client pb.ProxyAdminServiceClient
|
|
conn *grpc.ClientConn
|
|
store *db.Store
|
|
srv *server.Server
|
|
}
|
|
|
|
func setup(t *testing.T) *testEnv {
|
|
t.Helper()
|
|
|
|
// Database in temp dir.
|
|
dbPath := filepath.Join(t.TempDir(), "test.db")
|
|
store, err := db.Open(dbPath)
|
|
if err != nil {
|
|
t.Fatalf("open db: %v", err)
|
|
}
|
|
t.Cleanup(func() { store.Close() })
|
|
|
|
if err := store.Migrate(); err != nil {
|
|
t.Fatalf("migrate: %v", err)
|
|
}
|
|
|
|
// Seed with one listener and one route.
|
|
listeners := []config.Listener{
|
|
{
|
|
Addr: ":443",
|
|
Routes: []config.Route{
|
|
{Hostname: "a.test", Backend: "127.0.0.1:8443"},
|
|
},
|
|
},
|
|
}
|
|
fw := config.Firewall{
|
|
BlockedIPs: []string{"10.0.0.1"},
|
|
}
|
|
if err := store.Seed(listeners, fw); err != nil {
|
|
t.Fatalf("seed: %v", err)
|
|
}
|
|
|
|
// Build server with matching in-memory state.
|
|
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
|
|
if err != nil {
|
|
t.Fatalf("firewall: %v", err)
|
|
}
|
|
|
|
cfg := &config.Config{
|
|
Proxy: config.Proxy{
|
|
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
|
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
|
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
|
},
|
|
}
|
|
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
|
|
|
// Load listener data from DB to get correct IDs.
|
|
dbListeners, err := store.ListListeners()
|
|
if err != nil {
|
|
t.Fatalf("list listeners: %v", err)
|
|
}
|
|
var listenerData []server.ListenerData
|
|
for _, l := range dbListeners {
|
|
dbRoutes, err := store.ListRoutes(l.ID)
|
|
if err != nil {
|
|
t.Fatalf("list routes: %v", err)
|
|
}
|
|
routes := make(map[string]server.RouteInfo, len(dbRoutes))
|
|
for _, r := range dbRoutes {
|
|
routes[r.Hostname] = server.RouteInfo{
|
|
Backend: r.Backend,
|
|
Mode: r.Mode,
|
|
}
|
|
}
|
|
listenerData = append(listenerData, server.ListenerData{
|
|
ID: l.ID,
|
|
Addr: l.Addr,
|
|
ProxyProtocol: l.ProxyProtocol,
|
|
Routes: routes,
|
|
})
|
|
}
|
|
|
|
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
|
|
|
|
// Set up bufconn gRPC server (no TLS for tests).
|
|
lis := bufconn.Listen(1024 * 1024)
|
|
grpcSrv := grpc.NewServer()
|
|
admin := &AdminServer{
|
|
srv: srv,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
pb.RegisterProxyAdminServiceServer(grpcSrv, admin)
|
|
|
|
go func() {
|
|
if err := grpcSrv.Serve(lis); err != nil {
|
|
t.Logf("grpc serve: %v", err)
|
|
}
|
|
}()
|
|
t.Cleanup(grpcSrv.Stop)
|
|
|
|
conn, err := grpc.NewClient("passthrough://bufconn",
|
|
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
|
|
return lis.DialContext(ctx)
|
|
}),
|
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
|
)
|
|
if err != nil {
|
|
t.Fatalf("dial bufconn: %v", err)
|
|
}
|
|
t.Cleanup(func() { conn.Close() })
|
|
|
|
return &testEnv{
|
|
client: pb.NewProxyAdminServiceClient(conn),
|
|
conn: conn,
|
|
store: store,
|
|
srv: srv,
|
|
}
|
|
}
|
|
|
|
func TestGetStatus(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetStatus: %v", err)
|
|
}
|
|
if resp.Version != "test-version" {
|
|
t.Fatalf("got version %q, want %q", resp.Version, "test-version")
|
|
}
|
|
if len(resp.Listeners) != 1 {
|
|
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
|
|
}
|
|
if resp.Listeners[0].Addr != ":443" {
|
|
t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443")
|
|
}
|
|
if resp.Listeners[0].RouteCount != 1 {
|
|
t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount)
|
|
}
|
|
}
|
|
|
|
func TestListRoutes(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 1 {
|
|
t.Fatalf("got %d routes, want 1", len(resp.Routes))
|
|
}
|
|
if resp.Routes[0].Hostname != "a.test" {
|
|
t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test")
|
|
}
|
|
if resp.Routes[0].Backend != "127.0.0.1:8443" {
|
|
t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443")
|
|
}
|
|
}
|
|
|
|
func TestListRoutesNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"})
|
|
if err == nil {
|
|
t.Fatal("expected error for nonexistent listener")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
|
|
t.Fatalf("expected NotFound, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRoute(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddRoute: %v", err)
|
|
}
|
|
|
|
// Verify in-memory.
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 2 {
|
|
t.Fatalf("got %d routes, want 2", len(resp.Routes))
|
|
}
|
|
|
|
// Verify in DB.
|
|
dbListeners, err := env.store.ListListeners()
|
|
if err != nil {
|
|
t.Fatalf("list listeners: %v", err)
|
|
}
|
|
dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID)
|
|
if err != nil {
|
|
t.Fatalf("list routes: %v", err)
|
|
}
|
|
if len(dbRoutes) != 2 {
|
|
t.Fatalf("DB has %d routes, want 2", len(dbRoutes))
|
|
}
|
|
}
|
|
|
|
func TestAddRouteDuplicate(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for duplicate route")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
|
|
t.Fatalf("expected AlreadyExists, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRouteValidation(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Missing route.
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"})
|
|
if err == nil {
|
|
t.Fatal("expected error for nil route")
|
|
}
|
|
|
|
// Missing hostname.
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Backend: "127.0.0.1:1"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty hostname")
|
|
}
|
|
|
|
// Missing backend.
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "x.test"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty backend")
|
|
}
|
|
|
|
// Invalid backend (not host:port).
|
|
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid backend address")
|
|
}
|
|
}
|
|
|
|
func TestRemoveRoute(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Hostname: "a.test",
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RemoveRoute: %v", err)
|
|
}
|
|
|
|
// Verify removed from memory.
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
if len(resp.Routes) != 0 {
|
|
t.Fatalf("got %d routes, want 0", len(resp.Routes))
|
|
}
|
|
}
|
|
|
|
func TestRemoveRouteNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Hostname: "nonexistent.test",
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for removing nonexistent route")
|
|
}
|
|
}
|
|
|
|
func TestGetFirewallRules(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 1 {
|
|
t.Fatalf("got %d rules, want 1", len(resp.Rules))
|
|
}
|
|
if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP {
|
|
t.Fatalf("got type %v, want IP", resp.Rules[0].Type)
|
|
}
|
|
if resp.Rules[0].Value != "10.0.0.1" {
|
|
t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1")
|
|
}
|
|
}
|
|
|
|
func TestAddFirewallRule(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Add IP rule.
|
|
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "10.0.0.2",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule IP: %v", err)
|
|
}
|
|
|
|
// Add CIDR rule.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "192.168.0.0/16",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule CIDR: %v", err)
|
|
}
|
|
|
|
// Add country rule.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "RU",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddFirewallRule country: %v", err)
|
|
}
|
|
|
|
// Verify.
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 4 {
|
|
t.Fatalf("got %d rules, want 4", len(resp.Rules))
|
|
}
|
|
|
|
// Verify DB persistence.
|
|
dbRules, err := env.store.ListFirewallRules()
|
|
if err != nil {
|
|
t.Fatalf("list firewall rules: %v", err)
|
|
}
|
|
if len(dbRules) != 4 {
|
|
t.Fatalf("DB has %d rules, want 4", len(dbRules))
|
|
}
|
|
}
|
|
|
|
func TestAddFirewallRuleValidation(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Nil rule.
|
|
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{})
|
|
if err == nil {
|
|
t.Fatal("expected error for nil rule")
|
|
}
|
|
|
|
// Unknown type.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED,
|
|
Value: "x",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for unspecified rule type")
|
|
}
|
|
|
|
// Empty value.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for empty value")
|
|
}
|
|
|
|
// Invalid IP address.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "not-an-ip",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid IP")
|
|
}
|
|
|
|
// Invalid CIDR.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "not-a-cidr",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid CIDR")
|
|
}
|
|
|
|
// Non-canonical CIDR.
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: "192.168.1.5/16",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for non-canonical CIDR")
|
|
}
|
|
|
|
// Invalid country code (lowercase).
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "cn",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for lowercase country code")
|
|
}
|
|
|
|
// Invalid country code (too long).
|
|
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: "USA",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for 3-letter country code")
|
|
}
|
|
}
|
|
|
|
func TestRemoveFirewallRule(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "10.0.0.1",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("RemoveFirewallRule: %v", err)
|
|
}
|
|
|
|
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetFirewallRules: %v", err)
|
|
}
|
|
if len(resp.Rules) != 0 {
|
|
t.Fatalf("got %d rules, want 0", len(resp.Rules))
|
|
}
|
|
}
|
|
|
|
func TestRemoveFirewallRuleNotFound(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
|
Rule: &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: "99.99.99.99",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for removing nonexistent rule")
|
|
}
|
|
}
|
|
|
|
func TestAddRouteL7(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{
|
|
Hostname: "l7.test",
|
|
Backend: "127.0.0.1:8080",
|
|
Mode: "l7",
|
|
TlsCert: "/certs/l7.crt",
|
|
TlsKey: "/certs/l7.key",
|
|
BackendTls: false,
|
|
SendProxyProtocol: true,
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddRoute L7: %v", err)
|
|
}
|
|
|
|
// Verify in-memory via ListRoutes.
|
|
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
if err != nil {
|
|
t.Fatalf("ListRoutes: %v", err)
|
|
}
|
|
|
|
var found *pb.Route
|
|
for _, r := range resp.Routes {
|
|
if r.Hostname == "l7.test" {
|
|
found = r
|
|
break
|
|
}
|
|
}
|
|
if found == nil {
|
|
t.Fatal("L7 route not found in ListRoutes response")
|
|
}
|
|
if found.Mode != "l7" {
|
|
t.Fatalf("mode = %q, want %q", found.Mode, "l7")
|
|
}
|
|
if found.TlsCert != "/certs/l7.crt" {
|
|
t.Fatalf("tls_cert = %q, want %q", found.TlsCert, "/certs/l7.crt")
|
|
}
|
|
if found.TlsKey != "/certs/l7.key" {
|
|
t.Fatalf("tls_key = %q, want %q", found.TlsKey, "/certs/l7.key")
|
|
}
|
|
if found.BackendTls {
|
|
t.Fatal("expected backend_tls = false")
|
|
}
|
|
if !found.SendProxyProtocol {
|
|
t.Fatal("expected send_proxy_protocol = true")
|
|
}
|
|
|
|
// Verify DB persistence.
|
|
dbListeners, _ := env.store.ListListeners()
|
|
dbRoutes, _ := env.store.ListRoutes(dbListeners[0].ID)
|
|
var dbRoute *db.Route
|
|
for i := range dbRoutes {
|
|
if dbRoutes[i].Hostname == "l7.test" {
|
|
dbRoute = &dbRoutes[i]
|
|
break
|
|
}
|
|
}
|
|
if dbRoute == nil {
|
|
t.Fatal("L7 route not found in DB")
|
|
}
|
|
if dbRoute.Mode != "l7" {
|
|
t.Fatalf("DB mode = %q, want %q", dbRoute.Mode, "l7")
|
|
}
|
|
if !dbRoute.SendProxyProtocol {
|
|
t.Fatal("DB send_proxy_protocol should be true")
|
|
}
|
|
}
|
|
|
|
func TestAddRouteL7MissingCert(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{
|
|
Hostname: "nocert.test",
|
|
Backend: "127.0.0.1:8080",
|
|
Mode: "l7",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for L7 route without cert/key")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument {
|
|
t.Fatalf("expected InvalidArgument, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRouteInvalidMode(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{
|
|
Hostname: "badmode.test",
|
|
Backend: "127.0.0.1:8080",
|
|
Mode: "l5",
|
|
},
|
|
})
|
|
if err == nil {
|
|
t.Fatal("expected error for invalid mode")
|
|
}
|
|
if s, ok := status.FromError(err); !ok || s.Code() != codes.InvalidArgument {
|
|
t.Fatalf("expected InvalidArgument, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestAddRouteDefaultsToL4(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
// Add route without specifying mode — should default to "l4".
|
|
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
|
ListenerAddr: ":443",
|
|
Route: &pb.Route{
|
|
Hostname: "default.test",
|
|
Backend: "127.0.0.1:9443",
|
|
},
|
|
})
|
|
if err != nil {
|
|
t.Fatalf("AddRoute: %v", err)
|
|
}
|
|
|
|
resp, _ := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
|
for _, r := range resp.Routes {
|
|
if r.Hostname == "default.test" {
|
|
if r.Mode != "l4" {
|
|
t.Fatalf("mode = %q, want %q", r.Mode, "l4")
|
|
}
|
|
return
|
|
}
|
|
}
|
|
t.Fatal("route not found")
|
|
}
|
|
|
|
func TestGetStatusProxyProtocol(t *testing.T) {
|
|
env := setup(t)
|
|
ctx := context.Background()
|
|
|
|
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
|
|
if err != nil {
|
|
t.Fatalf("GetStatus: %v", err)
|
|
}
|
|
|
|
// The seeded listener has proxy_protocol = false.
|
|
if len(resp.Listeners) != 1 {
|
|
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
|
|
}
|
|
if resp.Listeners[0].ProxyProtocol {
|
|
t.Fatal("expected proxy_protocol = false")
|
|
}
|
|
}
|