diff --git a/internal/db/db_test.go b/internal/db/db_test.go index ac8dc88..58d6ab1 100644 --- a/internal/db/db_test.go +++ b/internal/db/db_test.go @@ -250,15 +250,30 @@ func TestRouteL7Fields(t *testing.T) { } } -func TestRouteDuplicateHostname(t *testing.T) { +func TestRouteUpsert(t *testing.T) { store := openTestDB(t) listenerID, _ := store.CreateListener(":443", false, 0) if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:8443", "l4", "", "", false, false); err != nil { t.Fatalf("first create: %v", err) } - if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l4", "", "", false, false); err == nil { - t.Fatal("expected error for duplicate hostname on same listener") + // Same (listener, hostname) with different backend — should upsert, not error. + if _, err := store.CreateRoute(listenerID, "example.com", "127.0.0.1:9443", "l7", "/cert.pem", "/key.pem", false, false); err != nil { + t.Fatalf("upsert: %v", err) + } + + routes, err := store.ListRoutes(listenerID) + if err != nil { + t.Fatalf("list routes: %v", err) + } + if len(routes) != 1 { + t.Fatalf("expected 1 route after upsert, got %d", len(routes)) + } + if routes[0].Backend != "127.0.0.1:9443" { + t.Fatalf("expected updated backend, got %q", routes[0].Backend) + } + if routes[0].Mode != "l7" { + t.Fatalf("expected updated mode, got %q", routes[0].Mode) } } diff --git a/internal/db/routes.go b/internal/db/routes.go index cf7526e..b4a8fd2 100644 --- a/internal/db/routes.go +++ b/internal/db/routes.go @@ -39,11 +39,20 @@ func (s *Store) ListRoutes(listenerID int64) ([]Route, error) { return routes, rows.Err() } -// CreateRoute inserts a route and returns its ID. +// CreateRoute inserts or updates a route and returns its ID. If a route +// for the same (listener_id, hostname) already exists, it is updated +// with the new values (upsert), making the operation idempotent. func (s *Store) CreateRoute(listenerID int64, hostname, backend, mode, tlsCert, tlsKey string, backendTLS, sendProxyProtocol bool) (int64, error) { result, err := s.db.Exec( `INSERT INTO routes (listener_id, hostname, backend, mode, tls_cert, tls_key, backend_tls, send_proxy_protocol) - VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(listener_id, hostname) DO UPDATE SET + backend = excluded.backend, + mode = excluded.mode, + tls_cert = excluded.tls_cert, + tls_key = excluded.tls_key, + backend_tls = excluded.backend_tls, + send_proxy_protocol = excluded.send_proxy_protocol`, listenerID, hostname, backend, mode, tlsCert, tlsKey, backendTLS, sendProxyProtocol, ) if err != nil { diff --git a/internal/grpcserver/grpcserver.go b/internal/grpcserver/grpcserver.go index 694d10a..905b952 100644 --- a/internal/grpcserver/grpcserver.go +++ b/internal/grpcserver/grpcserver.go @@ -144,10 +144,10 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb. } } - // Write-through: DB first, then memory. + // Write-through: DB first (upsert), then memory. if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode, req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil { - return nil, status.Errorf(codes.AlreadyExists, "%v", err) + return nil, status.Errorf(codes.Internal, "%v", err) } info := server.RouteInfo{ @@ -158,10 +158,7 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb. BackendTLS: req.Route.BackendTls, SendProxyProtocol: req.Route.SendProxyProtocol, } - if err := ls.AddRoute(hostname, info); err != nil { - // DB succeeded but memory failed (should not happen since DB enforces uniqueness). - a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) - } + ls.AddRoute(hostname, info) a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode) return &pb.AddRouteResponse{}, nil diff --git a/internal/grpcserver/grpcserver_test.go b/internal/grpcserver/grpcserver_test.go index 34b6ae4..bdec161 100644 --- a/internal/grpcserver/grpcserver_test.go +++ b/internal/grpcserver/grpcserver_test.go @@ -229,19 +229,29 @@ func TestAddRoute(t *testing.T) { } } -func TestAddRouteDuplicate(t *testing.T) { +func TestAddRouteUpsert(t *testing.T) { env := setup(t) ctx := context.Background() + // a.test already exists from setup(). Adding again with a different + // backend should succeed (upsert) and update the route. _, err := env.client.AddRoute(ctx, &pb.AddRouteRequest{ ListenerAddr: ":443", Route: &pb.Route{Hostname: "a.test", Backend: "127.0.0.1:1111"}, }) - if err == nil { - t.Fatal("expected error for duplicate route") + if err != nil { + t.Fatalf("upsert should succeed: %v", err) } - if s, ok := status.FromError(err); !ok || s.Code() != codes.AlreadyExists { - t.Fatalf("expected AlreadyExists, got %v", err) + + // Verify the route was updated, not duplicated. + routes, err := env.client.ListRoutes(ctx, &pb.ListRoutesRequest{ListenerAddr: ":443"}) + if err != nil { + t.Fatalf("list routes: %v", err) + } + for _, r := range routes.Routes { + if r.Hostname == "a.test" && r.Backend != "127.0.0.1:1111" { + t.Fatalf("expected updated backend 127.0.0.1:1111, got %q", r.Backend) + } } } diff --git a/internal/server/server.go b/internal/server/server.go index c9a2c47..ec8ce8e 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -69,19 +69,15 @@ func (ls *ListenerState) Routes() map[string]RouteInfo { return m } -// AddRoute adds a route to the listener. Returns an error if the hostname -// already exists. -func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) error { +// AddRoute adds or updates a route on the listener. If a route for the +// hostname already exists, it is replaced (upsert). +func (ls *ListenerState) AddRoute(hostname string, info RouteInfo) { key := strings.ToLower(hostname) ls.mu.Lock() defer ls.mu.Unlock() - if _, ok := ls.routes[key]; ok { - return fmt.Errorf("route %q already exists", hostname) - } ls.routes[key] = info - return nil } // RemoveRoute removes a route from the listener. Returns an error if the diff --git a/internal/server/server_test.go b/internal/server/server_test.go index f86ad67..8f24c28 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -679,16 +679,12 @@ func TestListenerStateRoutes(t *testing.T) { } // AddRoute - if err := ls.AddRoute("b.test", l4Route("127.0.0.1:2")); err != nil { - t.Fatalf("AddRoute: %v", err) - } + ls.AddRoute("b.test", l4Route("127.0.0.1:2")) - // AddRoute duplicate - if err := ls.AddRoute("b.test", l4Route("127.0.0.1:3")); err == nil { - t.Fatal("expected error for duplicate route") - } + // AddRoute duplicate (upsert — updates in place) + ls.AddRoute("b.test", l4Route("127.0.0.1:3")) - // Routes snapshot + // Routes snapshot — still 2 routes, the duplicate replaced the first. routes := ls.Routes() if len(routes) != 2 { t.Fatalf("expected 2 routes, got %d", len(routes)) @@ -708,8 +704,8 @@ func TestListenerStateRoutes(t *testing.T) { if len(routes) != 1 { t.Fatalf("expected 1 route, got %d", len(routes)) } - if routes["b.test"].Backend != "127.0.0.1:2" { - t.Fatalf("expected b.test → 127.0.0.1:2, got %q", routes["b.test"].Backend) + if routes["b.test"].Backend != "127.0.0.1:3" { + t.Fatalf("expected b.test → 127.0.0.1:3 (upserted), got %q", routes["b.test"].Backend) } }