Add documentation, Docker setup, and tests for server and gRPC packages
Rewrite README with project overview and quick start. Add RUNBOOK with operational procedures and incident playbooks. Fix Dockerfile for Go 1.25 with version injection. Add docker-compose.yml. Clean up golangci.yaml for mc-proxy. Add server tests (10) covering the full proxy pipeline with TCP echo backends, and grpcserver tests (13) covering all admin API RPCs with bufconn and write-through DB verification. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
451
internal/grpcserver/grpcserver_test.go
Normal file
451
internal/grpcserver/grpcserver_test.go
Normal file
@@ -0,0 +1,451 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/grpc/test/bufconn"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/firewall"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||
)
|
||||
|
||||
// testEnv bundles all the objects needed for a grpcserver test.
|
||||
type testEnv struct {
|
||||
client pb.ProxyAdminServiceClient
|
||||
conn *grpc.ClientConn
|
||||
store *db.Store
|
||||
srv *server.Server
|
||||
}
|
||||
|
||||
func setup(t *testing.T) *testEnv {
|
||||
t.Helper()
|
||||
|
||||
// Database in temp dir.
|
||||
dbPath := filepath.Join(t.TempDir(), "test.db")
|
||||
store, err := db.Open(dbPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { store.Close() })
|
||||
|
||||
if err := store.Migrate(); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
|
||||
// Seed with one listener and one route.
|
||||
listeners := []config.Listener{
|
||||
{
|
||||
Addr: ":443",
|
||||
Routes: []config.Route{
|
||||
{Hostname: "a.test", Backend: "127.0.0.1:8443"},
|
||||
},
|
||||
},
|
||||
}
|
||||
fw := config.Firewall{
|
||||
BlockedIPs: []string{"10.0.0.1"},
|
||||
}
|
||||
if err := store.Seed(listeners, fw); err != nil {
|
||||
t.Fatalf("seed: %v", err)
|
||||
}
|
||||
|
||||
// Build server with matching in-memory state.
|
||||
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("firewall: %v", err)
|
||||
}
|
||||
|
||||
cfg := &config.Config{
|
||||
Proxy: config.Proxy{
|
||||
ConnectTimeout: config.Duration{Duration: 5 * time.Second},
|
||||
IdleTimeout: config.Duration{Duration: 30 * time.Second},
|
||||
ShutdownTimeout: config.Duration{Duration: 5 * time.Second},
|
||||
},
|
||||
}
|
||||
logger := slog.New(slog.NewTextHandler(io.Discard, nil))
|
||||
|
||||
// Load listener data from DB to get correct IDs.
|
||||
dbListeners, err := store.ListListeners()
|
||||
if err != nil {
|
||||
t.Fatalf("list listeners: %v", err)
|
||||
}
|
||||
var listenerData []server.ListenerData
|
||||
for _, l := range dbListeners {
|
||||
dbRoutes, err := store.ListRoutes(l.ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list routes: %v", err)
|
||||
}
|
||||
routes := make(map[string]string, len(dbRoutes))
|
||||
for _, r := range dbRoutes {
|
||||
routes[r.Hostname] = r.Backend
|
||||
}
|
||||
listenerData = append(listenerData, server.ListenerData{
|
||||
ID: l.ID,
|
||||
Addr: l.Addr,
|
||||
Routes: routes,
|
||||
})
|
||||
}
|
||||
|
||||
srv := server.New(cfg, fwObj, listenerData, logger, "test-version")
|
||||
|
||||
// Set up bufconn gRPC server (no TLS for tests).
|
||||
lis := bufconn.Listen(1024 * 1024)
|
||||
grpcSrv := grpc.NewServer()
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
pb.RegisterProxyAdminServiceServer(grpcSrv, admin)
|
||||
|
||||
go func() {
|
||||
if err := grpcSrv.Serve(lis); err != nil {
|
||||
t.Logf("grpc serve: %v", err)
|
||||
}
|
||||
}()
|
||||
t.Cleanup(grpcSrv.Stop)
|
||||
|
||||
conn, err := grpc.NewClient("passthrough://bufconn",
|
||||
grpc.WithContextDialer(func(ctx context.Context, _ string) (net.Conn, error) {
|
||||
return lis.DialContext(ctx)
|
||||
}),
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("dial bufconn: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { conn.Close() })
|
||||
|
||||
return &testEnv{
|
||||
client: pb.NewProxyAdminServiceClient(conn),
|
||||
conn: conn,
|
||||
store: store,
|
||||
srv: srv,
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetStatus(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := env.client.GetStatus(ctx, &pb.GetStatusRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetStatus: %v", err)
|
||||
}
|
||||
if resp.Version != "test-version" {
|
||||
t.Fatalf("got version %q, want %q", resp.Version, "test-version")
|
||||
}
|
||||
if len(resp.Listeners) != 1 {
|
||||
t.Fatalf("got %d listeners, want 1", len(resp.Listeners))
|
||||
}
|
||||
if resp.Listeners[0].Addr != ":443" {
|
||||
t.Fatalf("got listener addr %q, want %q", resp.Listeners[0].Addr, ":443")
|
||||
}
|
||||
if resp.Listeners[0].RouteCount != 1 {
|
||||
t.Fatalf("got route count %d, want 1", resp.Listeners[0].RouteCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRoutes(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRoutes: %v", err)
|
||||
}
|
||||
if len(resp.Routes) != 1 {
|
||||
t.Fatalf("got %d routes, want 1", len(resp.Routes))
|
||||
}
|
||||
if resp.Routes[0].Hostname != "a.test" {
|
||||
t.Fatalf("got hostname %q, want %q", resp.Routes[0].Hostname, "a.test")
|
||||
}
|
||||
if resp.Routes[0].Backend != "127.0.0.1:8443" {
|
||||
t.Fatalf("got backend %q, want %q", resp.Routes[0].Backend, "127.0.0.1:8443")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRoutesNotFound(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":9999"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nonexistent listener")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.NotFound {
|
||||
t.Fatalf("expected NotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRoute(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Route: &pb.Route{Hostname: "b.test", Backend: "127.0.0.1:9443"},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddRoute: %v", err)
|
||||
}
|
||||
|
||||
// Verify in-memory.
|
||||
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRoutes: %v", err)
|
||||
}
|
||||
if len(resp.Routes) != 2 {
|
||||
t.Fatalf("got %d routes, want 2", len(resp.Routes))
|
||||
}
|
||||
|
||||
// Verify in DB.
|
||||
dbListeners, err := env.store.ListListeners()
|
||||
if err != nil {
|
||||
t.Fatalf("list listeners: %v", err)
|
||||
}
|
||||
dbRoutes, err := env.store.ListRoutes(dbListeners[0].ID)
|
||||
if err != nil {
|
||||
t.Fatalf("list routes: %v", err)
|
||||
}
|
||||
if len(dbRoutes) != 2 {
|
||||
t.Fatalf("DB has %d routes, want 2", len(dbRoutes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRouteDuplicate(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate route")
|
||||
}
|
||||
if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists {
|
||||
t.Fatalf("expected AlreadyExists, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRouteValidation(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Missing route.
|
||||
_, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ListenerAddr: ":443"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil route")
|
||||
}
|
||||
|
||||
// Missing hostname.
|
||||
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Route: &pb.Route{Backend: "127.0.0.1:1"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty hostname")
|
||||
}
|
||||
|
||||
// Missing backend.
|
||||
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Route: &pb.Route{Hostname: "x.test"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty backend")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRoute(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "a.test",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveRoute: %v", err)
|
||||
}
|
||||
|
||||
// Verify removed from memory.
|
||||
resp, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"})
|
||||
if err != nil {
|
||||
t.Fatalf("ListRoutes: %v", err)
|
||||
}
|
||||
if len(resp.Routes) != 0 {
|
||||
t.Fatalf("got %d routes, want 0", len(resp.Routes))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRouteNotFound(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.RemoveRoute(ctx, &pb.RemoveRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Hostname: "nonexistent.test",
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for removing nonexistent route")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetFirewallRules(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFirewallRules: %v", err)
|
||||
}
|
||||
if len(resp.Rules) != 1 {
|
||||
t.Fatalf("got %d rules, want 1", len(resp.Rules))
|
||||
}
|
||||
if resp.Rules[0].Type != pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP {
|
||||
t.Fatalf("got type %v, want IP", resp.Rules[0].Type)
|
||||
}
|
||||
if resp.Rules[0].Value != "10.0.0.1" {
|
||||
t.Fatalf("got value %q, want %q", resp.Rules[0].Value, "10.0.0.1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFirewallRule(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Add IP rule.
|
||||
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: "10.0.0.2",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddFirewallRule IP: %v", err)
|
||||
}
|
||||
|
||||
// Add CIDR rule.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||
Value: "192.168.0.0/16",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddFirewallRule CIDR: %v", err)
|
||||
}
|
||||
|
||||
// Add country rule.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||
Value: "RU",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("AddFirewallRule country: %v", err)
|
||||
}
|
||||
|
||||
// Verify.
|
||||
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFirewallRules: %v", err)
|
||||
}
|
||||
if len(resp.Rules) != 4 {
|
||||
t.Fatalf("got %d rules, want 4", len(resp.Rules))
|
||||
}
|
||||
|
||||
// Verify DB persistence.
|
||||
dbRules, err := env.store.ListFirewallRules()
|
||||
if err != nil {
|
||||
t.Fatalf("list firewall rules: %v", err)
|
||||
}
|
||||
if len(dbRules) != 4 {
|
||||
t.Fatalf("DB has %d rules, want 4", len(dbRules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddFirewallRuleValidation(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
// Nil rule.
|
||||
_, err := env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for nil rule")
|
||||
}
|
||||
|
||||
// Unknown type.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_UNSPECIFIED,
|
||||
Value: "x",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for unspecified rule type")
|
||||
}
|
||||
|
||||
// Empty value.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty value")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirewallRule(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: "10.0.0.1",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("RemoveFirewallRule: %v", err)
|
||||
}
|
||||
|
||||
resp, err := env.client.GetFirewallRules(ctx, &pb.GetFirewallRulesRequest{})
|
||||
if err != nil {
|
||||
t.Fatalf("GetFirewallRules: %v", err)
|
||||
}
|
||||
if len(resp.Rules) != 0 {
|
||||
t.Fatalf("got %d rules, want 0", len(resp.Rules))
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirewallRuleNotFound(t *testing.T) {
|
||||
env := setup(t)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err := env.client.RemoveFirewallRule(ctx, &pb.RemoveFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: "99.99.99.99",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for removing nonexistent rule")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user