package master import ( "bytes" "context" "crypto/tls" "crypto/x509" "encoding/json" "fmt" "io" "log/slog" "net/http" "os" "strings" "time" "git.wntrmute.dev/mc/mcp/internal/auth" "git.wntrmute.dev/mc/mcp/internal/config" ) // DNSClient creates and removes A records in MCNS. Unlike the agent's // DNSRegistrar, the master registers records for different node IPs // (the nodeAddr is a per-call parameter, not a fixed config value). type DNSClient struct { serverURL string token string zone string httpClient *http.Client logger *slog.Logger } type dnsRecord struct { ID int `json:"ID"` Name string `json:"Name"` Type string `json:"Type"` Value string `json:"Value"` TTL int `json:"TTL"` } // NewDNSClient creates a DNS client. Returns (nil, nil) if serverURL is empty. func NewDNSClient(cfg config.MCNSConfig, logger *slog.Logger) (*DNSClient, error) { if cfg.ServerURL == "" { logger.Info("mcns not configured, DNS registration disabled") return nil, nil } token, err := auth.LoadToken(cfg.TokenPath) if err != nil { return nil, fmt.Errorf("load mcns token: %w", err) } httpClient, err := newHTTPClient(cfg.CACert) if err != nil { return nil, fmt.Errorf("create mcns HTTP client: %w", err) } logger.Info("master DNS client enabled", "server", cfg.ServerURL, "zone", cfg.Zone) return &DNSClient{ serverURL: strings.TrimRight(cfg.ServerURL, "/"), token: token, zone: cfg.Zone, httpClient: httpClient, logger: logger, }, nil } // Zone returns the configured DNS zone. func (d *DNSClient) Zone() string { if d == nil { return "" } return d.zone } // EnsureRecord ensures an A record exists for serviceName pointing to nodeAddr. func (d *DNSClient) EnsureRecord(ctx context.Context, serviceName, nodeAddr string) error { if d == nil { return nil } existing, err := d.listRecords(ctx, serviceName) if err != nil { return fmt.Errorf("list DNS records: %w", err) } for _, r := range existing { if r.Value == nodeAddr { d.logger.Debug("DNS record exists", "service", serviceName, "value", r.Value) return nil } } if len(existing) > 0 { d.logger.Info("updating DNS record", "service", serviceName, "old_value", existing[0].Value, "new_value", nodeAddr) return d.updateRecord(ctx, existing[0].ID, serviceName, nodeAddr) } d.logger.Info("creating DNS record", "service", serviceName, "record", serviceName+"."+d.zone, "value", nodeAddr) return d.createRecord(ctx, serviceName, nodeAddr) } // RemoveRecord removes A records for serviceName. func (d *DNSClient) RemoveRecord(ctx context.Context, serviceName string) error { if d == nil { return nil } existing, err := d.listRecords(ctx, serviceName) if err != nil { return fmt.Errorf("list DNS records: %w", err) } for _, r := range existing { d.logger.Info("removing DNS record", "service", serviceName, "id", r.ID) if err := d.deleteRecord(ctx, r.ID); err != nil { return err } } return nil } func (d *DNSClient) listRecords(ctx context.Context, serviceName string) ([]dnsRecord, error) { url := fmt.Sprintf("%s/v1/zones/%s/records?name=%s&type=A", d.serverURL, d.zone, serviceName) req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) if err != nil { return nil, fmt.Errorf("create list request: %w", err) } req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("list records: %w", err) } defer func() { _ = resp.Body.Close() }() body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("read list response: %w", err) } if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("list records: mcns returned %d: %s", resp.StatusCode, string(body)) } var envelope struct { Records []dnsRecord `json:"records"` } if err := json.Unmarshal(body, &envelope); err != nil { return nil, fmt.Errorf("parse list response: %w", err) } return envelope.Records, nil } func (d *DNSClient) createRecord(ctx context.Context, serviceName, nodeAddr string) error { reqBody, _ := json.Marshal(map[string]interface{}{ "name": serviceName, "type": "A", "value": nodeAddr, "ttl": 300, }) url := fmt.Sprintf("%s/v1/zones/%s/records", d.serverURL, d.zone) req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBody)) if err != nil { return fmt.Errorf("create record request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("create record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusCreated && resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("create record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil } func (d *DNSClient) updateRecord(ctx context.Context, recordID int, serviceName, nodeAddr string) error { reqBody, _ := json.Marshal(map[string]interface{}{ "name": serviceName, "type": "A", "value": nodeAddr, "ttl": 300, }) url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) req, err := http.NewRequestWithContext(ctx, http.MethodPut, url, bytes.NewReader(reqBody)) if err != nil { return fmt.Errorf("create update request: %w", err) } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("update record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("update record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil } func (d *DNSClient) deleteRecord(ctx context.Context, recordID int) error { url := fmt.Sprintf("%s/v1/zones/%s/records/%d", d.serverURL, d.zone, recordID) req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil) if err != nil { return fmt.Errorf("create delete request: %w", err) } req.Header.Set("Authorization", "Bearer "+d.token) resp, err := d.httpClient.Do(req) if err != nil { return fmt.Errorf("delete record: %w", err) } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusOK { respBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("delete record: mcns returned %d: %s", resp.StatusCode, string(respBody)) } return nil } func newHTTPClient(caCertPath string) (*http.Client, error) { tlsConfig := &tls.Config{ MinVersion: tls.VersionTLS13, } if caCertPath != "" { caCert, err := os.ReadFile(caCertPath) //nolint:gosec // path from trusted config if err != nil { return nil, fmt.Errorf("read CA cert %q: %w", caCertPath, err) } pool := x509.NewCertPool() if !pool.AppendCertsFromPEM(caCert) { return nil, fmt.Errorf("parse CA cert %q: no valid certificates found", caCertPath) } tlsConfig.RootCAs = pool } return &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ TLSClientConfig: tlsConfig, }, }, nil }