mcproxyctl status now shows individual routes per listener with hostname, backend, mode, and re-encrypt indicator. Proto, gRPC server, client library, and CLI all updated. Default gRPC socket path moved from /var/run/mc-proxy.sock to /srv/mc-proxy/mc-proxy.sock to match the service data directory convention. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
486 lines
15 KiB
Go
486 lines
15 KiB
Go
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log/slog"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/health"
|
|
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
|
"google.golang.org/grpc/status"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
|
|
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
|
)
|
|
|
|
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
|
|
|
|
// AdminServer implements the ProxyAdmin gRPC service.
|
|
type AdminServer struct {
|
|
pb.UnimplementedProxyAdminServiceServer
|
|
srv *server.Server
|
|
store *db.Store
|
|
logger *slog.Logger
|
|
}
|
|
|
|
// NewAdminServer creates an AdminServer for use in testing or custom setups.
|
|
func NewAdminServer(srv *server.Server, store *db.Store, logger *slog.Logger) *AdminServer {
|
|
return &AdminServer{
|
|
srv: srv,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
}
|
|
|
|
// New creates a gRPC server listening on a Unix socket.
|
|
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
|
admin := &AdminServer{
|
|
srv: srv,
|
|
store: store,
|
|
logger: logger,
|
|
}
|
|
|
|
path := cfg.SocketPath()
|
|
|
|
// Remove stale socket file from a previous run.
|
|
os.Remove(path)
|
|
|
|
ln, err := net.Listen("unix", path)
|
|
if err != nil {
|
|
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
|
|
}
|
|
|
|
if err := os.Chmod(path, 0600); err != nil {
|
|
ln.Close()
|
|
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
|
|
}
|
|
|
|
grpcServer := grpc.NewServer()
|
|
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
|
|
|
// Register standard gRPC health check service.
|
|
healthServer := health.NewServer()
|
|
healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
|
|
healthServer.SetServingStatus("mc_proxy.v1.ProxyAdminService", healthpb.HealthCheckResponse_SERVING)
|
|
healthpb.RegisterHealthServer(grpcServer, healthServer)
|
|
|
|
return grpcServer, ln, nil
|
|
}
|
|
|
|
// ListRoutes returns the route table for a listener.
|
|
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
routes := ls.Routes()
|
|
resp := &pb.ListRoutesResponse{
|
|
ListenerAddr: ls.Addr,
|
|
}
|
|
for hostname, route := range routes {
|
|
var policies []*pb.L7Policy
|
|
for _, p := range route.L7Policies {
|
|
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
|
}
|
|
resp.Routes = append(resp.Routes, &pb.Route{
|
|
Hostname: hostname,
|
|
Backend: route.Backend,
|
|
Mode: route.Mode,
|
|
TlsCert: route.TLSCert,
|
|
TlsKey: route.TLSKey,
|
|
BackendTls: route.BackendTLS,
|
|
SendProxyProtocol: route.SendProxyProtocol,
|
|
L7Policies: policies,
|
|
})
|
|
}
|
|
return resp, nil
|
|
}
|
|
|
|
// AddRoute writes to the database first, then updates in-memory state.
|
|
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
|
|
if req.Route == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "route is required")
|
|
}
|
|
if req.Route.Hostname == "" || req.Route.Backend == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
|
}
|
|
|
|
// Validate backend is a valid host:port.
|
|
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
|
|
}
|
|
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hostname := strings.ToLower(req.Route.Hostname)
|
|
|
|
// Normalize mode.
|
|
mode := req.Route.Mode
|
|
if mode == "" {
|
|
mode = "l4"
|
|
}
|
|
if mode != "l4" && mode != "l7" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "mode must be \"l4\" or \"l7\", got %q", mode)
|
|
}
|
|
|
|
// L7 routes require cert/key paths.
|
|
if mode == "l7" {
|
|
if req.Route.TlsCert == "" || req.Route.TlsKey == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "L7 routes require tls_cert and tls_key")
|
|
}
|
|
}
|
|
|
|
// Write-through: DB first, then memory.
|
|
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend, mode,
|
|
req.Route.TlsCert, req.Route.TlsKey, req.Route.BackendTls, req.Route.SendProxyProtocol); err != nil {
|
|
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
|
}
|
|
|
|
info := server.RouteInfo{
|
|
Backend: req.Route.Backend,
|
|
Mode: mode,
|
|
TLSCert: req.Route.TlsCert,
|
|
TLSKey: req.Route.TlsKey,
|
|
BackendTLS: req.Route.BackendTls,
|
|
SendProxyProtocol: req.Route.SendProxyProtocol,
|
|
}
|
|
if err := ls.AddRoute(hostname, info); err != nil {
|
|
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
|
|
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
|
}
|
|
|
|
a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend, "mode", mode)
|
|
return &pb.AddRouteResponse{}, nil
|
|
}
|
|
|
|
// RemoveRoute writes to the database first, then updates in-memory state.
|
|
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
|
|
if req.Hostname == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "hostname is required")
|
|
}
|
|
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hostname := strings.ToLower(req.Hostname)
|
|
|
|
if err := a.store.DeleteRoute(ls.ID, hostname); err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "%v", err)
|
|
}
|
|
|
|
if err := ls.RemoveRoute(hostname); err != nil {
|
|
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
|
}
|
|
|
|
a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname)
|
|
return &pb.RemoveRouteResponse{}, nil
|
|
}
|
|
|
|
// ListL7Policies returns L7 policies for a route.
|
|
func (a *AdminServer) ListL7Policies(_ context.Context, req *pb.ListL7PoliciesRequest) (*pb.ListL7PoliciesResponse, error) {
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hostname := strings.ToLower(req.Hostname)
|
|
routes := ls.Routes()
|
|
route, ok := routes[hostname]
|
|
if !ok {
|
|
return nil, status.Errorf(codes.NotFound, "route %q not found", hostname)
|
|
}
|
|
|
|
var policies []*pb.L7Policy
|
|
for _, p := range route.L7Policies {
|
|
policies = append(policies, &pb.L7Policy{Type: p.Type, Value: p.Value})
|
|
}
|
|
return &pb.ListL7PoliciesResponse{Policies: policies}, nil
|
|
}
|
|
|
|
// AddL7Policy adds an L7 policy to a route (write-through).
|
|
func (a *AdminServer) AddL7Policy(_ context.Context, req *pb.AddL7PolicyRequest) (*pb.AddL7PolicyResponse, error) {
|
|
if req.Policy == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
|
}
|
|
if req.Policy.Type != "block_user_agent" && req.Policy.Type != "require_header" {
|
|
return nil, status.Errorf(codes.InvalidArgument, "policy type must be \"block_user_agent\" or \"require_header\", got %q", req.Policy.Type)
|
|
}
|
|
if req.Policy.Value == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "policy value is required")
|
|
}
|
|
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hostname := strings.ToLower(req.Hostname)
|
|
|
|
// Get route ID from DB.
|
|
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
|
}
|
|
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
|
}
|
|
|
|
// Write-through: DB first.
|
|
if _, err := a.store.CreateL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
|
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
|
}
|
|
|
|
// Update in-memory state.
|
|
ls.AddL7Policy(hostname, server.L7PolicyRule{Type: req.Policy.Type, Value: req.Policy.Value})
|
|
|
|
a.logger.Info("L7 policy added", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type, "value", req.Policy.Value)
|
|
return &pb.AddL7PolicyResponse{}, nil
|
|
}
|
|
|
|
// RemoveL7Policy removes an L7 policy from a route (write-through).
|
|
func (a *AdminServer) RemoveL7Policy(_ context.Context, req *pb.RemoveL7PolicyRequest) (*pb.RemoveL7PolicyResponse, error) {
|
|
if req.Policy == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "policy is required")
|
|
}
|
|
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hostname := strings.ToLower(req.Hostname)
|
|
|
|
dbListener, err := a.store.GetListenerByAddr(ls.Addr)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
|
}
|
|
routeID, err := a.store.GetRouteID(dbListener.ID, hostname)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "route %q not found: %v", hostname, err)
|
|
}
|
|
|
|
if err := a.store.DeleteL7Policy(routeID, req.Policy.Type, req.Policy.Value); err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "%v", err)
|
|
}
|
|
|
|
ls.RemoveL7Policy(hostname, req.Policy.Type, req.Policy.Value)
|
|
|
|
a.logger.Info("L7 policy removed", "listener", ls.Addr, "hostname", hostname, "type", req.Policy.Type)
|
|
return &pb.RemoveL7PolicyResponse{}, nil
|
|
}
|
|
|
|
// GetFirewallRules returns all current firewall rules.
|
|
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
|
ips, cidrs, countries := a.srv.Firewall().Rules()
|
|
|
|
var rules []*pb.FirewallRule
|
|
for _, ip := range ips {
|
|
rules = append(rules, &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
|
Value: ip,
|
|
})
|
|
}
|
|
for _, cidr := range cidrs {
|
|
rules = append(rules, &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
|
Value: cidr,
|
|
})
|
|
}
|
|
for _, code := range countries {
|
|
rules = append(rules, &pb.FirewallRule{
|
|
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
|
Value: code,
|
|
})
|
|
}
|
|
|
|
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
|
|
}
|
|
|
|
// AddFirewallRule writes to the database first, then updates in-memory state.
|
|
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
|
|
if req.Rule == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
|
}
|
|
|
|
ruleType, err := protoRuleTypeToString(req.Rule.Type)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if req.Rule.Value == "" {
|
|
return nil, status.Error(codes.InvalidArgument, "value is required")
|
|
}
|
|
|
|
// Validate the value matches the rule type before persisting.
|
|
switch ruleType {
|
|
case "ip":
|
|
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
|
|
}
|
|
case "cidr":
|
|
prefix, err := netip.ParsePrefix(req.Rule.Value)
|
|
if err != nil {
|
|
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
|
|
}
|
|
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
|
|
if prefix.Masked().String() != req.Rule.Value {
|
|
return nil, status.Errorf(codes.InvalidArgument,
|
|
"CIDR not in canonical form: use %s", prefix.Masked().String())
|
|
}
|
|
case "country":
|
|
if !countryCodeRe.MatchString(req.Rule.Value) {
|
|
return nil, status.Error(codes.InvalidArgument,
|
|
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
|
|
}
|
|
}
|
|
|
|
// Write-through: DB first, then memory.
|
|
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
|
|
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
|
}
|
|
|
|
fw := a.srv.Firewall()
|
|
switch ruleType {
|
|
case "ip":
|
|
if err := fw.AddIP(req.Rule.Value); err != nil {
|
|
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
|
}
|
|
case "cidr":
|
|
if err := fw.AddCIDR(req.Rule.Value); err != nil {
|
|
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
|
|
}
|
|
case "country":
|
|
fw.AddCountry(req.Rule.Value)
|
|
}
|
|
|
|
a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value)
|
|
return &pb.AddFirewallRuleResponse{}, nil
|
|
}
|
|
|
|
// RemoveFirewallRule writes to the database first, then updates in-memory state.
|
|
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
|
|
if req.Rule == nil {
|
|
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
|
}
|
|
|
|
ruleType, err := protoRuleTypeToString(req.Rule.Type)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil {
|
|
return nil, status.Errorf(codes.NotFound, "%v", err)
|
|
}
|
|
|
|
fw := a.srv.Firewall()
|
|
switch ruleType {
|
|
case "ip":
|
|
if err := fw.RemoveIP(req.Rule.Value); err != nil {
|
|
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
|
}
|
|
case "cidr":
|
|
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
|
|
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
|
|
}
|
|
case "country":
|
|
fw.RemoveCountry(req.Rule.Value)
|
|
}
|
|
|
|
a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value)
|
|
return &pb.RemoveFirewallRuleResponse{}, nil
|
|
}
|
|
|
|
// SetListenerMaxConnections updates the per-listener connection limit.
|
|
func (a *AdminServer) SetListenerMaxConnections(_ context.Context, req *pb.SetListenerMaxConnectionsRequest) (*pb.SetListenerMaxConnectionsResponse, error) {
|
|
if req.MaxConnections < 0 {
|
|
return nil, status.Error(codes.InvalidArgument, "max_connections must not be negative")
|
|
}
|
|
|
|
ls, err := a.findListener(req.ListenerAddr)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Write-through: DB first, then memory.
|
|
if err := a.store.UpdateListenerMaxConns(ls.ID, req.MaxConnections); err != nil {
|
|
return nil, status.Errorf(codes.Internal, "%v", err)
|
|
}
|
|
|
|
ls.SetMaxConnections(req.MaxConnections)
|
|
|
|
a.logger.Info("connection limit updated", "listener", ls.Addr, "max_connections", req.MaxConnections)
|
|
return &pb.SetListenerMaxConnectionsResponse{}, nil
|
|
}
|
|
|
|
// GetStatus returns the proxy's current status.
|
|
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
|
|
var listeners []*pb.ListenerStatus
|
|
for _, ls := range a.srv.Listeners() {
|
|
routes := ls.Routes()
|
|
var pbRoutes []*pb.Route
|
|
for hostname, route := range routes {
|
|
pbRoutes = append(pbRoutes, &pb.Route{
|
|
Hostname: hostname,
|
|
Backend: route.Backend,
|
|
Mode: route.Mode,
|
|
BackendTls: route.BackendTLS,
|
|
SendProxyProtocol: route.SendProxyProtocol,
|
|
})
|
|
}
|
|
listeners = append(listeners, &pb.ListenerStatus{
|
|
Addr: ls.Addr,
|
|
RouteCount: int32(len(routes)),
|
|
ActiveConnections: ls.ActiveConnections.Load(),
|
|
ProxyProtocol: ls.ProxyProtocol,
|
|
MaxConnections: ls.MaxConnections,
|
|
Routes: pbRoutes,
|
|
})
|
|
}
|
|
|
|
return &pb.GetStatusResponse{
|
|
Version: a.srv.Version(),
|
|
StartedAt: timestamppb.New(a.srv.StartedAt()),
|
|
Listeners: listeners,
|
|
TotalConnections: a.srv.TotalConnections(),
|
|
}, nil
|
|
}
|
|
|
|
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
|
|
for _, ls := range a.srv.Listeners() {
|
|
if ls.Addr == addr {
|
|
return ls, nil
|
|
}
|
|
}
|
|
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
|
|
}
|
|
|
|
func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) {
|
|
switch t {
|
|
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
|
return "ip", nil
|
|
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
|
return "cidr", nil
|
|
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
|
return "country", nil
|
|
default:
|
|
return "", status.Error(codes.InvalidArgument, "unknown rule type")
|
|
}
|
|
}
|