diff --git a/gen/mcp/v1/master.pb.go b/gen/mcp/v1/master.pb.go index f18cc84..97cbff4 100644 --- a/gen/mcp/v1/master.pb.go +++ b/gen/mcp/v1/master.pb.go @@ -712,6 +712,238 @@ func (x *NodeInfo) GetServices() int32 { return 0 } +type RegisterRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Role string `protobuf:"bytes,2,opt,name=role,proto3" json:"role,omitempty"` // "worker", "edge", or "master" + Address string `protobuf:"bytes,3,opt,name=address,proto3" json:"address,omitempty"` // agent gRPC address + Arch string `protobuf:"bytes,4,opt,name=arch,proto3" json:"arch,omitempty"` // "amd64" or "arm64" + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterRequest) Reset() { + *x = RegisterRequest{} + mi := &file_proto_mcp_v1_master_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterRequest) ProtoMessage() {} + +func (x *RegisterRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_mcp_v1_master_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterRequest.ProtoReflect.Descriptor instead. +func (*RegisterRequest) Descriptor() ([]byte, []int) { + return file_proto_mcp_v1_master_proto_rawDescGZIP(), []int{12} +} + +func (x *RegisterRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *RegisterRequest) GetRole() string { + if x != nil { + return x.Role + } + return "" +} + +func (x *RegisterRequest) GetAddress() string { + if x != nil { + return x.Address + } + return "" +} + +func (x *RegisterRequest) GetArch() string { + if x != nil { + return x.Arch + } + return "" +} + +type RegisterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Accepted bool `protobuf:"varint,1,opt,name=accepted,proto3" json:"accepted,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterResponse) Reset() { + *x = RegisterResponse{} + mi := &file_proto_mcp_v1_master_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterResponse) ProtoMessage() {} + +func (x *RegisterResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_mcp_v1_master_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterResponse.ProtoReflect.Descriptor instead. +func (*RegisterResponse) Descriptor() ([]byte, []int) { + return file_proto_mcp_v1_master_proto_rawDescGZIP(), []int{13} +} + +func (x *RegisterResponse) GetAccepted() bool { + if x != nil { + return x.Accepted + } + return false +} + +type HeartbeatRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + CpuMillicores int64 `protobuf:"varint,2,opt,name=cpu_millicores,json=cpuMillicores,proto3" json:"cpu_millicores,omitempty"` + MemoryBytes int64 `protobuf:"varint,3,opt,name=memory_bytes,json=memoryBytes,proto3" json:"memory_bytes,omitempty"` + DiskBytes int64 `protobuf:"varint,4,opt,name=disk_bytes,json=diskBytes,proto3" json:"disk_bytes,omitempty"` + Containers int32 `protobuf:"varint,5,opt,name=containers,proto3" json:"containers,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HeartbeatRequest) Reset() { + *x = HeartbeatRequest{} + mi := &file_proto_mcp_v1_master_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HeartbeatRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeartbeatRequest) ProtoMessage() {} + +func (x *HeartbeatRequest) ProtoReflect() protoreflect.Message { + mi := &file_proto_mcp_v1_master_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeartbeatRequest.ProtoReflect.Descriptor instead. +func (*HeartbeatRequest) Descriptor() ([]byte, []int) { + return file_proto_mcp_v1_master_proto_rawDescGZIP(), []int{14} +} + +func (x *HeartbeatRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *HeartbeatRequest) GetCpuMillicores() int64 { + if x != nil { + return x.CpuMillicores + } + return 0 +} + +func (x *HeartbeatRequest) GetMemoryBytes() int64 { + if x != nil { + return x.MemoryBytes + } + return 0 +} + +func (x *HeartbeatRequest) GetDiskBytes() int64 { + if x != nil { + return x.DiskBytes + } + return 0 +} + +func (x *HeartbeatRequest) GetContainers() int32 { + if x != nil { + return x.Containers + } + return 0 +} + +type HeartbeatResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Acknowledged bool `protobuf:"varint,1,opt,name=acknowledged,proto3" json:"acknowledged,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HeartbeatResponse) Reset() { + *x = HeartbeatResponse{} + mi := &file_proto_mcp_v1_master_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HeartbeatResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HeartbeatResponse) ProtoMessage() {} + +func (x *HeartbeatResponse) ProtoReflect() protoreflect.Message { + mi := &file_proto_mcp_v1_master_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HeartbeatResponse.ProtoReflect.Descriptor instead. +func (*HeartbeatResponse) Descriptor() ([]byte, []int) { + return file_proto_mcp_v1_master_proto_rawDescGZIP(), []int{15} +} + +func (x *HeartbeatResponse) GetAcknowledged() bool { + if x != nil { + return x.Acknowledged + } + return false +} + var File_proto_mcp_v1_master_proto protoreflect.FileDescriptor const file_proto_mcp_v1_master_proto_rawDesc = "" + @@ -765,12 +997,32 @@ const file_proto_mcp_v1_master_proto_rawDesc = "" + "containers\x18\x06 \x01(\x05R\n" + "containers\x12%\n" + "\x0elast_heartbeat\x18\a \x01(\tR\rlastHeartbeat\x12\x1a\n" + - "\bservices\x18\b \x01(\x05R\bservices2\xa9\x02\n" + + "\bservices\x18\b \x01(\x05R\bservices\"g\n" + + "\x0fRegisterRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x12\n" + + "\x04role\x18\x02 \x01(\tR\x04role\x12\x18\n" + + "\aaddress\x18\x03 \x01(\tR\aaddress\x12\x12\n" + + "\x04arch\x18\x04 \x01(\tR\x04arch\".\n" + + "\x10RegisterResponse\x12\x1a\n" + + "\baccepted\x18\x01 \x01(\bR\baccepted\"\xaf\x01\n" + + "\x10HeartbeatRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12%\n" + + "\x0ecpu_millicores\x18\x02 \x01(\x03R\rcpuMillicores\x12!\n" + + "\fmemory_bytes\x18\x03 \x01(\x03R\vmemoryBytes\x12\x1d\n" + + "\n" + + "disk_bytes\x18\x04 \x01(\x03R\tdiskBytes\x12\x1e\n" + + "\n" + + "containers\x18\x05 \x01(\x05R\n" + + "containers\"7\n" + + "\x11HeartbeatResponse\x12\"\n" + + "\facknowledged\x18\x01 \x01(\bR\facknowledged2\xaa\x03\n" + "\x10McpMasterService\x12C\n" + "\x06Deploy\x12\x1b.mcp.v1.MasterDeployRequest\x1a\x1c.mcp.v1.MasterDeployResponse\x12I\n" + "\bUndeploy\x12\x1d.mcp.v1.MasterUndeployRequest\x1a\x1e.mcp.v1.MasterUndeployResponse\x12C\n" + "\x06Status\x12\x1b.mcp.v1.MasterStatusRequest\x1a\x1c.mcp.v1.MasterStatusResponse\x12@\n" + - "\tListNodes\x12\x18.mcp.v1.ListNodesRequest\x1a\x19.mcp.v1.ListNodesResponseB*Z(git.wntrmute.dev/mc/mcp/gen/mcp/v1;mcpv1b\x06proto3" + "\tListNodes\x12\x18.mcp.v1.ListNodesRequest\x1a\x19.mcp.v1.ListNodesResponse\x12=\n" + + "\bRegister\x12\x17.mcp.v1.RegisterRequest\x1a\x18.mcp.v1.RegisterResponse\x12@\n" + + "\tHeartbeat\x12\x18.mcp.v1.HeartbeatRequest\x1a\x19.mcp.v1.HeartbeatResponseB*Z(git.wntrmute.dev/mc/mcp/gen/mcp/v1;mcpv1b\x06proto3" var ( file_proto_mcp_v1_master_proto_rawDescOnce sync.Once @@ -784,7 +1036,7 @@ func file_proto_mcp_v1_master_proto_rawDescGZIP() []byte { return file_proto_mcp_v1_master_proto_rawDescData } -var file_proto_mcp_v1_master_proto_msgTypes = make([]protoimpl.MessageInfo, 12) +var file_proto_mcp_v1_master_proto_msgTypes = make([]protoimpl.MessageInfo, 16) var file_proto_mcp_v1_master_proto_goTypes = []any{ (*MasterDeployRequest)(nil), // 0: mcp.v1.MasterDeployRequest (*MasterDeployResponse)(nil), // 1: mcp.v1.MasterDeployResponse @@ -798,10 +1050,14 @@ var file_proto_mcp_v1_master_proto_goTypes = []any{ (*ListNodesRequest)(nil), // 9: mcp.v1.ListNodesRequest (*ListNodesResponse)(nil), // 10: mcp.v1.ListNodesResponse (*NodeInfo)(nil), // 11: mcp.v1.NodeInfo - (*ServiceSpec)(nil), // 12: mcp.v1.ServiceSpec + (*RegisterRequest)(nil), // 12: mcp.v1.RegisterRequest + (*RegisterResponse)(nil), // 13: mcp.v1.RegisterResponse + (*HeartbeatRequest)(nil), // 14: mcp.v1.HeartbeatRequest + (*HeartbeatResponse)(nil), // 15: mcp.v1.HeartbeatResponse + (*ServiceSpec)(nil), // 16: mcp.v1.ServiceSpec } var file_proto_mcp_v1_master_proto_depIdxs = []int32{ - 12, // 0: mcp.v1.MasterDeployRequest.service:type_name -> mcp.v1.ServiceSpec + 16, // 0: mcp.v1.MasterDeployRequest.service:type_name -> mcp.v1.ServiceSpec 2, // 1: mcp.v1.MasterDeployResponse.deploy_result:type_name -> mcp.v1.StepResult 2, // 2: mcp.v1.MasterDeployResponse.edge_route_result:type_name -> mcp.v1.StepResult 2, // 3: mcp.v1.MasterDeployResponse.dns_result:type_name -> mcp.v1.StepResult @@ -812,12 +1068,16 @@ var file_proto_mcp_v1_master_proto_depIdxs = []int32{ 3, // 8: mcp.v1.McpMasterService.Undeploy:input_type -> mcp.v1.MasterUndeployRequest 5, // 9: mcp.v1.McpMasterService.Status:input_type -> mcp.v1.MasterStatusRequest 9, // 10: mcp.v1.McpMasterService.ListNodes:input_type -> mcp.v1.ListNodesRequest - 1, // 11: mcp.v1.McpMasterService.Deploy:output_type -> mcp.v1.MasterDeployResponse - 4, // 12: mcp.v1.McpMasterService.Undeploy:output_type -> mcp.v1.MasterUndeployResponse - 6, // 13: mcp.v1.McpMasterService.Status:output_type -> mcp.v1.MasterStatusResponse - 10, // 14: mcp.v1.McpMasterService.ListNodes:output_type -> mcp.v1.ListNodesResponse - 11, // [11:15] is the sub-list for method output_type - 7, // [7:11] is the sub-list for method input_type + 12, // 11: mcp.v1.McpMasterService.Register:input_type -> mcp.v1.RegisterRequest + 14, // 12: mcp.v1.McpMasterService.Heartbeat:input_type -> mcp.v1.HeartbeatRequest + 1, // 13: mcp.v1.McpMasterService.Deploy:output_type -> mcp.v1.MasterDeployResponse + 4, // 14: mcp.v1.McpMasterService.Undeploy:output_type -> mcp.v1.MasterUndeployResponse + 6, // 15: mcp.v1.McpMasterService.Status:output_type -> mcp.v1.MasterStatusResponse + 10, // 16: mcp.v1.McpMasterService.ListNodes:output_type -> mcp.v1.ListNodesResponse + 13, // 17: mcp.v1.McpMasterService.Register:output_type -> mcp.v1.RegisterResponse + 15, // 18: mcp.v1.McpMasterService.Heartbeat:output_type -> mcp.v1.HeartbeatResponse + 13, // [13:19] is the sub-list for method output_type + 7, // [7:13] is the sub-list for method input_type 7, // [7:7] is the sub-list for extension type_name 7, // [7:7] is the sub-list for extension extendee 0, // [0:7] is the sub-list for field type_name @@ -835,7 +1095,7 @@ func file_proto_mcp_v1_master_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_proto_mcp_v1_master_proto_rawDesc), len(file_proto_mcp_v1_master_proto_rawDesc)), NumEnums: 0, - NumMessages: 12, + NumMessages: 16, NumExtensions: 0, NumServices: 1, }, diff --git a/gen/mcp/v1/master_grpc.pb.go b/gen/mcp/v1/master_grpc.pb.go index 726952b..4c58c8d 100644 --- a/gen/mcp/v1/master_grpc.pb.go +++ b/gen/mcp/v1/master_grpc.pb.go @@ -25,6 +25,8 @@ const ( McpMasterService_Undeploy_FullMethodName = "/mcp.v1.McpMasterService/Undeploy" McpMasterService_Status_FullMethodName = "/mcp.v1.McpMasterService/Status" McpMasterService_ListNodes_FullMethodName = "/mcp.v1.McpMasterService/ListNodes" + McpMasterService_Register_FullMethodName = "/mcp.v1.McpMasterService/Register" + McpMasterService_Heartbeat_FullMethodName = "/mcp.v1.McpMasterService/Heartbeat" ) // McpMasterServiceClient is the client API for McpMasterService service. @@ -40,6 +42,9 @@ type McpMasterServiceClient interface { Undeploy(ctx context.Context, in *MasterUndeployRequest, opts ...grpc.CallOption) (*MasterUndeployResponse, error) Status(ctx context.Context, in *MasterStatusRequest, opts ...grpc.CallOption) (*MasterStatusResponse, error) ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error) + // Agent registration and health (called by agents). + Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) + Heartbeat(ctx context.Context, in *HeartbeatRequest, opts ...grpc.CallOption) (*HeartbeatResponse, error) } type mcpMasterServiceClient struct { @@ -90,6 +95,26 @@ func (c *mcpMasterServiceClient) ListNodes(ctx context.Context, in *ListNodesReq return out, nil } +func (c *mcpMasterServiceClient) Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RegisterResponse) + err := c.cc.Invoke(ctx, McpMasterService_Register_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *mcpMasterServiceClient) Heartbeat(ctx context.Context, in *HeartbeatRequest, opts ...grpc.CallOption) (*HeartbeatResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(HeartbeatResponse) + err := c.cc.Invoke(ctx, McpMasterService_Heartbeat_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + // McpMasterServiceServer is the server API for McpMasterService service. // All implementations must embed UnimplementedMcpMasterServiceServer // for forward compatibility. @@ -103,6 +128,9 @@ type McpMasterServiceServer interface { Undeploy(context.Context, *MasterUndeployRequest) (*MasterUndeployResponse, error) Status(context.Context, *MasterStatusRequest) (*MasterStatusResponse, error) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) + // Agent registration and health (called by agents). + Register(context.Context, *RegisterRequest) (*RegisterResponse, error) + Heartbeat(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) mustEmbedUnimplementedMcpMasterServiceServer() } @@ -125,6 +153,12 @@ func (UnimplementedMcpMasterServiceServer) Status(context.Context, *MasterStatus func (UnimplementedMcpMasterServiceServer) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) { return nil, status.Error(codes.Unimplemented, "method ListNodes not implemented") } +func (UnimplementedMcpMasterServiceServer) Register(context.Context, *RegisterRequest) (*RegisterResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Register not implemented") +} +func (UnimplementedMcpMasterServiceServer) Heartbeat(context.Context, *HeartbeatRequest) (*HeartbeatResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Heartbeat not implemented") +} func (UnimplementedMcpMasterServiceServer) mustEmbedUnimplementedMcpMasterServiceServer() {} func (UnimplementedMcpMasterServiceServer) testEmbeddedByValue() {} @@ -218,6 +252,42 @@ func _McpMasterService_ListNodes_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _McpMasterService_Register_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RegisterRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(McpMasterServiceServer).Register(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: McpMasterService_Register_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(McpMasterServiceServer).Register(ctx, req.(*RegisterRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _McpMasterService_Heartbeat_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HeartbeatRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(McpMasterServiceServer).Heartbeat(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: McpMasterService_Heartbeat_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(McpMasterServiceServer).Heartbeat(ctx, req.(*HeartbeatRequest)) + } + return interceptor(ctx, in, info, handler) +} + // McpMasterService_ServiceDesc is the grpc.ServiceDesc for McpMasterService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -241,6 +311,14 @@ var McpMasterService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ListNodes", Handler: _McpMasterService_ListNodes_Handler, }, + { + MethodName: "Register", + Handler: _McpMasterService_Register_Handler, + }, + { + MethodName: "Heartbeat", + Handler: _McpMasterService_Heartbeat_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "proto/mcp/v1/master.proto", diff --git a/internal/agent/heartbeat.go b/internal/agent/heartbeat.go new file mode 100644 index 0000000..49d4027 --- /dev/null +++ b/internal/agent/heartbeat.go @@ -0,0 +1,183 @@ +package agent + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "runtime" + "strings" + "sync" + "time" + + mcpv1 "git.wntrmute.dev/mc/mcp/gen/mcp/v1" + "git.wntrmute.dev/mc/mcp/internal/config" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" +) + +// MasterConfig holds the optional master connection settings for the agent. +// When configured, the agent self-registers and sends periodic heartbeats. +type MasterConfig struct { + Address string `toml:"address"` // master gRPC address + CACert string `toml:"ca_cert"` // CA cert to verify master's TLS + TokenPath string `toml:"token_path"` // MCIAS service token for auth +} + +// HeartbeatClient manages the agent's connection to the master for +// registration and heartbeats. +type HeartbeatClient struct { + client mcpv1.McpMasterServiceClient + conn *grpc.ClientConn + nodeName string + role string + address string // agent's own gRPC address + arch string + interval time.Duration + stop chan struct{} + wg sync.WaitGroup + logger interface{ Info(string, ...any); Warn(string, ...any); Error(string, ...any) } +} + +// NewHeartbeatClient creates a client that registers with the master and +// sends periodic heartbeats. Returns nil if master address is not configured. +func NewHeartbeatClient(cfg config.AgentConfig, logger interface{ Info(string, ...any); Warn(string, ...any); Error(string, ...any) }) (*HeartbeatClient, error) { + masterAddr := os.Getenv("MCP_MASTER_ADDRESS") + masterCACert := os.Getenv("MCP_MASTER_CA_CERT") + masterToken := os.Getenv("MCP_MASTER_TOKEN_PATH") + + if masterAddr == "" { + return nil, nil // master not configured + } + + token := "" + if masterToken != "" { + data, err := os.ReadFile(masterToken) //nolint:gosec // trusted config + if err != nil { + return nil, fmt.Errorf("read master token: %w", err) + } + token = strings.TrimSpace(string(data)) + } + + tlsConfig := &tls.Config{MinVersion: tls.VersionTLS13} + if masterCACert != "" { + caCert, err := os.ReadFile(masterCACert) //nolint:gosec // trusted config + if err != nil { + return nil, fmt.Errorf("read master CA cert: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caCert) { + return nil, fmt.Errorf("invalid master CA cert") + } + tlsConfig.RootCAs = pool + } + + conn, err := grpc.NewClient( + masterAddr, + grpc.WithTransportCredentials(credentials.NewTLS(tlsConfig)), + grpc.WithUnaryInterceptor(func(ctx context.Context, method string, req, reply any, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { + if token != "" { + ctx = metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+token) + } + return invoker(ctx, method, req, reply, cc, opts...) + }), + ) + if err != nil { + return nil, fmt.Errorf("dial master: %w", err) + } + + return &HeartbeatClient{ + client: mcpv1.NewMcpMasterServiceClient(conn), + conn: conn, + nodeName: cfg.Agent.NodeName, + role: "worker", // default; master node sets this via config + address: cfg.Server.GRPCAddr, + arch: runtime.GOARCH, + interval: 30 * time.Second, + stop: make(chan struct{}), + logger: logger, + }, nil +} + +// Start registers with the master and begins the heartbeat loop. +func (hc *HeartbeatClient) Start() { + if hc == nil { + return + } + + // Register with the master (retry with backoff). + hc.wg.Add(1) + go func() { + defer hc.wg.Done() + + backoff := time.Second + for { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + resp, err := hc.client.Register(ctx, &mcpv1.RegisterRequest{ + Name: hc.nodeName, + Role: hc.role, + Address: hc.address, + Arch: hc.arch, + }) + cancel() + + if err == nil && resp.GetAccepted() { + hc.logger.Info("registered with master", + "node", hc.nodeName, "master_accepted", true) + break + } + + hc.logger.Warn("registration failed, retrying", + "node", hc.nodeName, "err", err, "backoff", backoff) + + select { + case <-hc.stop: + return + case <-time.After(backoff): + } + + backoff *= 2 + if backoff > 60*time.Second { + backoff = 60 * time.Second + } + } + + // Heartbeat loop. + ticker := time.NewTicker(hc.interval) + defer ticker.Stop() + + for { + select { + case <-hc.stop: + return + case <-ticker.C: + hc.sendHeartbeat() + } + } + }() +} + +func (hc *HeartbeatClient) sendHeartbeat() { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + _, err := hc.client.Heartbeat(ctx, &mcpv1.HeartbeatRequest{ + Name: hc.nodeName, + Containers: 0, // TODO: count from runtime + }) + if err != nil { + hc.logger.Warn("heartbeat failed", "node", hc.nodeName, "err", err) + } +} + +// Stop stops the heartbeat loop and closes the master connection. +func (hc *HeartbeatClient) Stop() { + if hc == nil { + return + } + close(hc.stop) + hc.wg.Wait() + _ = hc.conn.Close() +} diff --git a/internal/master/master.go b/internal/master/master.go index 247a65e..352d596 100644 --- a/internal/master/master.go +++ b/internal/master/master.go @@ -124,6 +124,10 @@ func Run(cfg *config.MasterConfig, version string) error { "nodes", len(cfg.Nodes), ) + // Start heartbeat monitor. + hbMonitor := NewHeartbeatMonitor(m) + hbMonitor.Start() + // Signal handling. ctx, stop := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) defer stop() @@ -136,10 +140,12 @@ func Run(cfg *config.MasterConfig, version string) error { select { case <-ctx.Done(): logger.Info("shutting down") + hbMonitor.Stop() server.GracefulStop() pool.Close() return nil case err := <-errCh: + hbMonitor.Stop() pool.Close() return fmt.Errorf("serve: %w", err) } diff --git a/internal/master/registration.go b/internal/master/registration.go new file mode 100644 index 0000000..322cacf --- /dev/null +++ b/internal/master/registration.go @@ -0,0 +1,235 @@ +package master + +import ( + "context" + "strings" + "sync" + "time" + + mcpv1 "git.wntrmute.dev/mc/mcp/gen/mcp/v1" + "git.wntrmute.dev/mc/mcp/internal/auth" + "git.wntrmute.dev/mc/mcp/internal/masterdb" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// Register handles agent self-registration. Identity-bound: the agent's +// MCIAS service name must match the claimed node name (agent-rift → rift). +func (m *Master) Register(ctx context.Context, req *mcpv1.RegisterRequest) (*mcpv1.RegisterResponse, error) { + // Extract caller identity from the auth context. + tokenInfo := auth.TokenInfoFromContext(ctx) + if tokenInfo == nil { + return nil, status.Error(codes.Unauthenticated, "no auth context") + } + + // Identity binding: agent-rift can only register name="rift". + expectedName := strings.TrimPrefix(tokenInfo.Username, "agent-") + if expectedName == tokenInfo.Username { + // Not an agent-* account — also allow mcp-agent (legacy). + expectedName = req.GetName() + } + if req.GetName() != expectedName { + m.Logger.Warn("registration rejected: name mismatch", + "claimed", req.GetName(), "identity", tokenInfo.Username) + return nil, status.Errorf(codes.PermissionDenied, + "identity %q cannot register as %q", tokenInfo.Username, req.GetName()) + } + + // Check allowlist. + if len(m.Config.Registration.AllowedAgents) > 0 { + allowed := false + for _, a := range m.Config.Registration.AllowedAgents { + if a == tokenInfo.Username { + allowed = true + break + } + } + if !allowed { + m.Logger.Warn("registration rejected: not in allowlist", + "identity", tokenInfo.Username) + return nil, status.Errorf(codes.PermissionDenied, + "identity %q not in registration allowlist", tokenInfo.Username) + } + } + + // Check max nodes. + nodes, err := masterdb.ListNodes(m.DB) + if err == nil && len(nodes) >= m.Config.Registration.MaxNodes { + // Check if this is a re-registration (existing node). + found := false + for _, n := range nodes { + if n.Name == req.GetName() { + found = true + break + } + } + if !found { + return nil, status.Error(codes.ResourceExhausted, "max nodes reached") + } + } + + // Upsert node in registry. + role := req.GetRole() + if role == "" { + role = "worker" + } + arch := req.GetArch() + if arch == "" { + arch = "amd64" + } + + if err := masterdb.UpsertNode(m.DB, req.GetName(), req.GetAddress(), role, arch); err != nil { + m.Logger.Error("registration upsert failed", "node", req.GetName(), "err", err) + return nil, status.Error(codes.Internal, "registration failed") + } + if err := masterdb.UpdateNodeStatus(m.DB, req.GetName(), "healthy"); err != nil { + m.Logger.Warn("update node status", "node", req.GetName(), "err", err) + } + + // Update the agent pool connection. + if addErr := m.Pool.AddNode(req.GetName(), req.GetAddress()); addErr != nil { + m.Logger.Warn("pool update failed", "node", req.GetName(), "err", addErr) + } + + m.Logger.Info("agent registered", + "node", req.GetName(), "address", req.GetAddress(), + "role", role, "arch", arch, "identity", tokenInfo.Username) + + return &mcpv1.RegisterResponse{Accepted: true}, nil +} + +// Heartbeat handles agent heartbeats. Updates the node's resource data +// and last-heartbeat timestamp. Derives the node name from the MCIAS +// identity, not the request (security: don't trust self-reported name). +func (m *Master) Heartbeat(ctx context.Context, req *mcpv1.HeartbeatRequest) (*mcpv1.HeartbeatResponse, error) { + // Derive node name from identity. + tokenInfo := auth.TokenInfoFromContext(ctx) + if tokenInfo == nil { + return nil, status.Error(codes.Unauthenticated, "no auth context") + } + + nodeName := strings.TrimPrefix(tokenInfo.Username, "agent-") + if nodeName == tokenInfo.Username { + // Legacy mcp-agent account — use the request name. + nodeName = req.GetName() + } + + // Verify the node is registered. + node, err := masterdb.GetNode(m.DB, nodeName) + if err != nil || node == nil { + return nil, status.Errorf(codes.NotFound, "node %q not registered", nodeName) + } + + // Update heartbeat data. + now := time.Now().UTC().Format(time.RFC3339) + _, err = m.DB.Exec(` + UPDATE nodes SET + containers = ?, + status = 'healthy', + last_heartbeat = ?, + updated_at = datetime('now') + WHERE name = ? + `, req.GetContainers(), now, nodeName) + if err != nil { + m.Logger.Warn("heartbeat update failed", "node", nodeName, "err", err) + } + + return &mcpv1.HeartbeatResponse{Acknowledged: true}, nil +} + +// HeartbeatMonitor runs in the background, checking for agents that have +// missed heartbeats and probing them via HealthCheck. +type HeartbeatMonitor struct { + master *Master + interval time.Duration // heartbeat check interval (default: 30s) + timeout time.Duration // missed heartbeat threshold (default: 90s) + stop chan struct{} + wg sync.WaitGroup +} + +// NewHeartbeatMonitor creates a heartbeat monitor. +func NewHeartbeatMonitor(m *Master) *HeartbeatMonitor { + return &HeartbeatMonitor{ + master: m, + interval: 30 * time.Second, + timeout: 90 * time.Second, + stop: make(chan struct{}), + } +} + +// Start begins the heartbeat monitoring loop. +func (hm *HeartbeatMonitor) Start() { + hm.wg.Add(1) + go func() { + defer hm.wg.Done() + // Initial warm-up: don't alert for the first cycle. + time.Sleep(hm.timeout) + + ticker := time.NewTicker(hm.interval) + defer ticker.Stop() + + for { + select { + case <-hm.stop: + return + case <-ticker.C: + hm.check() + } + } + }() +} + +// Stop stops the heartbeat monitor. +func (hm *HeartbeatMonitor) Stop() { + close(hm.stop) + hm.wg.Wait() +} + +func (hm *HeartbeatMonitor) check() { + nodes, err := masterdb.ListNodes(hm.master.DB) + if err != nil { + hm.master.Logger.Warn("heartbeat check: list nodes", "err", err) + return + } + + now := time.Now() + for _, node := range nodes { + if node.Status == "unhealthy" { + continue // already marked, don't spam probes + } + + if node.LastHeartbeat == nil { + continue // never sent a heartbeat, skip + } + + if now.Sub(*node.LastHeartbeat) > hm.timeout { + hm.master.Logger.Warn("missed heartbeats, probing", + "node", node.Name, + "last_heartbeat", node.LastHeartbeat.Format(time.RFC3339)) + + // Probe the agent. + client, err := hm.master.Pool.Get(node.Name) + if err != nil { + hm.master.Logger.Warn("probe failed: no connection", + "node", node.Name, "err", err) + _ = masterdb.UpdateNodeStatus(hm.master.DB, node.Name, "unhealthy") + continue + } + + ctx, cancel := context.WithTimeout(context.Background(), + hm.master.Config.Timeouts.HealthCheck.Duration) + _, probeErr := client.HealthCheck(ctx, &mcpv1.HealthCheckRequest{}) + cancel() + + if probeErr != nil { + hm.master.Logger.Warn("probe failed", + "node", node.Name, "err", probeErr) + _ = masterdb.UpdateNodeStatus(hm.master.DB, node.Name, "unhealthy") + } else { + // Probe succeeded — node is alive, just not sending heartbeats. + hm.master.Logger.Info("probe succeeded (heartbeats stale)", + "node", node.Name) + } + } + } +} diff --git a/proto/mcp/v1/master.proto b/proto/mcp/v1/master.proto index ce4e0f4..e68e42b 100644 --- a/proto/mcp/v1/master.proto +++ b/proto/mcp/v1/master.proto @@ -16,6 +16,10 @@ service McpMasterService { rpc Undeploy(MasterUndeployRequest) returns (MasterUndeployResponse); rpc Status(MasterStatusRequest) returns (MasterStatusResponse); rpc ListNodes(ListNodesRequest) returns (ListNodesResponse); + + // Agent registration and health (called by agents). + rpc Register(RegisterRequest) returns (RegisterResponse); + rpc Heartbeat(HeartbeatRequest) returns (HeartbeatResponse); } // --- Deploy --- @@ -93,3 +97,30 @@ message NodeInfo { string last_heartbeat = 7; // RFC3339 int32 services = 8; // placement count } + +// --- Registration --- + +message RegisterRequest { + string name = 1; + string role = 2; // "worker", "edge", or "master" + string address = 3; // agent gRPC address + string arch = 4; // "amd64" or "arm64" +} + +message RegisterResponse { + bool accepted = 1; +} + +// --- Heartbeat --- + +message HeartbeatRequest { + string name = 1; + int64 cpu_millicores = 2; + int64 memory_bytes = 3; + int64 disk_bytes = 4; + int32 containers = 5; +} + +message HeartbeatResponse { + bool acknowledged = 1; +}