package db import ( "database/sql" "errors" "fmt" "net" "strings" "time" ) // Record represents a DNS record stored in the database. type Record struct { ID int64 ZoneID int64 ZoneName string Name string Type string Value string TTL int CreatedAt string UpdatedAt string } // ListRecords returns records for a zone, optionally filtered by name and type. func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) { zoneName = strings.ToLower(strings.TrimSuffix(zoneName, ".")) zone, err := d.GetZone(zoneName) if err != nil { return nil, err } query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?` args := []any{zone.ID} if name != "" { query += ` AND r.name = ?` args = append(args, strings.ToLower(name)) } if recordType != "" { query += ` AND r.type = ?` args = append(args, strings.ToUpper(recordType)) } query += ` ORDER BY r.name, r.type, r.value` rows, err := d.Query(query, args...) if err != nil { return nil, fmt.Errorf("list records: %w", err) } defer rows.Close() var records []Record for rows.Next() { var r Record if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil { return nil, fmt.Errorf("scan record: %w", err) } records = append(records, r) } return records, rows.Err() } // LookupRecords returns records matching a name and type within a zone. // Used by the DNS handler for query resolution. func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) { zoneName = strings.ToLower(strings.TrimSuffix(zoneName, ".")) name = strings.ToLower(name) rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at FROM records r JOIN zones z ON r.zone_id = z.id WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType)) if err != nil { return nil, fmt.Errorf("lookup records: %w", err) } defer rows.Close() var records []Record for rows.Next() { var r Record if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil { return nil, fmt.Errorf("scan record: %w", err) } records = append(records, r) } return records, rows.Err() } // LookupCNAME returns CNAME records for a name within a zone. func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) { return d.LookupRecords(zoneName, name, "CNAME") } // HasRecordsForName checks if any records of the given types exist for a name. func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) { placeholders := make([]string, len(types)) args := []any{zoneID, strings.ToLower(name)} for i, t := range types { placeholders[i] = "?" args = append(args, strings.ToUpper(t)) } query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ",")) var count int if err := tx.QueryRow(query, args...).Scan(&count); err != nil { return false, err } return count > 0, nil } // GetRecord returns a record by ID. func (d *DB) GetRecord(id int64) (*Record, error) { var r Record err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id). Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get record: %w", err) } return &r, nil } // CreateRecord inserts a new record, enforcing CNAME exclusivity and // value validation. Bumps the zone serial within the same transaction. func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) { zoneName = strings.ToLower(strings.TrimSuffix(zoneName, ".")) name = strings.ToLower(name) recordType = strings.ToUpper(recordType) if err := validateRecordValue(recordType, value); err != nil { return nil, err } zone, err := d.GetZone(zoneName) if err != nil { return nil, err } if ttl <= 0 { ttl = 300 } tx, err := d.Begin() if err != nil { return nil, fmt.Errorf("begin tx: %w", err) } defer func() { _ = tx.Rollback() }() // Enforce CNAME exclusivity. if recordType == "CNAME" { hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"}) if err != nil { return nil, fmt.Errorf("check cname exclusivity: %w", err) } if hasAddr { return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name) } } else if recordType == "A" || recordType == "AAAA" { hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"}) if err != nil { return nil, fmt.Errorf("check cname exclusivity: %w", err) } if hasCNAME { return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name) } } res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`, zone.ID, name, recordType, value, ttl) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint") { return nil, fmt.Errorf("%w: record already exists", ErrConflict) } return nil, fmt.Errorf("insert record: %w", err) } recordID, err := res.LastInsertId() if err != nil { return nil, fmt.Errorf("last insert id: %w", err) } if err := d.BumpSerial(tx, zone.ID); err != nil { return nil, fmt.Errorf("bump serial: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("commit: %w", err) } return d.GetRecord(recordID) } // UpdateRecord updates an existing record's fields and bumps the zone serial. func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) { name = strings.ToLower(name) recordType = strings.ToUpper(recordType) if err := validateRecordValue(recordType, value); err != nil { return nil, err } existing, err := d.GetRecord(id) if err != nil { return nil, err } if ttl <= 0 { ttl = 300 } tx, err := d.Begin() if err != nil { return nil, fmt.Errorf("begin tx: %w", err) } defer func() { _ = tx.Rollback() }() // Enforce CNAME exclusivity for the new type/name combo. if recordType == "CNAME" { hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"}) if err != nil { return nil, fmt.Errorf("check cname exclusivity: %w", err) } if hasAddr { return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name) } } else if recordType == "A" || recordType == "AAAA" { hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"}) if err != nil { return nil, fmt.Errorf("check cname exclusivity: %w", err) } if hasCNAME { return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name) } } now := time.Now().UTC().Format("2006-01-02T15:04:05Z") _, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`, name, recordType, value, ttl, now, id) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint") { return nil, fmt.Errorf("%w: record already exists", ErrConflict) } return nil, fmt.Errorf("update record: %w", err) } if err := d.BumpSerial(tx, existing.ZoneID); err != nil { return nil, fmt.Errorf("bump serial: %w", err) } if err := tx.Commit(); err != nil { return nil, fmt.Errorf("commit: %w", err) } return d.GetRecord(id) } // DeleteRecord deletes a record and bumps the zone serial. func (d *DB) DeleteRecord(id int64) error { existing, err := d.GetRecord(id) if err != nil { return err } tx, err := d.Begin() if err != nil { return fmt.Errorf("begin tx: %w", err) } defer func() { _ = tx.Rollback() }() _, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id) if err != nil { return fmt.Errorf("delete record: %w", err) } if err := d.BumpSerial(tx, existing.ZoneID); err != nil { return fmt.Errorf("bump serial: %w", err) } return tx.Commit() } // validateRecordValue checks that a record value is valid for its type. func validateRecordValue(recordType, value string) error { switch recordType { case "A": ip := net.ParseIP(value) if ip == nil || ip.To4() == nil { return fmt.Errorf("invalid IPv4 address: %q", value) } case "AAAA": ip := net.ParseIP(value) if ip == nil || ip.To4() != nil { return fmt.Errorf("invalid IPv6 address: %q", value) } case "CNAME": if !strings.HasSuffix(value, ".") { return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value) } default: return fmt.Errorf("unsupported record type: %q", recordType) } return nil }