diff --git a/internal/master/agentclient.go b/internal/master/agentclient.go new file mode 100644 index 0000000..1834aa0 --- /dev/null +++ b/internal/master/agentclient.go @@ -0,0 +1,190 @@ +// Package master implements the mcp-master orchestrator. +package master + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" + "sync" + + mcpv1 "git.wntrmute.dev/mc/mcp/gen/mcp/v1" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +// AgentClient wraps a gRPC connection to a single mcp-agent. +type AgentClient struct { + conn *grpc.ClientConn + client mcpv1.McpAgentServiceClient + Node string +} + +// DialAgent connects to an agent at the given address using TLS 1.3. +// The token is attached to every outgoing RPC via metadata. +func DialAgent(address, caCertPath, token string) (*AgentClient, error) { + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS13, + } + + if caCertPath != "" { + caCert, err := os.ReadFile(caCertPath) //nolint:gosec // trusted config path + if err != nil { + return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("invalid CA cert %q", caCertPath) + } + tlsConfig.RootCAs = pool + } + + conn, err := grpc.NewClient( + address, + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithUnaryInterceptor(agentTokenInterceptor(token)), + grpc.WithStreamInterceptor(agentStreamTokenInterceptor(token)), + ) + if err != nil { + return nil, fmt.Errorf("dial agent %q: %w", address, err) + } + + return &AgentClient{ + conn: conn, + client: mcpv1.NewMcpAgentServiceClient(conn), + }, nil +} + +// Close closes the underlying gRPC connection. +func (c *AgentClient) Close() error { + if c == nil || c.conn == nil { + return nil + } + return c.conn.Close() +} + +// Deploy forwards a deploy request to the agent. +func (c *AgentClient) Deploy(ctx context.Context, req *mcpv1.DeployRequest) (*mcpv1.DeployResponse, error) { + return c.client.Deploy(ctx, req) +} + +// UndeployService forwards an undeploy request to the agent. +func (c *AgentClient) UndeployService(ctx context.Context, req *mcpv1.UndeployServiceRequest) (*mcpv1.UndeployServiceResponse, error) { + return c.client.UndeployService(ctx, req) +} + +// GetServiceStatus queries a service's status on the agent. +func (c *AgentClient) GetServiceStatus(ctx context.Context, req *mcpv1.GetServiceStatusRequest) (*mcpv1.GetServiceStatusResponse, error) { + return c.client.GetServiceStatus(ctx, req) +} + +// ListServices lists all services on the agent. +func (c *AgentClient) ListServices(ctx context.Context, req *mcpv1.ListServicesRequest) (*mcpv1.ListServicesResponse, error) { + return c.client.ListServices(ctx, req) +} + +// SetupEdgeRoute sets up an edge route on the agent. +func (c *AgentClient) SetupEdgeRoute(ctx context.Context, req *mcpv1.SetupEdgeRouteRequest) (*mcpv1.SetupEdgeRouteResponse, error) { + return c.client.SetupEdgeRoute(ctx, req) +} + +// RemoveEdgeRoute removes an edge route from the agent. +func (c *AgentClient) RemoveEdgeRoute(ctx context.Context, req *mcpv1.RemoveEdgeRouteRequest) (*mcpv1.RemoveEdgeRouteResponse, error) { + return c.client.RemoveEdgeRoute(ctx, req) +} + +// ListEdgeRoutes lists edge routes on the agent. +func (c *AgentClient) ListEdgeRoutes(ctx context.Context, req *mcpv1.ListEdgeRoutesRequest) (*mcpv1.ListEdgeRoutesResponse, error) { + return c.client.ListEdgeRoutes(ctx, req) +} + +// HealthCheck checks the agent's health. +func (c *AgentClient) HealthCheck(ctx context.Context, req *mcpv1.HealthCheckRequest) (*mcpv1.HealthCheckResponse, error) { + return c.client.HealthCheck(ctx, req) +} + +// agentTokenInterceptor attaches the bearer token to outgoing RPCs. +func agentTokenInterceptor(token string) grpc.UnaryClientInterceptor { + return func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token) + return invoker(ctx, method, req, reply, cc, opts...) + } +} + +func agentStreamTokenInterceptor(token string) grpc.StreamClientInterceptor { + return func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { + ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token) + return streamer(ctx, desc, cc, method, opts...) + } +} + +// AgentPool manages connections to multiple agents, keyed by node name. +type AgentPool struct { + mu sync.RWMutex + clients map[string]*AgentClient + caCert string + token string +} + +// NewAgentPool creates a pool with the given CA cert and service token. +func NewAgentPool(caCertPath, token string) *AgentPool { + return &AgentPool{ + clients: make(map[string]*AgentClient), + caCert: caCertPath, + token: token, + } +} + +// AddNode dials an agent and adds it to the pool. +func (p *AgentPool) AddNode(name, address string) error { + client, err := DialAgent(address, p.caCert, p.token) + if err != nil { + return fmt.Errorf("add node %s: %w", name, err) + } + client.Node = name + + p.mu.Lock() + defer p.mu.Unlock() + + // Close existing connection if re-adding. + if old, ok := p.clients[name]; ok { + _ = old.Close() + } + p.clients[name] = client + return nil +} + +// Get returns the agent client for a node. +func (p *AgentPool) Get(name string) (*AgentClient, error) { + p.mu.RLock() + defer p.mu.RUnlock() + + client, ok := p.clients[name] + if !ok { + return nil, fmt.Errorf("node %q not found in pool", name) + } + return client, nil +} + +// Close closes all agent connections. +func (p *AgentPool) Close() { + p.mu.Lock() + defer p.mu.Unlock() + + for _, c := range p.clients { + _ = c.Close() + } + p.clients = make(map[string]*AgentClient) +} + +// LoadServiceToken reads a token from a file path. +func LoadServiceToken(path string) (string, error) { + data, err := os.ReadFile(path) //nolint:gosec // trusted config path + if err != nil { + return "", fmt.Errorf("read service token %q: %w", path, err) + } + return strings.TrimSpace(string(data)), nil +}