diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 2632c31..b504e3f 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -34,6 +34,7 @@ type Agent struct { PortAlloc *PortAllocator Proxy *ProxyRouter Certs *CertProvisioner + DNS *DNSRegistrar } // Run starts the agent: opens the database, sets up the gRPC server with @@ -63,6 +64,11 @@ func Run(cfg *config.AgentConfig) error { return fmt.Errorf("create cert provisioner: %w", err) } + dns, err := NewDNSRegistrar(cfg.MCNS, logger) + if err != nil { + return fmt.Errorf("create DNS registrar: %w", err) + } + a := &Agent{ Config: cfg, DB: db, @@ -72,6 +78,7 @@ func Run(cfg *config.AgentConfig) error { PortAlloc: NewPortAllocator(), Proxy: proxy, Certs: certs, + DNS: dns, } tlsCert, err := tls.LoadX509KeyPair(cfg.Server.TLSCert, cfg.Server.TLSKey) diff --git a/internal/agent/deploy.go b/internal/agent/deploy.go index 468e415..1e08cee 100644 --- a/internal/agent/deploy.go +++ b/internal/agent/deploy.go @@ -164,6 +164,13 @@ func (a *Agent) deployComponent(ctx context.Context, serviceName string, cs *mcp } } + // Register DNS record for the service. + if a.DNS != nil && len(regRoutes) > 0 { + if err := a.DNS.EnsureRecord(ctx, serviceName); err != nil { + a.Logger.Warn("failed to register DNS record", "service", serviceName, "err", err) + } + } + if err := registry.UpdateComponentState(a.DB, serviceName, compName, "running", "running"); err != nil { a.Logger.Warn("failed to update component state", "service", serviceName, "component", compName, "err", err) } diff --git a/internal/agent/dns.go b/internal/agent/dns.go new file mode 100644 index 0000000..875161e --- /dev/null +++ b/internal/agent/dns.go @@ -0,0 +1,260 @@ +package agent + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "strings" + + "git.wntrmute.dev/mc/mcp/internal/auth" + "git.wntrmute.dev/mc/mcp/internal/config" +) + +// DNSRegistrar creates and removes A records in MCNS during deploy +// and stop. It is nil-safe: all methods are no-ops when the receiver +// is nil. +type DNSRegistrar struct { + serverURL string + token string + zone string + nodeAddr string + httpClient *http.Client + logger *slog.Logger +} + +// dnsRecord is the JSON representation of an MCNS record. +type dnsRecord struct { + ID int `json:"id"` + Name string `json:"name"` + Type string `json:"type"` + Value string `json:"value"` + TTL int `json:"ttl"` +} + +// NewDNSRegistrar creates a DNSRegistrar. Returns (nil, nil) if +// cfg.ServerURL is empty (DNS registration disabled). +func NewDNSRegistrar(cfg config.MCNSConfig, logger *slog.Logger) (*DNSRegistrar, error) { + if cfg.ServerURL == "" { + logger.Info("mcns not configured, DNS registration disabled") + return nil, nil + } + + token, err := auth.LoadToken(cfg.TokenPath) + if err != nil { + return nil, fmt.Errorf("load mcns token: %w", err) + } + + httpClient, err := newTLSClient(cfg.CACert) + if err != nil { + return nil, fmt.Errorf("create mcns HTTP client: %w", err) + } + + logger.Info("mcns DNS registrar enabled", "server", cfg.ServerURL, "zone", cfg.Zone, "node_addr", cfg.NodeAddr) + return &DNSRegistrar{ + serverURL: strings.TrimRight(cfg.ServerURL, "/"), + token: token, + zone: cfg.Zone, + nodeAddr: cfg.NodeAddr, + httpClient: httpClient, + logger: logger, + }, nil +} + +// EnsureRecord ensures an A record exists for the service in the +// configured zone, pointing to the node's address. +func (d *DNSRegistrar) EnsureRecord(ctx context.Context, serviceName string) error { + if d == nil { + return nil + } + + existing, err := d.listRecords(ctx, serviceName) + if err != nil { + return fmt.Errorf("list DNS records: %w", err) + } + + if len(existing) > 0 { + r := existing[0] + if r.Value == d.nodeAddr { + d.logger.Debug("DNS record exists, skipping", + "service", serviceName, + "record", r.Name+"."+d.zone, + "value", r.Value, + ) + return nil + } + // Wrong value — update it. + d.logger.Info("updating DNS record", + "service", serviceName, + "old_value", r.Value, + "new_value", d.nodeAddr, + ) + return d.updateRecord(ctx, r.ID, serviceName) + } + + // No existing record — create one. + d.logger.Info("creating DNS record", + "service", serviceName, + "record", serviceName+"."+d.zone, + "value", d.nodeAddr, + ) + return d.createRecord(ctx, serviceName) +} + +// RemoveRecord removes A records for the service from the configured zone. +func (d *DNSRegistrar) RemoveRecord(ctx context.Context, serviceName string) error { + if d == nil { + return nil + } + + existing, err := d.listRecords(ctx, serviceName) + if err != nil { + return fmt.Errorf("list DNS records: %w", err) + } + + if len(existing) == 0 { + d.logger.Debug("no DNS record to remove", "service", serviceName) + return nil + } + + for _, r := range existing { + d.logger.Info("removing DNS record", + "service", serviceName, + "record", r.Name+"."+d.zone, + "id", r.ID, + ) + if err := d.deleteRecord(ctx, r.ID); err != nil { + return err + } + } + return nil +} + +// listRecords returns A records matching the service name in the zone. +func (d *DNSRegistrar) listRecords(ctx context.Context, serviceName string) ([]dnsRecord, error) { + url := fmt.Sprintf("%s/v1/zones/%s/records?name=%s&type=A", d.serverURL, d.zone, serviceName) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("create list request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+d.token) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("list records: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read list response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("list records: mcns returned %d: %s", resp.StatusCode, string(body)) + } + + var records []dnsRecord + if err := json.Unmarshal(body, &records); err != nil { + return nil, fmt.Errorf("parse list response: %w", err) + } + return records, nil +} + +// createRecord creates an A record in the zone. +func (d *DNSRegistrar) createRecord(ctx context.Context, serviceName string) error { + reqBody := map[string]interface{}{ + "name": serviceName, + "type": "A", + "value": d.nodeAddr, + "ttl": 300, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal create request: %w", err) + } + + url := fmt.Sprintf("%s/v1/zones/%s/records", d.serverURL, d.zone) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("create record request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+d.token) + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("create record: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("create record: mcns returned %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// updateRecord updates an existing record's value. +func (d *DNSRegistrar) updateRecord(ctx context.Context, recordID int, serviceName string) error { + reqBody := map[string]interface{}{ + "name": serviceName, + "type": "A", + "value": d.nodeAddr, + "ttl": 300, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("marshal update request: %w", err) + } + + url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) + req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("create update request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+d.token) + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("update record: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("update record: mcns returned %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} + +// deleteRecord deletes a record by ID. +func (d *DNSRegistrar) deleteRecord(ctx context.Context, recordID int) error { + url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) + req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) + if err != nil { + return fmt.Errorf("create delete request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+d.token) + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("delete record: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { + respBody, _ := io.ReadAll(resp.Body) + return fmt.Errorf("delete record: mcns returned %d: %s", resp.StatusCode, string(respBody)) + } + + return nil +} diff --git a/internal/agent/dns_test.go b/internal/agent/dns_test.go new file mode 100644 index 0000000..df48cb4 --- /dev/null +++ b/internal/agent/dns_test.go @@ -0,0 +1,214 @@ +package agent + +import ( + "context" + "encoding/json" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "git.wntrmute.dev/mc/mcp/internal/config" +) + +func TestNilDNSRegistrarIsNoop(t *testing.T) { + var d *DNSRegistrar + if err := d.EnsureRecord(context.Background(), "svc"); err != nil { + t.Fatalf("EnsureRecord on nil: %v", err) + } + if err := d.RemoveRecord(context.Background(), "svc"); err != nil { + t.Fatalf("RemoveRecord on nil: %v", err) + } +} + +func TestNewDNSRegistrarDisabledWhenUnconfigured(t *testing.T) { + d, err := NewDNSRegistrar(config.MCNSConfig{}, slog.Default()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if d != nil { + t.Fatal("expected nil registrar for empty config") + } +} + +func TestEnsureRecordCreatesWhenMissing(t *testing.T) { + var gotMethod, gotPath, gotAuth string + var gotBody map[string]interface{} + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + // List returns empty — no existing records. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("[]")) + return + } + gotMethod = r.Method + gotPath = r.URL.Path + gotAuth = r.Header.Get("Authorization") + _ = json.NewDecoder(r.Body).Decode(&gotBody) + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{"id":1}`)) + })) + defer srv.Close() + + d := &DNSRegistrar{ + serverURL: srv.URL, + token: "test-token", + zone: "svc.mcp.metacircular.net", + nodeAddr: "192.168.88.181", + httpClient: srv.Client(), + logger: slog.Default(), + } + + if err := d.EnsureRecord(context.Background(), "myservice"); err != nil { + t.Fatalf("EnsureRecord: %v", err) + } + + if gotMethod != http.MethodPost { + t.Fatalf("method: got %q, want POST", gotMethod) + } + if gotPath != "/v1/zones/svc.mcp.metacircular.net/records" { + t.Fatalf("path: got %q", gotPath) + } + if gotAuth != "Bearer test-token" { + t.Fatalf("auth: got %q", gotAuth) + } + if gotBody["name"] != "myservice" { + t.Fatalf("name: got %v", gotBody["name"]) + } + if gotBody["type"] != "A" { + t.Fatalf("type: got %v", gotBody["type"]) + } + if gotBody["value"] != "192.168.88.181" { + t.Fatalf("value: got %v", gotBody["value"]) + } +} + +func TestEnsureRecordSkipsWhenExists(t *testing.T) { + createCalled := false + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + // Return an existing record with the correct value. + records := []dnsRecord{{ID: 1, Name: "myservice", Type: "A", Value: "192.168.88.181", TTL: 300}} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(records) + return + } + createCalled = true + w.WriteHeader(http.StatusCreated) + })) + defer srv.Close() + + d := &DNSRegistrar{ + serverURL: srv.URL, + token: "test-token", + zone: "svc.mcp.metacircular.net", + nodeAddr: "192.168.88.181", + httpClient: srv.Client(), + logger: slog.Default(), + } + + if err := d.EnsureRecord(context.Background(), "myservice"); err != nil { + t.Fatalf("EnsureRecord: %v", err) + } + if createCalled { + t.Fatal("should not create when record already exists with correct value") + } +} + +func TestEnsureRecordUpdatesWrongValue(t *testing.T) { + var gotMethod string + var gotPath string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + // Return a record with a stale value. + records := []dnsRecord{{ID: 42, Name: "myservice", Type: "A", Value: "10.0.0.1", TTL: 300}} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(records) + return + } + gotMethod = r.Method + gotPath = r.URL.Path + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + d := &DNSRegistrar{ + serverURL: srv.URL, + token: "test-token", + zone: "svc.mcp.metacircular.net", + nodeAddr: "192.168.88.181", + httpClient: srv.Client(), + logger: slog.Default(), + } + + if err := d.EnsureRecord(context.Background(), "myservice"); err != nil { + t.Fatalf("EnsureRecord: %v", err) + } + if gotMethod != http.MethodPut { + t.Fatalf("method: got %q, want PUT", gotMethod) + } + if gotPath != "/v1/zones/svc.mcp.metacircular.net/records/42" { + t.Fatalf("path: got %q", gotPath) + } +} + +func TestRemoveRecordDeletes(t *testing.T) { + var gotMethod, gotPath string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == http.MethodGet { + records := []dnsRecord{{ID: 7, Name: "myservice", Type: "A", Value: "192.168.88.181", TTL: 300}} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(records) + return + } + gotMethod = r.Method + gotPath = r.URL.Path + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + d := &DNSRegistrar{ + serverURL: srv.URL, + token: "test-token", + zone: "svc.mcp.metacircular.net", + nodeAddr: "192.168.88.181", + httpClient: srv.Client(), + logger: slog.Default(), + } + + if err := d.RemoveRecord(context.Background(), "myservice"); err != nil { + t.Fatalf("RemoveRecord: %v", err) + } + if gotMethod != http.MethodDelete { + t.Fatalf("method: got %q, want DELETE", gotMethod) + } + if gotPath != "/v1/zones/svc.mcp.metacircular.net/records/7" { + t.Fatalf("path: got %q", gotPath) + } +} + +func TestRemoveRecordNoopWhenMissing(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // List returns empty. + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte("[]")) + })) + defer srv.Close() + + d := &DNSRegistrar{ + serverURL: srv.URL, + token: "test-token", + zone: "svc.mcp.metacircular.net", + nodeAddr: "192.168.88.181", + httpClient: srv.Client(), + logger: slog.Default(), + } + + if err := d.RemoveRecord(context.Background(), "myservice"); err != nil { + t.Fatalf("RemoveRecord: %v", err) + } +} diff --git a/internal/agent/lifecycle.go b/internal/agent/lifecycle.go index 9a40afb..9257447 100644 --- a/internal/agent/lifecycle.go +++ b/internal/agent/lifecycle.go @@ -37,6 +37,13 @@ func (a *Agent) StopService(ctx context.Context, req *mcpv1.StopServiceRequest) } } + // Remove DNS record when stopping the service. + if len(c.Routes) > 0 && a.DNS != nil { + if err := a.DNS.RemoveRecord(ctx, req.GetName()); err != nil { + a.Logger.Warn("failed to remove DNS record", "service", req.GetName(), "err", err) + } + } + if err := a.Runtime.Stop(ctx, containerName); err != nil { a.Logger.Info("stop container (ignored)", "container", containerName, "error", err) } diff --git a/internal/config/agent.go b/internal/config/agent.go index e5b105c..6fac1ce 100644 --- a/internal/config/agent.go +++ b/internal/config/agent.go @@ -16,6 +16,7 @@ type AgentConfig struct { Agent AgentSettings `toml:"agent"` MCProxy MCProxyConfig `toml:"mcproxy"` Metacrypt MetacryptConfig `toml:"metacrypt"` + MCNS MCNSConfig `toml:"mcns"` Monitor MonitorConfig `toml:"monitor"` Log LogConfig `toml:"log"` } @@ -40,6 +41,26 @@ type MetacryptConfig struct { TokenPath string `toml:"token_path"` } +// MCNSConfig holds the MCNS DNS integration settings for automated +// DNS record registration. If ServerURL is empty, DNS registration +// is disabled. +type MCNSConfig struct { + // ServerURL is the MCNS API base URL (e.g. "https://localhost:28443"). + ServerURL string `toml:"server_url"` + + // CACert is the path to the CA certificate for verifying MCNS's TLS. + CACert string `toml:"ca_cert"` + + // TokenPath is the path to the MCIAS service token file. + TokenPath string `toml:"token_path"` + + // Zone is the DNS zone for service records. Defaults to "svc.mcp.metacircular.net". + Zone string `toml:"zone"` + + // NodeAddr is the IP address to register as the A record value. + NodeAddr string `toml:"node_addr"` +} + // MCProxyConfig holds the mc-proxy connection settings. type MCProxyConfig struct { // Socket is the path to the mc-proxy gRPC admin API Unix socket. @@ -177,6 +198,9 @@ func applyAgentDefaults(cfg *AgentConfig) { if cfg.Metacrypt.Issuer == "" { cfg.Metacrypt.Issuer = "infra" } + if cfg.MCNS.Zone == "" { + cfg.MCNS.Zone = "svc.mcp.metacircular.net" + } } func applyAgentEnvOverrides(cfg *AgentConfig) { @@ -213,6 +237,15 @@ func applyAgentEnvOverrides(cfg *AgentConfig) { if v := os.Getenv("MCP_AGENT_METACRYPT_TOKEN_PATH"); v != "" { cfg.Metacrypt.TokenPath = v } + if v := os.Getenv("MCP_AGENT_MCNS_SERVER_URL"); v != "" { + cfg.MCNS.ServerURL = v + } + if v := os.Getenv("MCP_AGENT_MCNS_TOKEN_PATH"); v != "" { + cfg.MCNS.TokenPath = v + } + if v := os.Getenv("MCP_AGENT_MCNS_NODE_ADDR"); v != "" { + cfg.MCNS.NodeAddr = v + } } func validateAgentConfig(cfg *AgentConfig) error { diff --git a/internal/config/config_test.go b/internal/config/config_test.go index daa7398..25debc0 100644 --- a/internal/config/config_test.go +++ b/internal/config/config_test.go @@ -171,6 +171,11 @@ func TestLoadAgentConfig(t *testing.T) { if cfg.Metacrypt.Issuer != "infra" { t.Fatalf("metacrypt.issuer default: got %q, want infra", cfg.Metacrypt.Issuer) } + + // MCNS defaults when section is omitted. + if cfg.MCNS.Zone != "svc.mcp.metacircular.net" { + t.Fatalf("mcns.zone default: got %q, want svc.mcp.metacircular.net", cfg.MCNS.Zone) + } } func TestCLIConfigValidation(t *testing.T) { @@ -521,6 +526,81 @@ node_name = "rift" } } +func TestAgentConfigMCNS(t *testing.T) { + cfgStr := ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +[mcns] +server_url = "https://localhost:28443" +ca_cert = "/srv/mcp/certs/metacircular-ca.pem" +token_path = "/srv/mcp/metacrypt-token" +zone = "custom.zone" +node_addr = "10.0.0.1" +` + path := writeTempConfig(t, cfgStr) + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if cfg.MCNS.ServerURL != "https://localhost:28443" { + t.Fatalf("mcns.server_url: got %q", cfg.MCNS.ServerURL) + } + if cfg.MCNS.CACert != "/srv/mcp/certs/metacircular-ca.pem" { + t.Fatalf("mcns.ca_cert: got %q", cfg.MCNS.CACert) + } + if cfg.MCNS.Zone != "custom.zone" { + t.Fatalf("mcns.zone: got %q", cfg.MCNS.Zone) + } + if cfg.MCNS.NodeAddr != "10.0.0.1" { + t.Fatalf("mcns.node_addr: got %q", cfg.MCNS.NodeAddr) + } +} + +func TestAgentConfigMCNSEnvOverrides(t *testing.T) { + minimal := ` +[server] +grpc_addr = "0.0.0.0:9444" +tls_cert = "/srv/mcp/cert.pem" +tls_key = "/srv/mcp/key.pem" +[database] +path = "/srv/mcp/mcp.db" +[mcias] +server_url = "https://mcias.metacircular.net:8443" +service_name = "mcp-agent" +[agent] +node_name = "rift" +` + t.Setenv("MCP_AGENT_MCNS_SERVER_URL", "https://override:28443") + t.Setenv("MCP_AGENT_MCNS_TOKEN_PATH", "/override/token") + t.Setenv("MCP_AGENT_MCNS_NODE_ADDR", "10.0.0.99") + + path := writeTempConfig(t, minimal) + cfg, err := LoadAgentConfig(path) + if err != nil { + t.Fatalf("load: %v", err) + } + + if cfg.MCNS.ServerURL != "https://override:28443" { + t.Fatalf("mcns.server_url: got %q", cfg.MCNS.ServerURL) + } + if cfg.MCNS.TokenPath != "/override/token" { + t.Fatalf("mcns.token_path: got %q", cfg.MCNS.TokenPath) + } + if cfg.MCNS.NodeAddr != "10.0.0.99" { + t.Fatalf("mcns.node_addr: got %q", cfg.MCNS.NodeAddr) + } +} + func TestDurationParsing(t *testing.T) { tests := []struct { input string