package agent import ( "bytes" "context" "encoding/json" "fmt" "io" "log/slog" "net/http" "strings" "git.wntrmute.dev/mc/mcp/internal/auth" "git.wntrmute.dev/mc/mcp/internal/config" ) // DNSRegistrar creates and removes A records in MCNS during deploy // and stop. It is nil-safe: all methods are no-ops when the receiver // is nil. type DNSRegistrar struct { serverURL string token string zone string nodeAddr string httpClient *http.Client logger *slog.Logger } // dnsRecord is the JSON representation of an MCNS record. type dnsRecord struct { ID int `json:"id"` Name string `json:"name"` Type string `json:"type"` Value string `json:"value"` TTL int `json:"ttl"` } // NewDNSRegistrar creates a DNSRegistrar. Returns (nil, nil) if // cfg.ServerURL is empty (DNS registration disabled). func NewDNSRegistrar(cfg config.MCNSConfig, logger *slog.Logger) (*DNSRegistrar, error) { if cfg.ServerURL == "" { logger.Info("mcns not configured, DNS registration disabled") return nil, nil } token, err := auth.LoadToken(cfg.TokenPath) if err != nil { return nil, fmt.Errorf("load mcns token: %w", err) } httpClient, err := newTLSClient(cfg.CACert) if err != nil { return nil, fmt.Errorf("create mcns HTTP client: %w", err) } logger.Info("mcns DNS registrar enabled", "server", cfg.ServerURL, "zone", cfg.Zone, "node_addr", cfg.NodeAddr) return &DNSRegistrar{ serverURL: strings.TrimRight(cfg.ServerURL, "/"), token: token, zone: cfg.Zone, nodeAddr: cfg.NodeAddr, httpClient: httpClient, logger: logger, }, nil } // EnsureRecord ensures an A record exists for the service in the // configured zone, pointing to the node's address. func (d *DNSRegistrar) EnsureRecord(ctx context.Context, serviceName string) error { if d == nil { return nil } existing, err := d.listRecords(ctx, serviceName) if err != nil { return fmt.Errorf("list DNS records: %w", err) } if len(existing) > 0 { r := existing[0] if r.Value == d.nodeAddr { d.logger.Debug("DNS record exists, skipping", "service", serviceName, "record", r.Name+"."+d.zone, "value", r.Value, ) return nil } // Wrong value — update it. d.logger.Info("updating DNS record", "service", serviceName, "old_value", r.Value, "new_value", d.nodeAddr, ) return d.updateRecord(ctx, r.ID, serviceName) } // No existing record — create one. d.logger.Info("creating DNS record", "service", serviceName, "record", serviceName+"."+d.zone, "value", d.nodeAddr, ) return d.createRecord(ctx, serviceName) } // RemoveRecord removes A records for the service from the configured zone. func (d *DNSRegistrar) RemoveRecord(ctx context.Context, serviceName string) error { if d == nil { return nil } existing, err := d.listRecords(ctx, serviceName) if err != nil { return fmt.Errorf("list DNS records: %w", err) } if len(existing) == 0 { d.logger.Debug("no DNS record to remove", "service", serviceName) return nil } for _, r := range existing { d.logger.Info("removing DNS record", "service", serviceName, "record", r.Name+"."+d.zone, "id", r.ID, ) if err := d.deleteRecord(ctx, r.ID); err != nil { return err } } return nil } // listRecords returns A records matching the service name in the zone. func (d *DNSRegistrar) listRecords(ctx context.Context, serviceName string) ([]dnsRecord, error) { url := fmt.Sprintf("%s/v1/zones/%s/records?name=%s&type=A", d.serverURL, d.zone, serviceName) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("create list request: %w", err) } req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("list records: %w", err) } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read list response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("list records: mcns returned %d: %s", resp.StatusCode, string(body)) } var records []dnsRecord if err := json.Unmarshal(body, &records); err != nil { return nil, fmt.Errorf("parse list response: %w", err) } return records, nil } // createRecord creates an A record in the zone. func (d *DNSRegistrar) createRecord(ctx context.Context, serviceName string) error { reqBody := map[string]interface{}{ "name": serviceName, "type": "A", "value": d.nodeAddr, "ttl": 300, } body, err := json.Marshal(reqBody) if err != nil { return fmt.Errorf("marshal create request: %w", err) } url := fmt.Sprintf("%s/v1/zones/%s/records", d.serverURL, d.zone) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("create record request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("create record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("create record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil } // updateRecord updates an existing record's value. func (d *DNSRegistrar) updateRecord(ctx context.Context, recordID int, serviceName string) error { reqBody := map[string]interface{}{ "name": serviceName, "type": "A", "value": d.nodeAddr, "ttl": 300, } body, err := json.Marshal(reqBody) if err != nil { return fmt.Errorf("marshal update request: %w", err) } url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, bytes.NewReader(body)) if err != nil { return fmt.Errorf("create update request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("update record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("update record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil } // deleteRecord deletes a record by ID. func (d *DNSRegistrar) deleteRecord(ctx context.Context, recordID int) error { url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) if err != nil { return fmt.Errorf("create delete request: %w", err) } req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("delete record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("delete record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil }