Files
mc-proxy/internal/grpcserver/grpcserver.go
Kyle Isom f24fa2a2b0 Switch gRPC admin API to Unix socket only, add client package
- Remove TCP listener support from gRPC server; Unix socket is now the
  only transport for the admin API (access controlled via filesystem
  permissions)
- Add standard gRPC health check service (grpc.health.v1.Health)
- Implement MCPROXY_* environment variable overrides for config
- Create client/mcproxy package with full API coverage and tests
- Update ARCHITECTURE.md and dev config (srv/mc-proxy.toml)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-03-19 07:48:11 -07:00

322 lines
9.8 KiB
Go

package grpcserver
import (
"context"
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"regexp"
"strings"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/health"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc_proxy/v1"
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
"git.wntrmute.dev/kyle/mc-proxy/internal/db"
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServiceServer
srv *server.Server
store *db.Store
logger *slog.Logger
}
// NewAdminServer creates an AdminServer for use in testing or custom setups.
func NewAdminServer(srv *server.Server, store *db.Store, logger *slog.Logger) *AdminServer {
return &AdminServer{
srv: srv,
store: store,
logger: logger,
}
}
// New creates a gRPC server listening on a Unix socket.
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
}
if err := os.Chmod(path, 0600); err != nil {
ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
grpcServer := grpc.NewServer()
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
// Register standard gRPC health check service.
healthServer := health.NewServer()
healthServer.SetServingStatus("", healthpb.HealthCheckResponse_SERVING)
healthServer.SetServingStatus("mc_proxy.v1.ProxyAdminService", healthpb.HealthCheckResponse_SERVING)
healthpb.RegisterHealthServer(grpcServer, healthServer)
return grpcServer, ln, nil
}
// ListRoutes returns the route table for a listener.
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
routes := ls.Routes()
resp := &pb.ListRoutesResponse{
ListenerAddr: ls.Addr,
}
for hostname, backend := range routes {
resp.Routes = append(resp.Routes, &pb.Route{
Hostname: hostname,
Backend: backend,
})
}
return resp, nil
}
// AddRoute writes to the database first, then updates in-memory state.
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
if req.Route == nil {
return nil, status.Error(codes.InvalidArgument, "route is required")
}
if req.Route.Hostname == "" || req.Route.Backend == "" {
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
// Validate backend is a valid host:port.
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Route.Hostname)
// Write-through: DB first, then memory.
if _, err := a.store.CreateRoute(ls.ID, hostname, req.Route.Backend); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
if err := ls.AddRoute(hostname, req.Route.Backend); err != nil {
// DB succeeded but memory failed (should not happen since DB enforces uniqueness).
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
a.logger.Info("route added", "listener", ls.Addr, "hostname", hostname, "backend", req.Route.Backend)
return &pb.AddRouteResponse{}, nil
}
// RemoveRoute writes to the database first, then updates in-memory state.
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
if req.Hostname == "" {
return nil, status.Error(codes.InvalidArgument, "hostname is required")
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
}
hostname := strings.ToLower(req.Hostname)
if err := a.store.DeleteRoute(ls.ID, hostname); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
if err := ls.RemoveRoute(hostname); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
a.logger.Info("route removed", "listener", ls.Addr, "hostname", hostname)
return &pb.RemoveRouteResponse{}, nil
}
// GetFirewallRules returns all current firewall rules.
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
ips, cidrs, countries := a.srv.Firewall().Rules()
var rules []*pb.FirewallRule
for _, ip := range ips {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
Value: ip,
})
}
for _, cidr := range cidrs {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
Value: cidr,
})
}
for _, code := range countries {
rules = append(rules, &pb.FirewallRule{
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
Value: code,
})
}
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
}
// AddFirewallRule writes to the database first, then updates in-memory state.
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
ruleType, err := protoRuleTypeToString(req.Rule.Type)
if err != nil {
return nil, err
}
if req.Rule.Value == "" {
return nil, status.Error(codes.InvalidArgument, "value is required")
}
// Validate the value matches the rule type before persisting.
switch ruleType {
case "ip":
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
}
case "cidr":
prefix, err := netip.ParsePrefix(req.Rule.Value)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
}
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
if prefix.Masked().String() != req.Rule.Value {
return nil, status.Errorf(codes.InvalidArgument,
"CIDR not in canonical form: use %s", prefix.Masked().String())
}
case "country":
if !countryCodeRe.MatchString(req.Rule.Value) {
return nil, status.Error(codes.InvalidArgument,
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
}
}
// Write-through: DB first, then memory.
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
}
fw := a.srv.Firewall()
switch ruleType {
case "ip":
if err := fw.AddIP(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
case "cidr":
if err := fw.AddCIDR(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB write succeeded but memory update failed", "error", err)
}
case "country":
fw.AddCountry(req.Rule.Value)
}
a.logger.Info("firewall rule added", "type", ruleType, "value", req.Rule.Value)
return &pb.AddFirewallRuleResponse{}, nil
}
// RemoveFirewallRule writes to the database first, then updates in-memory state.
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
if req.Rule == nil {
return nil, status.Error(codes.InvalidArgument, "rule is required")
}
ruleType, err := protoRuleTypeToString(req.Rule.Type)
if err != nil {
return nil, err
}
if err := a.store.DeleteFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.NotFound, "%v", err)
}
fw := a.srv.Firewall()
switch ruleType {
case "ip":
if err := fw.RemoveIP(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
case "cidr":
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
a.logger.Error("inconsistency: DB delete succeeded but memory update failed", "error", err)
}
case "country":
fw.RemoveCountry(req.Rule.Value)
}
a.logger.Info("firewall rule removed", "type", ruleType, "value", req.Rule.Value)
return &pb.RemoveFirewallRuleResponse{}, nil
}
// GetStatus returns the proxy's current status.
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
var listeners []*pb.ListenerStatus
for _, ls := range a.srv.Listeners() {
routes := ls.Routes()
listeners = append(listeners, &pb.ListenerStatus{
Addr: ls.Addr,
RouteCount: int32(len(routes)),
ActiveConnections: ls.ActiveConnections.Load(),
})
}
return &pb.GetStatusResponse{
Version: a.srv.Version(),
StartedAt: timestamppb.New(a.srv.StartedAt()),
Listeners: listeners,
TotalConnections: a.srv.TotalConnections(),
}, nil
}
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
for _, ls := range a.srv.Listeners() {
if ls.Addr == addr {
return ls, nil
}
}
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
}
func protoRuleTypeToString(t pb.FirewallRuleType) (string, error) {
switch t {
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
return "ip", nil
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
return "cidr", nil
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
return "country", nil
default:
return "", status.Error(codes.InvalidArgument, "unknown rule type")
}
}