Files
mcns/internal/db/zones.go
Kyle Isom f9635578e0 Implement MCNS v1: custom Go DNS server replacing CoreDNS
Replace the CoreDNS precursor with a purpose-built authoritative DNS
server. Zones and records (A, AAAA, CNAME) are stored in SQLite and
managed via synchronized gRPC + REST APIs authenticated through MCIAS.
Non-authoritative queries are forwarded to upstream resolvers with
in-memory caching.

Key components:
- DNS server (miekg/dns) with authoritative zone handling and forwarding
- gRPC + REST management APIs with MCIAS auth (mcdsl integration)
- SQLite storage with CNAME exclusivity enforcement and auto SOA serials
- 30 tests covering database CRUD, DNS resolution, and caching

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 18:37:14 -07:00

183 lines
5.6 KiB
Go

package db
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// Zone represents a DNS zone stored in the database.
type Zone struct {
ID int64
Name string
PrimaryNS string
AdminEmail string
Refresh int
Retry int
Expire int
MinimumTTL int
Serial int64
CreatedAt string
UpdatedAt string
}
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
// ErrConflict is returned when a write conflicts with existing data.
var ErrConflict = errors.New("conflict")
// ListZones returns all zones ordered by name.
func (d *DB) ListZones() ([]Zone, error) {
rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("list zones: %w", err)
}
defer rows.Close()
var zones []Zone
for rows.Next() {
var z Zone
if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan zone: %w", err)
}
zones = append(zones, z)
}
return zones, rows.Err()
}
// GetZone returns a zone by name (case-insensitive).
func (d *DB) GetZone(name string) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone: %w", err)
}
return &z, nil
}
// GetZoneByID returns a zone by ID.
func (d *DB) GetZoneByID(id int64) (*Zone, error) {
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone by id: %w", err)
}
return &z, nil
}
// CreateZone inserts a new zone and returns it with the generated serial.
func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
serial := nextSerial(0)
res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name)
}
return nil, fmt.Errorf("create zone: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("create zone: last insert id: %w", err)
}
return d.GetZoneByID(id)
}
// UpdateZone updates a zone's SOA parameters and bumps the serial.
func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
zone, err := d.GetZone(name)
if err != nil {
return nil, err
}
serial := nextSerial(zone.Serial)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`,
primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID)
if err != nil {
return nil, fmt.Errorf("update zone: %w", err)
}
return d.GetZoneByID(zone.ID)
}
// DeleteZone deletes a zone and all its records.
func (d *DB) DeleteZone(name string) error {
name = strings.ToLower(strings.TrimSuffix(name, "."))
res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name)
if err != nil {
return fmt.Errorf("delete zone: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("delete zone: rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// BumpSerial increments the serial for a zone within a transaction.
func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error {
var current int64
if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(&current); err != nil {
return fmt.Errorf("read serial: %w", err)
}
serial := nextSerial(current)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID)
return err
}
// ZoneNames returns all zone names for the DNS handler.
func (d *DB) ZoneNames() ([]string, error) {
rows, err := d.Query(`SELECT name FROM zones ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var names []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, rows.Err()
}
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
func nextSerial(current int64) int64 {
today := time.Now().UTC()
datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64)
datePrefix *= 100 // YYYYMMDD00
if current >= datePrefix {
return current + 1
}
return datePrefix + 1
}