package grpcserver import ( "context" "fmt" "log/slog" "net" "net/netip" "os" "regexp" "strings" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/db" "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`) // AdminServer implements the ProxyAdmin gRPC service. type AdminServer struct { pb.UnimplementedProxyAdminServiceServer srv *server.Server store *db.Store logger *slog.Logger } // NewAdminServer creates an AdminServer for use in testing or custom setups. func NewAdminServer(srv *server.Server, store *db.Store, logger *slog.Logger) *AdminServer { return &AdminServer{ srv: srv, store: store, logger: logger, } } // New creates a gRPC server listening on a Unix socket. func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) { admin := &AdminServer{ srv: srv, store: store, logger: logger, } path := cfg.SocketPath() // Remove stale socket file from a previous run. os.Remove(path) ln, err := net.Listen("unix", path) if err != nil { return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err) } if err := os.Chmod(path, 0600); err != nil { ln.Close() return nil, nil, fmt.Errorf("setting socket permissions: %w", err) } grpcServer := grpc.NewServer() pb.RegisterProxyAdminServiceServer(grpcServer, admin) // Register standard gRPC health check service. healthServer := health.NewServer() healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING) healthServer.SetServingStatus("mc_proxy.v1.ProxyAdminService", healthpb.HealthCheckResponse_SERVING) healthpb.RegisterHealthServer(grpcServer, healthServer) return grpcServer, ln, nil } // ListRoutes returns the route table for a listener. func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) { ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } routes := ls.Routes() resp := &pb.ListRoutesResponse{ ListenerAddr: ls.Addr, } for hostname, route := range routes { var policies []*pb.L7Policy for _, p := range route.L7Policies { policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value}) } resp.Routes = append(resp.Routes, &pb.Route{ Hostname: hostname, Backend: route.Backend, Mode: route.Mode, TlsCert: route.TLSCert, TlsKey: route.TLSKey, BackendTls: route.BackendTLS, SendProxyProtocol: route.SendProxyProtocol, L7Policies: policies, }) } return resp, nil } // AddRoute writes to the database first, then updates in-memory state. func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) { if req.Route == nil { return nil, status.Error(codes.InvalidArgument, "route is required") } if req.Route.Hostname == "" || req.Route.Backend == "" { return nil, status.Error(codes.InvalidArgument, "hostname and backend are required") } // Validate backend is a valid host:port. if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err) } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } hostname := strings.ToLower(req.Route.Hostname) // Normalize mode. mode := req.Route.Mode if mode == "" { mode = "l4" } if mode != "l4" && mode != "l7" { return nil, status.Errorf(codes.InvalidArgument, "mode must be \"l4\" or \"l7\", got %q", mode) } // L7 routes require cert/key paths. if mode == "l7" { if req.Route.TlsCert == "" || req.Route.TlsKey == "" { return nil, status.Error(codes.InvalidArgument, "L7 routes require tls_cert and tls_key") } } // Write-through: DB first, then memory. if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode, req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } info := server.RouteInfo{ Backend: req.Route.Backend, Mode: mode, TLSCert: req.Route.TlsCert, TLSKey: req.Route.TlsKey, BackendTLS: req.Route.BackendTls, SendProxyProtocol: req.Route.SendProxyProtocol, } if err := ls.AddRoute(hostname, info); err != nil { // DB succeeded but memory failed (should not happen since DB enforces uniqueness). a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) } a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode) return &pb.AddRouteResponse{}, nil } // RemoveRoute writes to the database first, then updates in-memory state. func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) { if req.Hostname == "" { return nil, status.Error(codes.InvalidArgument, "hostname is required") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } hostname := strings.ToLower(req.Hostname) if err := a.store.DeleteRoute(ls.ID, hostname); err != nil { return nil, status.Errorf(codes.NotFound, "%v", err) } if err := ls.RemoveRoute(hostname); err != nil { a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) } a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname) return &pb.RemoveRouteResponse{}, nil } // ListL7Policies returns L7 policies for a route. func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) { ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } hostname := strings.ToLower(req.Hostname) routes := ls.Routes() route, ok := routes[hostname] if !ok { return nil, status.Errorf(codes.NotFound, "route %q not found", hostname) } var policies []*pb.L7Policy for _, p := range route.L7Policies { policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value}) } return &pb.ListL7PoliciesResponse{Policies: policies}, nil } // AddL7Policy adds an L7 policy to a route (write-through). func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) { if req.Policy == nil { return nil, status.Error(codes.InvalidArgument, "policy is required") } if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" { return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type) } if req.Policy.Value == "" { return nil, status.Error(codes.InvalidArgument, "policy value is required") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } hostname := strings.ToLower(req.Hostname) // Get route ID from DB. dbListener, err := a.store.GetListenerByAddr(ls.Addr) if err != nil { return nil, status.Errorf(codes.Internal, "%v", err) } routeID, err := a.store.GetRouteID(dbListener.ID, hostname) if err != nil { return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err) } // Write-through: DB first. if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } // Update in-memory state. ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value}) a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value) return &pb.AddL7PolicyResponse{}, nil } // RemoveL7Policy removes an L7 policy from a route (write-through). func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) { if req.Policy == nil { return nil, status.Error(codes.InvalidArgument, "policy is required") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } hostname := strings.ToLower(req.Hostname) dbListener, err := a.store.GetListenerByAddr(ls.Addr) if err != nil { return nil, status.Errorf(codes.Internal, "%v", err) } routeID, err := a.store.GetRouteID(dbListener.ID, hostname) if err != nil { return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err) } if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil { return nil, status.Errorf(codes.NotFound, "%v", err) } ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value) a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type) return &pb.RemoveL7PolicyResponse{}, nil } // GetFirewallRules returns all current firewall rules. func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) { ips, cidrs, countries := a.srv.Firewall().Rules() var rules []*pb.FirewallRule for _, ip := range ips { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: ip, }) } for _, cidr := range cidrs { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, Value: cidr, }) } for _, code := range countries { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, Value: code, }) } return &pb.GetFirewallRulesResponse{Rules: rules}, nil } // AddFirewallRule writes to the database first, then updates in-memory state. func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } ruleType, err := protoRuleTypeToString(req.Rule.Type) if err != nil { return nil, err } if req.Rule.Value == "" { return nil, status.Error(codes.InvalidArgument, "value is required") } // Validate the value matches the rule type before persisting. switch ruleType { case "ip": if _, err := netip.ParseAddr(req.Rule.Value); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err) } case "cidr": prefix, err := netip.ParsePrefix(req.Rule.Value) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err) } // Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16). if prefix.Masked().String() != req.Rule.Value { return nil, status.Errorf(codes.InvalidArgument, "CIDR not in canonical form: use %s", prefix.Masked().String()) } case "country": if !countryCodeRe.MatchString(req.Rule.Value) { return nil, status.Error(codes.InvalidArgument, "country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)") } } // Write-through: DB first, then memory. if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } fw := a.srv.Firewall() switch ruleType { case "ip": if err := fw.AddIP(req.Rule.Value); err != nil { a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) } case "cidr": if err := fw.AddCIDR(req.Rule.Value); err != nil { a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err) } case "country": fw.AddCountry(req.Rule.Value) } a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value) return &pb.AddFirewallRuleResponse{}, nil } // RemoveFirewallRule writes to the database first, then updates in-memory state. func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } ruleType, err := protoRuleTypeToString(req.Rule.Type) if err != nil { return nil, err } if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil { return nil, status.Errorf(codes.NotFound, "%v", err) } fw := a.srv.Firewall() switch ruleType { case "ip": if err := fw.RemoveIP(req.Rule.Value); err != nil { a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) } case "cidr": if err := fw.RemoveCIDR(req.Rule.Value); err != nil { a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err) } case "country": fw.RemoveCountry(req.Rule.Value) } a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value) return &pb.RemoveFirewallRuleResponse{}, nil } // SetListenerMaxConnections updates the per-listener connection limit. func (a *AdminServer) SetListenerMaxConnections(_ context.Context, req *pb.SetListenerMaxConnectionsRequest) (*pb.SetListenerMaxConnectionsResponse, error) { if req.MaxConnections < 0 { return nil, status.Error(codes.InvalidArgument, "max_connections must not be negative") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } // Write-through: DB first, then memory. if err := a.store.UpdateListenerMaxConns(ls.ID, req.MaxConnections); err != nil { return nil, status.Errorf(codes.Internal, "%v", err) } ls.SetMaxConnections(req.MaxConnections) a.logger.Info("connection limit updated", "listener", ls.Addr, "max_connections", req.MaxConnections) return &pb.SetListenerMaxConnectionsResponse{}, nil } // GetStatus returns the proxy's current status. func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) { var listeners []*pb.ListenerStatus for _, ls := range a.srv.Listeners() { routes := ls.Routes() var pbRoutes []*pb.Route for hostname, route := range routes { pbRoutes = append(pbRoutes, &pb.Route{ Hostname: hostname, Backend: route.Backend, Mode: route.Mode, BackendTls: route.BackendTLS, SendProxyProtocol: route.SendProxyProtocol, }) } listeners = append(listeners, &pb.ListenerStatus{ Addr: ls.Addr, RouteCount: int32(len(routes)), ActiveConnections: ls.ActiveConnections.Load(), ProxyProtocol: ls.ProxyProtocol, MaxConnections: ls.MaxConnections, Routes: pbRoutes, }) } return &pb.GetStatusResponse{ Version: a.srv.Version(), StartedAt: timestamppb.New(a.srv.StartedAt()), Listeners: listeners, TotalConnections: a.srv.TotalConnections(), }, nil } func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) { for _, ls := range a.srv.Listeners() { if ls.Addr == addr { return ls, nil } } return nil, status.Errorf(codes.NotFound, "listener %q not found", addr) } func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) { switch t { case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: return "ip", nil case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: return "cidr", nil case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: return "country", nil default: return "", status.Error(codes.InvalidArgument, "unknown rule type") } }