package db import ( "database/sql" "errors" "fmt" "strconv" "strings" "time" ) // Zone represents a DNS zone stored in the database. type Zone struct { ID int64 Name string PrimaryNS string AdminEmail string Refresh int Retry int Expire int MinimumTTL int Serial int64 CreatedAt string UpdatedAt string } // ErrNotFound is returned when a requested resource does not exist. var ErrNotFound = errors.New("not found") // ErrConflict is returned when a write conflicts with existing data. var ErrConflict = errors.New("conflict") // ListZones returns all zones ordered by name. func (d *DB) ListZones() ([]Zone, error) { rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`) if err != nil { return nil, fmt.Errorf("list zones: %w", err) } defer rows.Close() var zones []Zone for rows.Next() { var z Zone if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil { return nil, fmt.Errorf("scan zone: %w", err) } zones = append(zones, z) } return zones, rows.Err() } // GetZone returns a zone by name (case-insensitive). func (d *DB) GetZone(name string) (*Zone, error) { name = strings.ToLower(strings.TrimSuffix(name, ".")) var z Zone err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name). Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get zone: %w", err) } return &z, nil } // GetZoneByID returns a zone by ID. func (d *DB) GetZoneByID(id int64) (*Zone, error) { var z Zone err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id). Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt) if errors.Is(err, sql.ErrNoRows) { return nil, ErrNotFound } if err != nil { return nil, fmt.Errorf("get zone by id: %w", err) } return &z, nil } // CreateZone inserts a new zone and returns it with the generated serial. func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) { name = strings.ToLower(strings.TrimSuffix(name, ".")) serial := nextSerial(0) res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial) if err != nil { if strings.Contains(err.Error(), "UNIQUE constraint") { return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name) } return nil, fmt.Errorf("create zone: %w", err) } id, err := res.LastInsertId() if err != nil { return nil, fmt.Errorf("create zone: last insert id: %w", err) } return d.GetZoneByID(id) } // UpdateZone updates a zone's SOA parameters and bumps the serial. func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) { name = strings.ToLower(strings.TrimSuffix(name, ".")) zone, err := d.GetZone(name) if err != nil { return nil, err } serial := nextSerial(zone.Serial) now := time.Now().UTC().Format("2006-01-02T15:04:05Z") _, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID) if err != nil { return nil, fmt.Errorf("update zone: %w", err) } return d.GetZoneByID(zone.ID) } // DeleteZone deletes a zone and all its records. func (d *DB) DeleteZone(name string) error { name = strings.ToLower(strings.TrimSuffix(name, ".")) res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name) if err != nil { return fmt.Errorf("delete zone: %w", err) } n, err := res.RowsAffected() if err != nil { return fmt.Errorf("delete zone: rows affected: %w", err) } if n == 0 { return ErrNotFound } return nil } // BumpSerial increments the serial for a zone within a transaction. func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error { var current int64 if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(¤t); err != nil { return fmt.Errorf("read serial: %w", err) } serial := nextSerial(current) now := time.Now().UTC().Format("2006-01-02T15:04:05Z") _, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID) return err } // ZoneNames returns all zone names for the DNS handler. func (d *DB) ZoneNames() ([]string, error) { rows, err := d.Query(`SELECT name FROM zones ORDER BY name`) if err != nil { return nil, err } defer rows.Close() var names []string for rows.Next() { var name string if err := rows.Scan(&name); err != nil { return nil, err } names = append(names, name) } return names, rows.Err() } // nextSerial computes the next SOA serial in YYYYMMDDNN format. func nextSerial(current int64) int64 { today := time.Now().UTC() datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64) datePrefix *= 100 // YYYYMMDD00 if current >= datePrefix { return current + 1 } return datePrefix + 1 }