Initial implementation of mc-proxy
Layer 4 TLS SNI proxy with global firewall (IP/CIDR/GeoIP blocking), per-listener route tables, bidirectional TCP relay with half-close propagation, and a gRPC admin API (routes, firewall, status) with TLS/mTLS support. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
246
internal/grpcserver/grpcserver.go
Normal file
246
internal/grpcserver/grpcserver.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/credentials"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/config"
|
||||
"git.wntrmute.dev/kyle/mc-proxy/internal/server"
|
||||
)
|
||||
|
||||
// AdminServer implements the ProxyAdmin gRPC service.
|
||||
type AdminServer struct {
|
||||
pb.UnimplementedProxyAdminServer
|
||||
srv *server.Server
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// New creates a gRPC server with TLS and optional mTLS.
|
||||
func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) {
|
||||
cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("loading TLS keypair: %w", err)
|
||||
}
|
||||
|
||||
tlsConfig := &tls.Config{
|
||||
Certificates: []tls.Certificate{cert},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}
|
||||
|
||||
// mTLS: require and verify client certificates.
|
||||
if cfg.ClientCA != "" {
|
||||
caCert, err := os.ReadFile(cfg.ClientCA)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("reading client CA: %w", err)
|
||||
}
|
||||
pool := x509.NewCertPool()
|
||||
if !pool.AppendCertsFromPEM(caCert) {
|
||||
return nil, nil, fmt.Errorf("failed to parse client CA certificate")
|
||||
}
|
||||
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
|
||||
tlsConfig.ClientCAs = pool
|
||||
}
|
||||
|
||||
creds := credentials.NewTLS(tlsConfig)
|
||||
grpcServer := grpc.NewServer(grpc.Creds(creds))
|
||||
|
||||
admin := &AdminServer{
|
||||
srv: srv,
|
||||
logger: logger,
|
||||
}
|
||||
pb.RegisterProxyAdminServer(grpcServer, admin)
|
||||
|
||||
ln, err := net.Listen("tcp", cfg.Addr)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err)
|
||||
}
|
||||
|
||||
return grpcServer, ln, nil
|
||||
}
|
||||
|
||||
// ListRoutes returns the route table for a listener.
|
||||
func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) {
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
routes := ls.Routes()
|
||||
resp := &pb.ListRoutesResponse{
|
||||
ListenerAddr: ls.Addr,
|
||||
}
|
||||
for hostname, backend := range routes {
|
||||
resp.Routes = append(resp.Routes, &pb.Route{
|
||||
Hostname: hostname,
|
||||
Backend: backend,
|
||||
})
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// AddRoute adds a route to a listener's route table.
|
||||
func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) {
|
||||
if req.Route == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "route is required")
|
||||
}
|
||||
if req.Route.Hostname == "" || req.Route.Backend == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname and backend are required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil {
|
||||
return nil, status.Errorf(codes.AlreadyExists, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend)
|
||||
return &pb.AddRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveRoute removes a route from a listener's route table.
|
||||
func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) {
|
||||
if req.Hostname == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "hostname is required")
|
||||
}
|
||||
|
||||
ls, err := a.findListener(req.ListenerAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := ls.RemoveRoute(req.Hostname); err != nil {
|
||||
return nil, status.Errorf(codes.NotFound, "%v", err)
|
||||
}
|
||||
|
||||
a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname)
|
||||
return &pb.RemoveRouteResponse{}, nil
|
||||
}
|
||||
|
||||
// GetFirewallRules returns all current firewall rules.
|
||||
func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) {
|
||||
ips, cidrs, countries := a.srv.Firewall().Rules()
|
||||
|
||||
var rules []*pb.FirewallRule
|
||||
for _, ip := range ips {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP,
|
||||
Value: ip,
|
||||
})
|
||||
}
|
||||
for _, cidr := range cidrs {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR,
|
||||
Value: cidr,
|
||||
})
|
||||
}
|
||||
for _, code := range countries {
|
||||
rules = append(rules, &pb.FirewallRule{
|
||||
Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY,
|
||||
Value: code,
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.GetFirewallRulesResponse{Rules: rules}, nil
|
||||
}
|
||||
|
||||
// AddFirewallRule adds a firewall rule.
|
||||
func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.AddIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.AddCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.AddCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
return &pb.AddFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// RemoveFirewallRule removes a firewall rule.
|
||||
func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) {
|
||||
if req.Rule == nil {
|
||||
return nil, status.Error(codes.InvalidArgument, "rule is required")
|
||||
}
|
||||
|
||||
fw := a.srv.Firewall()
|
||||
switch req.Rule.Type {
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP:
|
||||
if err := fw.RemoveIP(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR:
|
||||
if err := fw.RemoveCIDR(req.Rule.Value); err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "%v", err)
|
||||
}
|
||||
case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY:
|
||||
if req.Rule.Value == "" {
|
||||
return nil, status.Error(codes.InvalidArgument, "country code is required")
|
||||
}
|
||||
fw.RemoveCountry(req.Rule.Value)
|
||||
default:
|
||||
return nil, status.Error(codes.InvalidArgument, "unknown rule type")
|
||||
}
|
||||
|
||||
a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value)
|
||||
return &pb.RemoveFirewallRuleResponse{}, nil
|
||||
}
|
||||
|
||||
// GetStatus returns the proxy's current status.
|
||||
func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) {
|
||||
var listeners []*pb.ListenerStatus
|
||||
for _, ls := range a.srv.Listeners() {
|
||||
routes := ls.Routes()
|
||||
listeners = append(listeners, &pb.ListenerStatus{
|
||||
Addr: ls.Addr,
|
||||
RouteCount: int32(len(routes)),
|
||||
ActiveConnections: ls.ActiveConnections.Load(),
|
||||
})
|
||||
}
|
||||
|
||||
return &pb.GetStatusResponse{
|
||||
Version: a.srv.Version(),
|
||||
StartedAt: timestamppb.New(a.srv.StartedAt()),
|
||||
Listeners: listeners,
|
||||
TotalConnections: a.srv.TotalConnections(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) {
|
||||
for _, ls := range a.srv.Listeners() {
|
||||
if ls.Addr == addr {
|
||||
return ls, nil
|
||||
}
|
||||
}
|
||||
return nil, status.Errorf(codes.NotFound, "listener %q not found", addr)
|
||||
}
|
||||
Reference in New Issue
Block a user