Add per-IP rate limiting and Unix socket support for gRPC admin API
Rate limiting: per-source-IP connection rate limiter in the firewall layer with configurable limit and sliding window. Blocklisted IPs are rejected before rate limit evaluation to avoid wasting quota. Unix socket: the gRPC admin API can now listen on a Unix domain socket (no TLS required), secured by file permissions (0600), as a simpler alternative for local-only access. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -7,7 +7,9 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
@@ -22,6 +24,8 @@ import (
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||
)
|
||||
|
||||
var countryCodeRe = regexp.MustCompile(`^[A-Z]{2}$`)
|
||||
|
||||
// AdminServer implements the ProxyAdmin gRPC service.
|
||||
type AdminServer struct {
|
||||
pb.UnimplementedProxyAdminServiceServer
|
||||
@@ -30,8 +34,43 @@ type AdminServer struct {
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a gRPC server with TLS and optional mTLS.
|
||||
// New creates a gRPC server. For Unix sockets, no TLS is used. For TCP
|
||||
// addresses, TLS is required with optional mTLS.
|
||||
func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
if cfg.IsUnixSocket() {
|
||||
return newUnixServer(cfg, admin)
|
||||
}
|
||||
return newTCPServer(cfg, admin)
|
||||
}
|
||||
|
||||
func newUnixServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
|
||||
path := cfg.SocketPath()
|
||||
|
||||
// Remove stale socket file from a previous run.
|
||||
os.Remove(path)
|
||||
|
||||
ln, err := net.Listen("unix", path)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("listening on unix socket %s: %w", path, err)
|
||||
}
|
||||
|
||||
if err := os.Chmod(path, 0600); err != nil {
|
||||
ln.Close()
|
||||
return nil, nil, fmt.Errorf("setting socket permissions: %w", err)
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
||||
return grpcServer, ln, nil
|
||||
}
|
||||
|
||||
func newTCPServer(cfg config.GRPC, admin *AdminServer) (*grpc.Server, net.Listener, error) {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
||||
@@ -57,12 +96,6 @@ func New(cfg config.GRPC, srv *server.Server, store *db.Store, logger *slog.Logg
|
||||
|
||||
creds := credentials.NewTLS(tlsConfig)
|
||||
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
||||
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
store: store,
|
||||
logger: logger,
|
||||
}
|
||||
pb.RegisterProxyAdminServiceServer(grpcServer, admin)
|
||||
|
||||
ln, err := net.Listen("tcp", cfg.Addr)
|
||||
@@ -102,6 +135,11 @@ func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
||||
}
|
||||
|
||||
// Validate backend is a valid host:port.
|
||||
if _, _, err := net.SplitHostPort(req.Route.Backend); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid backend address: %v", err)
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -190,6 +228,29 @@ func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRule
|
||||
return nil, status.Error(codes.InvalidArgument, "value is required")
|
||||
}
|
||||
|
||||
// Validate the value matches the rule type before persisting.
|
||||
switch ruleType {
|
||||
case "ip":
|
||||
if _, err := netip.ParseAddr(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid IP address: %v", err)
|
||||
}
|
||||
case "cidr":
|
||||
prefix, err := netip.ParsePrefix(req.Rule.Value)
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid CIDR: %v", err)
|
||||
}
|
||||
// Require canonical form (e.g. 192.168.0.0/16 not 192.168.1.5/16).
|
||||
if prefix.Masked().String() != req.Rule.Value {
|
||||
return nil, status.Errorf(codes.InvalidArgument,
|
||||
"CIDR not in canonical form: use %s", prefix.Masked().String())
|
||||
}
|
||||
case "country":
|
||||
if !countryCodeRe.MatchString(req.Rule.Value) {
|
||||
return nil, status.Error(codes.InvalidArgument,
|
||||
"country code must be exactly 2 uppercase letters (ISO 3166-1 alpha-2)")
|
||||
}
|
||||
}
|
||||
|
||||
// Write-through: DB first, then memory.
|
||||
if _, err := a.store.CreateFirewallRule(ruleType, req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
|
||||
@@ -62,7 +62,7 @@ func setup(t *testing.T) *testEnv {
|
||||
}
|
||||
|
||||
// Build server with matching in-memory state.
|
||||
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil)
|
||||
fwObj, err := firewall.New("", []string{"10.0.0.1"}, nil, nil, 0, 0)
|
||||
if err != nil {
|
||||
t.Fatalf("firewall: %v", err)
|
||||
}
|
||||
@@ -268,6 +268,15 @@ func TestAddRouteValidation(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty backend")
|
||||
}
|
||||
|
||||
// Invalid backend (not host:port).
|
||||
_, err = env.client.AddRoute(ctx, &pb.AddRouteRequest{
|
||||
ListenerAddr: ":443",
|
||||
Route: &pb.Route{Hostname: "y.test", Backend: "not-a-host-port"},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid backend address")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveRoute(t *testing.T) {
|
||||
@@ -410,6 +419,61 @@ func TestAddFirewallRuleValidation(t *testing.T) {
|
||||
if err == nil {
|
||||
t.Fatal("expected error for empty value")
|
||||
}
|
||||
|
||||
// Invalid IP address.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: "not-an-ip",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IP")
|
||||
}
|
||||
|
||||
// Invalid CIDR.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||
Value: "not-a-cidr",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid CIDR")
|
||||
}
|
||||
|
||||
// Non-canonical CIDR.
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||
Value: "192.168.1.5/16",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for non-canonical CIDR")
|
||||
}
|
||||
|
||||
// Invalid country code (lowercase).
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||
Value: "cn",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for lowercase country code")
|
||||
}
|
||||
|
||||
// Invalid country code (too long).
|
||||
_, err = env.client.AddFirewallRule(ctx, &pb.AddFirewallRuleRequest{
|
||||
Rule: &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||
Value: "USA",
|
||||
},
|
||||
})
|
||||
if err == nil {
|
||||
t.Fatal("expected error for 3-letter country code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveFirewallRule(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user