package grpcserver import ( "context" "crypto/tls" "crypto/x509" "fmt" "log/slog" "net" "os" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" pb "git.wntrmute.dev/kyle/mc-proxy/gen/mc-proxy/v1" "git.wntrmute.dev/kyle/mc-proxy/internal/config" "git.wntrmute.dev/kyle/mc-proxy/internal/server" ) // AdminServer implements the ProxyAdmin gRPC service. type AdminServer struct { pb.UnimplementedProxyAdminServer srv *server.Server logger *slog.Logger } // New creates a gRPC server with TLS and optional mTLS. func New(cfg config.GRPC, srv *server.Server, logger *slog.Logger) (*grpc.Server, net.Listener, error) { cert, err := tls.LoadX509KeyPair(cfg.TLSCert, cfg.TLSKey) if err != nil { return nil, nil, fmt.Errorf("loading TLS keypair: %w", err) } tlsConfig := &tls.Config{ Certificates: []tls.Certificate{cert}, MinVersion: tls.VersionTLS13, } // mTLS: require and verify client certificates. if cfg.ClientCA != "" { caCert, err := os.ReadFile(cfg.ClientCA) if err != nil { return nil, nil, fmt.Errorf("reading client CA: %w", err) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(caCert) { return nil, nil, fmt.Errorf("failed to parse client CA certificate") } tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert tlsConfig.ClientCAs = pool } creds := credentials.NewTLS(tlsConfig) grpcServer := grpc.NewServer(grpc.Creds(creds)) admin := &AdminServer{ srv: srv, logger: logger, } pb.RegisterProxyAdminServer(grpcServer, admin) ln, err := net.Listen("tcp", cfg.Addr) if err != nil { return nil, nil, fmt.Errorf("listening on %s: %w", cfg.Addr, err) } return grpcServer, ln, nil } // ListRoutes returns the route table for a listener. func (a *AdminServer) ListRoutes(_ context.Context, req *pb.ListRoutesRequest) (*pb.ListRoutesResponse, error) { ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } routes := ls.Routes() resp := &pb.ListRoutesResponse{ ListenerAddr: ls.Addr, } for hostname, backend := range routes { resp.Routes = append(resp.Routes, &pb.Route{ Hostname: hostname, Backend: backend, }) } return resp, nil } // AddRoute adds a route to a listener's route table. func (a *AdminServer) AddRoute(_ context.Context, req *pb.AddRouteRequest) (*pb.AddRouteResponse, error) { if req.Route == nil { return nil, status.Error(codes.InvalidArgument, "route is required") } if req.Route.Hostname == "" || req.Route.Backend == "" { return nil, status.Error(codes.InvalidArgument, "hostname and backend are required") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } if err := ls.AddRoute(req.Route.Hostname, req.Route.Backend); err != nil { return nil, status.Errorf(codes.AlreadyExists, "%v", err) } a.logger.Info("route added", "listener", ls.Addr, "hostname", req.Route.Hostname, "backend", req.Route.Backend) return &pb.AddRouteResponse{}, nil } // RemoveRoute removes a route from a listener's route table. func (a *AdminServer) RemoveRoute(_ context.Context, req *pb.RemoveRouteRequest) (*pb.RemoveRouteResponse, error) { if req.Hostname == "" { return nil, status.Error(codes.InvalidArgument, "hostname is required") } ls, err := a.findListener(req.ListenerAddr) if err != nil { return nil, err } if err := ls.RemoveRoute(req.Hostname); err != nil { return nil, status.Errorf(codes.NotFound, "%v", err) } a.logger.Info("route removed", "listener", ls.Addr, "hostname", req.Hostname) return &pb.RemoveRouteResponse{}, nil } // GetFirewallRules returns all current firewall rules. func (a *AdminServer) GetFirewallRules(_ context.Context, _ *pb.GetFirewallRulesRequest) (*pb.GetFirewallRulesResponse, error) { ips, cidrs, countries := a.srv.Firewall().Rules() var rules []*pb.FirewallRule for _, ip := range ips { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP, Value: ip, }) } for _, cidr := range cidrs { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR, Value: cidr, }) } for _, code := range countries { rules = append(rules, &pb.FirewallRule{ Type: pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY, Value: code, }) } return &pb.GetFirewallRulesResponse{Rules: rules}, nil } // AddFirewallRule adds a firewall rule. func (a *AdminServer) AddFirewallRule(_ context.Context, req *pb.AddFirewallRuleRequest) (*pb.AddFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } fw := a.srv.Firewall() switch req.Rule.Type { case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: if err := fw.AddIP(req.Rule.Value); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: if err := fw.AddCIDR(req.Rule.Value); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: if req.Rule.Value == "" { return nil, status.Error(codes.InvalidArgument, "country code is required") } fw.AddCountry(req.Rule.Value) default: return nil, status.Error(codes.InvalidArgument, "unknown rule type") } a.logger.Info("firewall rule added", "type", req.Rule.Type, "value", req.Rule.Value) return &pb.AddFirewallRuleResponse{}, nil } // RemoveFirewallRule removes a firewall rule. func (a *AdminServer) RemoveFirewallRule(_ context.Context, req *pb.RemoveFirewallRuleRequest) (*pb.RemoveFirewallRuleResponse, error) { if req.Rule == nil { return nil, status.Error(codes.InvalidArgument, "rule is required") } fw := a.srv.Firewall() switch req.Rule.Type { case pb.FirewallRuleType_FIREWALL_RULE_TYPE_IP: if err := fw.RemoveIP(req.Rule.Value); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } case pb.FirewallRuleType_FIREWALL_RULE_TYPE_CIDR: if err := fw.RemoveCIDR(req.Rule.Value); err != nil { return nil, status.Errorf(codes.InvalidArgument, "%v", err) } case pb.FirewallRuleType_FIREWALL_RULE_TYPE_COUNTRY: if req.Rule.Value == "" { return nil, status.Error(codes.InvalidArgument, "country code is required") } fw.RemoveCountry(req.Rule.Value) default: return nil, status.Error(codes.InvalidArgument, "unknown rule type") } a.logger.Info("firewall rule removed", "type", req.Rule.Type, "value", req.Rule.Value) return &pb.RemoveFirewallRuleResponse{}, nil } // GetStatus returns the proxy's current status. func (a *AdminServer) GetStatus(_ context.Context, _ *pb.GetStatusRequest) (*pb.GetStatusResponse, error) { var listeners []*pb.ListenerStatus for _, ls := range a.srv.Listeners() { routes := ls.Routes() listeners = append(listeners, &pb.ListenerStatus{ Addr: ls.Addr, RouteCount: int32(len(routes)), ActiveConnections: ls.ActiveConnections.Load(), }) } return &pb.GetStatusResponse{ Version: a.srv.Version(), StartedAt: timestamppb.New(a.srv.StartedAt()), Listeners: listeners, TotalConnections: a.srv.TotalConnections(), }, nil } func (a *AdminServer) findListener(addr string) (*server.ListenerState, error) { for _, ls := range a.srv.Listeners() { if ls.Addr == addr { return ls, nil } } return nil, status.Errorf(codes.NotFound, "listener %q not found", addr) }