Add per-IP rate limiting and Unix socket support for gRPC admin API

Rate limiting: per-source-IP connection rate limiter in the firewall layer
with configurable limit and sliding window. Blocklisted IPs are rejected
before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC
admin API can now listen on a Unix domain socket (no TLS required), secured
by file permissions (0600), as a simpler alternative for local-only access.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-17 14:37:21 -07:00
parent e84093b7fb
commit b25e1b0e79
16 changed files with 694 additions and 43 deletions

View File

@@ -7,7 +7,9 @@ import (
"fmt"
"log/slog"
"net"
"net/netip"
"os"
"regexp"
"strings"
"google.golang.org/grpc"
@@ -22,6 +24,8 @@ import (
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
)
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
// AdminServer implements the ProxyAdmin gRPC service.
type AdminServer struct {
pb.UnimplementedProxyAdminServiceServer
@@ -30,8 +34,43 @@ type AdminServer struct {
logger *slog.Logger
}
// New creates a gRPC server with TLS and optional mTLS.
// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP
// addresses, TLS is required with optional mTLS.
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
if cfg.IsUnixSocket() {
return newUnixServer(cfg, admin)
}
return newTCPServer(cfg, admin)
}
func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
path := cfg.SocketPath()
// Remove stale socket file from a previous run.
os.Remove(path)
ln, err := net.Listen("unix", path)
if err != nil {
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
}
if err := os.Chmod(path, 0600); err != nil {
ln.Close()
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
}
grpcServer := grpc.NewServer()
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
return grpcServer, ln, nil
}
func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
if err != nil {
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
@@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
creds := credentials.NewTLS(tlsConfig)
grpcServer := grpc.NewServer(grpc.Creds(creds))
admin := &AdminServer{
srv: srv,
store: store,
logger: logger,
}
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
ln, err := net.Listen("tcp", cfg.Addr)
@@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
}
// Validate backend is a valid host:port.
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
}
ls, err := a.findListener(req.ListenerAddr)
if err != nil {
return nil, err
@@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule
return nil, status.Error(codes.InvalidArgument, "value is required")
}
// Validate the value matches the rule type before persisting.
switch ruleType {
case "ip":
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
}
case "cidr":
prefix, err := netip.ParsePrefix(req.Rule.Value)
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
}
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
if prefix.Masked().String() != req.Rule.Value {
return nil, status.Errorf(codes.InvalidArgument,
"CIDR not in canonical form: use %s", prefix.Masked().String())
}
case "country":
if !countryCodeRe.MatchString(req.Rule.Value) {
return nil, status.Error(codes.InvalidArgument,
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
}
}
// Write-through: DB first, then memory.
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
return nil, status.Errorf(codes.AlreadyExists, "%v", err)