Files
mc-proxy/internal/grpcserver/grpcserver.go
Kyle Isom 28321e22f4 Make AddRoute idempotent (upsert instead of reject duplicates)
AddRoute now updates an existing route if one already exists for the
same (listener, hostname) pair, instead of returning AlreadyExists.
This makes repeated deploys idempotent — the MCP agent can register
routes on every deploy without needing to remove them first.

- DB: INSERT ... ON CONFLICT DO UPDATE (SQLite upsert)
- In-memory: overwrite existing route unconditionally
- gRPC: error code changed from AlreadyExists to Internal (for real DB errors)

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-27 14:01:45 -07:00

483 lines
15 KiB
Go

package grpcserver
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"regexp"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/mc/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/mc/mc-proxy/internal/config"
"git.wntrmute.dev/mc/mc-proxy/internal/db"
"git.wntrmute.dev/mc/mc-proxy/internal/server"
)
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServiceServer
srv *server.Server
store *db.Store
logger *slog.Logger
}
// NewAdminServer creates an AdminServer for use in testing or custom setups.
func NewAdminServer(srv *server.Server, store *db.Store, logger *slog.Logger) *AdminServer {
return &AdminServer{
srv: srv,
store: store,
logger: logger,
}
}
// New creates a gRPC server listening on a Unix socket.
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
_ = os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
}
if err := os.Chmod(path, 0600); err != nil {
_ = ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
grpcServer := grpc.NewServer()
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
// Register standard gRPC health check service.
healthServer := health.NewServer()
healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
healthServer.SetServingStatus("mc_proxy.v1.ProxyAdminService", healthpb.HealthCheckResponse_SERVING)
healthpb.RegisterHealthServer(grpcServer, healthServer)
return grpcServer, ln, nil
}
// ListRoutes returns the route table for a listener.
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
routes := ls.Routes()
resp := &pb.ListRoutesResponse{
ListenerAddr: ls.Addr,
}
for hostname, route := range routes {
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: route.Backend,
Mode: route.Mode,
TlsCert: route.TLSCert,
TlsKey: route.TLSKey,
BackendTls: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
L7Policies: policies,
})
}
return resp, nil
}
// AddRoute writes to the database first, then updates in-memory state.
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
if req.Route == nil {
return nil, status.Error(codes.InvalidArgument, "route is required")
}
if req.Route.Hostname == "" || req.Route.Backend == "" {
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
// Validate backend is a valid host:port.
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Route.Hostname)
// Normalize mode.
mode := req.Route.Mode
if mode == "" {
mode = "l4"
}
if mode != "l4" && mode != "l7" {
return nil, status.Errorf(codes.InvalidArgument, "mode must be \"l4\" or \"l7\", got %q", mode)
}
// L7 routes require cert/key paths.
if mode == "l7" {
if req.Route.TlsCert == "" || req.Route.TlsKey == "" {
return nil, status.Error(codes.InvalidArgument, "L7 routes require tls_cert and tls_key")
}
}
// Write-through: DB first (upsert), then memory.
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode,
req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
info := server.RouteInfo{
Backend: req.Route.Backend,
Mode: mode,
TLSCert: req.Route.TlsCert,
TLSKey: req.Route.TlsKey,
BackendTLS: req.Route.BackendTls,
SendProxyProtocol: req.Route.SendProxyProtocol,
}
ls.AddRoute(hostname, info)
a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode)
return &pb.AddRouteResponse{}, nil
}
// RemoveRoute writes to the database first, then updates in-memory state.
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
if req.Hostname == "" {
return nil, status.Error(codes.InvalidArgument, "hostname is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
if err := a.store.DeleteRoute(ls.ID, hostname); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
if err := ls.RemoveRoute(hostname); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname)
return &pb.RemoveRouteResponse{}, nil
}
// ListL7Policies returns L7 policies for a route.
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
routes := ls.Routes()
route, ok := routes[hostname]
if !ok {
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
}
var policies []*pb.L7Policy
for _, p := range route.L7Policies {
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
}
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
}
// AddL7Policy adds an L7 policy to a route (write-through).
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
}
if req.Policy.Value == "" {
return nil, status.Error(codes.InvalidArgument, "policy value is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
// Get route ID from DB.
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
// Write-through: DB first.
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
// Update in-memory state.
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
return &pb.AddL7PolicyResponse{}, nil
}
// RemoveL7Policy removes an L7 policy from a route (write-through).
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
if req.Policy == nil {
return nil, status.Error(codes.InvalidArgument, "policy is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
if err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
if err != nil {
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
}
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
return &pb.RemoveL7PolicyResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()
var rules []*pb.FirewallRule
for _, ip := range ips {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: ip,
})
}
for _, cidr := range cidrs {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: cidr,
})
}
for _, code := range countries {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: code,
})
}
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
}
// AddFirewallRule writes to the database first, then updates in-memory state.
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
ruleType, err := protoRuleTypeToString(req.Rule.Type)
if err != nil {
return nil, err
}
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "value is required")
}
// Validate the value matches the rule type before persisting.
switch ruleType {
case "ip":
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
}
case "cidr":
prefix, err := netip.ParsePrefix(req.Rule.Value)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
}
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
if prefix.Masked().String() != req.Rule.Value {
return nil, status.Errorf(codes.InvalidArgument,
"CIDR not in canonical form: use %s", prefix.Masked().String())
}
case "country":
if !countryCodeRe.MatchString(req.Rule.Value) {
return nil, status.Error(codes.InvalidArgument,
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
}
}
// Write-through: DB first, then memory.
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
fw := a.srv.Firewall()
switch ruleType {
case "ip":
if err := fw.AddIP(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
case "cidr":
if err := fw.AddCIDR(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
case "country":
fw.AddCountry(req.Rule.Value)
}
a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value)
return &pb.AddFirewallRuleResponse{}, nil
}
// RemoveFirewallRule writes to the database first, then updates in-memory state.
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
ruleType, err := protoRuleTypeToString(req.Rule.Type)
if err != nil {
return nil, err
}
if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
fw := a.srv.Firewall()
switch ruleType {
case "ip":
if err := fw.RemoveIP(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
case "cidr":
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
case "country":
fw.RemoveCountry(req.Rule.Value)
}
a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value)
return &pb.RemoveFirewallRuleResponse{}, nil
}
// SetListenerMaxConnections updates the per-listener connection limit.
func (a *AdminServer) SetListenerMaxConnections(_ context.Context, req *pb.SetListenerMaxConnectionsRequest) (*pb.SetListenerMaxConnectionsResponse, error) {
if req.MaxConnections < 0 {
return nil, status.Error(codes.InvalidArgument, "max_connections must not be negative")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
// Write-through: DB first, then memory.
if err := a.store.UpdateListenerMaxConns(ls.ID, req.MaxConnections); err != nil {
return nil, status.Errorf(codes.Internal, "%v", err)
}
ls.SetMaxConnections(req.MaxConnections)
a.logger.Info("connection limit updated", "listener", ls.Addr, "max_connections", req.MaxConnections)
return &pb.SetListenerMaxConnectionsResponse{}, nil
}
// GetStatus returns the proxy's current status.
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
var listeners []*pb.ListenerStatus
for _, ls := range a.srv.Listeners() {
routes := ls.Routes()
var pbRoutes []*pb.Route
for hostname, route := range routes {
pbRoutes = append(pbRoutes, &pb.Route{
Hostname: hostname,
Backend: route.Backend,
Mode: route.Mode,
BackendTls: route.BackendTLS,
SendProxyProtocol: route.SendProxyProtocol,
})
}
listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr,
RouteCount: int32(len(routes)), //nolint:gosec // route count can never exceed int32
ActiveConnections: ls.ActiveConnections.Load(),
ProxyProtocol: ls.ProxyProtocol,
MaxConnections: ls.MaxConnections,
Routes: pbRoutes,
})
}
return &pb.GetStatusResponse{
Version: a.srv.Version(),
StartedAt: timestamppb.New(a.srv.StartedAt()),
Listeners: listeners,
TotalConnections: a.srv.TotalConnections(),
}, nil
}
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
for _, ls := range a.srv.Listeners() {
if ls.Addr == addr {
return ls, nil
}
}
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
}
func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) {
switch t {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
return "ip", nil
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
return "cidr", nil
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
return "country", nil
default:
return "", status.Error(codes.InvalidArgument, "unknown rule type")
}
}