Implement MCNS v1: custom Go DNS server replacing CoreDNS

Replace the CoreDNS precursor with a purpose-built authoritative DNS
server. Zones and records (A, AAAA, CNAME) are stored in SQLite and
managed via synchronized gRPC + REST APIs authenticated through MCIAS.
Non-authoritative queries are forwarded to upstream resolvers with
in-memory caching.

Key components:
- DNS server (miekg/dns) with authoritative zone handling and forwarding
- gRPC + REST management APIs with MCIAS auth (mcdsl integration)
- SQLite storage with CNAME exclusivity enforcement and auto SOA serials
- 30 tests covering database CRUD, DNS resolution, and caching

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 18:37:14 -07:00
parent a545fec658
commit f9635578e0
48 changed files with 6015 additions and 87 deletions

23
internal/db/db.go Normal file
View File

@@ -0,0 +1,23 @@
package db
import (
"database/sql"
"fmt"
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
// DB wraps a SQLite database connection.
type DB struct {
*sql.DB
}
// Open opens (or creates) a SQLite database at the given path with the
// standard Metacircular pragmas: WAL mode, foreign keys, busy timeout.
func Open(path string) (*DB, error) {
sqlDB, err := mcdsldb.Open(path)
if err != nil {
return nil, fmt.Errorf("db: %w", err)
}
return &DB{sqlDB}, nil
}

46
internal/db/migrate.go Normal file
View File

@@ -0,0 +1,46 @@
package db
import (
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
// Migrations is the ordered list of MCNS schema migrations.
var Migrations = []mcdsldb.Migration{
{
Version: 1,
Name: "zones and records",
SQL: `
CREATE TABLE IF NOT EXISTS zones (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
primary_ns TEXT NOT NULL,
admin_email TEXT NOT NULL,
refresh INTEGER NOT NULL DEFAULT 3600,
retry INTEGER NOT NULL DEFAULT 600,
expire INTEGER NOT NULL DEFAULT 86400,
minimum_ttl INTEGER NOT NULL DEFAULT 300,
serial INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
);
CREATE TABLE IF NOT EXISTS records (
id INTEGER PRIMARY KEY,
zone_id INTEGER NOT NULL REFERENCES zones(id) ON DELETE CASCADE,
name TEXT NOT NULL,
type TEXT NOT NULL CHECK (type IN ('A', 'AAAA', 'CNAME')),
value TEXT NOT NULL,
ttl INTEGER NOT NULL DEFAULT 300,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
UNIQUE(zone_id, name, type, value)
);
CREATE INDEX IF NOT EXISTS idx_records_zone_name ON records(zone_id, name);`,
},
}
// Migrate applies all pending migrations.
func (d *DB) Migrate() error {
return mcdsldb.Migrate(d.DB, Migrations)
}

308
internal/db/records.go Normal file
View File

@@ -0,0 +1,308 @@
package db
import (
"database/sql"
"errors"
"fmt"
"net"
"strings"
"time"
)
// Record represents a DNS record stored in the database.
type Record struct {
ID int64
ZoneID int64
ZoneName string
Name string
Type string
Value string
TTL int
CreatedAt string
UpdatedAt string
}
// ListRecords returns records for a zone, optionally filtered by name and type.
func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
zone, err := d.GetZone(zoneName)
if err != nil {
return nil, err
}
query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?`
args := []any{zone.ID}
if name != "" {
query += ` AND r.name = ?`
args = append(args, strings.ToLower(name))
}
if recordType != "" {
query += ` AND r.type = ?`
args = append(args, strings.ToUpper(recordType))
}
query += ` ORDER BY r.name, r.type, r.value`
rows, err := d.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("list records: %w", err)
}
defer rows.Close()
var records []Record
for rows.Next() {
var r Record
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan record: %w", err)
}
records = append(records, r)
}
return records, rows.Err()
}
// LookupRecords returns records matching a name and type within a zone.
// Used by the DNS handler for query resolution.
func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
name = strings.ToLower(name)
rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id
WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType))
if err != nil {
return nil, fmt.Errorf("lookup records: %w", err)
}
defer rows.Close()
var records []Record
for rows.Next() {
var r Record
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan record: %w", err)
}
records = append(records, r)
}
return records, rows.Err()
}
// LookupCNAME returns CNAME records for a name within a zone.
func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) {
return d.LookupRecords(zoneName, name, "CNAME")
}
// HasRecordsForName checks if any records of the given types exist for a name.
func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) {
placeholders := make([]string, len(types))
args := []any{zoneID, strings.ToLower(name)}
for i, t := range types {
placeholders[i] = "?"
args = append(args, strings.ToUpper(t))
}
query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ","))
var count int
if err := tx.QueryRow(query, args...).Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
// GetRecord returns a record by ID.
func (d *DB) GetRecord(id int64) (*Record, error) {
var r Record
err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id).
Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get record: %w", err)
}
return &r, nil
}
// CreateRecord inserts a new record, enforcing CNAME exclusivity and
// value validation. Bumps the zone serial within the same transaction.
func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
name = strings.ToLower(name)
recordType = strings.ToUpper(recordType)
if err := validateRecordValue(recordType, value); err != nil {
return nil, err
}
zone, err := d.GetZone(zoneName)
if err != nil {
return nil, err
}
if ttl <= 0 {
ttl = 300
}
tx, err := d.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Enforce CNAME exclusivity.
if recordType == "CNAME" {
hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasAddr {
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
}
} else if recordType == "A" || recordType == "AAAA" {
hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasCNAME {
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
}
}
res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`,
zone.ID, name, recordType, value, ttl)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
}
return nil, fmt.Errorf("insert record: %w", err)
}
recordID, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("last insert id: %w", err)
}
if err := d.BumpSerial(tx, zone.ID); err != nil {
return nil, fmt.Errorf("bump serial: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
return d.GetRecord(recordID)
}
// UpdateRecord updates an existing record's fields and bumps the zone serial.
func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) {
name = strings.ToLower(name)
recordType = strings.ToUpper(recordType)
if err := validateRecordValue(recordType, value); err != nil {
return nil, err
}
existing, err := d.GetRecord(id)
if err != nil {
return nil, err
}
if ttl <= 0 {
ttl = 300
}
tx, err := d.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Enforce CNAME exclusivity for the new type/name combo.
if recordType == "CNAME" {
hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasAddr {
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
}
} else if recordType == "A" || recordType == "AAAA" {
hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasCNAME {
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
}
}
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`,
name, recordType, value, ttl, now, id)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
}
return nil, fmt.Errorf("update record: %w", err)
}
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
return nil, fmt.Errorf("bump serial: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
return d.GetRecord(id)
}
// DeleteRecord deletes a record and bumps the zone serial.
func (d *DB) DeleteRecord(id int64) error {
existing, err := d.GetRecord(id)
if err != nil {
return err
}
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
_, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("delete record: %w", err)
}
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
return fmt.Errorf("bump serial: %w", err)
}
return tx.Commit()
}
// validateRecordValue checks that a record value is valid for its type.
func validateRecordValue(recordType, value string) error {
switch recordType {
case "A":
ip := net.ParseIP(value)
if ip == nil || ip.To4() == nil {
return fmt.Errorf("invalid IPv4 address: %q", value)
}
case "AAAA":
ip := net.ParseIP(value)
if ip == nil || ip.To4() != nil {
return fmt.Errorf("invalid IPv6 address: %q", value)
}
case "CNAME":
if !strings.HasSuffix(value, ".") {
return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value)
}
default:
return fmt.Errorf("unsupported record type: %q", recordType)
}
return nil
}

289
internal/db/records_test.go Normal file
View File

@@ -0,0 +1,289 @@
package db
import (
"errors"
"testing"
)
func createTestZone(t *testing.T, db *DB) *Zone {
t.Helper()
zone, err := db.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
return zone
}
func TestCreateRecordA(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Name != "metacrypt" {
t.Fatalf("got name %q, want %q", record.Name, "metacrypt")
}
if record.Type != "A" {
t.Fatalf("got type %q, want %q", record.Type, "A")
}
if record.Value != "192.168.88.181" {
t.Fatalf("got value %q, want %q", record.Value, "192.168.88.181")
}
}
func TestCreateRecordAAAA(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "AAAA", "2001:db8::1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Type != "AAAA" {
t.Fatalf("got type %q, want %q", record.Type, "AAAA")
}
}
func TestCreateRecordCNAME(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Type != "CNAME" {
t.Fatalf("got type %q, want %q", record.Type, "CNAME")
}
}
func TestCreateRecordInvalidIP(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "bad", "A", "not-an-ip", 300)
if err == nil {
t.Fatal("expected error for invalid IPv4")
}
}
func TestCreateRecordCNAMEExclusivity(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
// Create an A record first.
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create A record: %v", err)
}
// Trying to add a CNAME for the same name should fail.
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "CNAME", "rift.mcp.metacircular.net.", 300)
if !errors.Is(err, ErrConflict) {
t.Fatalf("expected ErrConflict, got %v", err)
}
}
func TestCreateRecordCNAMEExclusivityReverse(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
// Create a CNAME record first.
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create CNAME record: %v", err)
}
// Trying to add an A record for the same name should fail.
_, err = db.CreateRecord("svc.mcp.metacircular.net", "alias", "A", "192.168.88.181", 300)
if !errors.Is(err, ErrConflict) {
t.Fatalf("expected ErrConflict, got %v", err)
}
}
func TestCreateRecordBumpsSerial(t *testing.T) {
db := openTestDB(t)
zone := createTestZone(t, db)
originalSerial := zone.Serial
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
updated, err := db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
if updated.Serial <= originalSerial {
t.Fatalf("serial should have bumped: %d <= %d", updated.Serial, originalSerial)
}
}
func TestListRecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record 1: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create record 2: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "mcr", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record 3: %v", err)
}
// List all records.
records, err := db.ListRecords("svc.mcp.metacircular.net", "", "")
if err != nil {
t.Fatalf("list records: %v", err)
}
if len(records) != 3 {
t.Fatalf("got %d records, want 3", len(records))
}
// Filter by name.
records, err = db.ListRecords("svc.mcp.metacircular.net", "metacrypt", "")
if err != nil {
t.Fatalf("list records by name: %v", err)
}
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
// Filter by type.
records, err = db.ListRecords("svc.mcp.metacircular.net", "", "A")
if err != nil {
t.Fatalf("list records by type: %v", err)
}
if len(records) != 3 {
t.Fatalf("got %d records, want 3", len(records))
}
}
func TestUpdateRecord(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
updated, err := db.UpdateRecord(record.ID, "metacrypt", "A", "10.0.0.1", 600)
if err != nil {
t.Fatalf("update record: %v", err)
}
if updated.Value != "10.0.0.1" {
t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.1")
}
if updated.TTL != 600 {
t.Fatalf("got ttl %d, want 600", updated.TTL)
}
}
func TestDeleteRecord(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if err := db.DeleteRecord(record.ID); err != nil {
t.Fatalf("delete record: %v", err)
}
_, err = db.GetRecord(record.ID)
if !errors.Is(err, ErrNotFound) {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestDeleteRecordBumpsSerial(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
zone, err := db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
serialBefore := zone.Serial
if err := db.DeleteRecord(record.ID); err != nil {
t.Fatalf("delete record: %v", err)
}
zone, err = db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone after delete: %v", err)
}
if zone.Serial <= serialBefore {
t.Fatalf("serial should have bumped: %d <= %d", zone.Serial, serialBefore)
}
}
func TestLookupRecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
if err != nil {
t.Fatalf("lookup records: %v", err)
}
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
if records[0].Value != "192.168.88.181" {
t.Fatalf("got value %q, want %q", records[0].Value, "192.168.88.181")
}
}
func TestCreateRecordCNAMEMissingDot(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net", 300)
if err == nil {
t.Fatal("expected error for CNAME without trailing dot")
}
}
func TestMultipleARecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create first A record: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create second A record: %v", err)
}
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
if err != nil {
t.Fatalf("lookup records: %v", err)
}
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
}

182
internal/db/zones.go Normal file
View File

@@ -0,0 +1,182 @@
package db
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// Zone represents a DNS zone stored in the database.
type Zone struct {
ID int64
Name string
PrimaryNS string
AdminEmail string
Refresh int
Retry int
Expire int
MinimumTTL int
Serial int64
CreatedAt string
UpdatedAt string
}
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
// ErrConflict is returned when a write conflicts with existing data.
var ErrConflict = errors.New("conflict")
// ListZones returns all zones ordered by name.
func (d *DB) ListZones() ([]Zone, error) {
rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("list zones: %w", err)
}
defer rows.Close()
var zones []Zone
for rows.Next() {
var z Zone
if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan zone: %w", err)
}
zones = append(zones, z)
}
return zones, rows.Err()
}
// GetZone returns a zone by name (case-insensitive).
func (d *DB) GetZone(name string) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone: %w", err)
}
return &z, nil
}
// GetZoneByID returns a zone by ID.
func (d *DB) GetZoneByID(id int64) (*Zone, error) {
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone by id: %w", err)
}
return &z, nil
}
// CreateZone inserts a new zone and returns it with the generated serial.
func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
serial := nextSerial(0)
res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name)
}
return nil, fmt.Errorf("create zone: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("create zone: last insert id: %w", err)
}
return d.GetZoneByID(id)
}
// UpdateZone updates a zone's SOA parameters and bumps the serial.
func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
zone, err := d.GetZone(name)
if err != nil {
return nil, err
}
serial := nextSerial(zone.Serial)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`,
primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID)
if err != nil {
return nil, fmt.Errorf("update zone: %w", err)
}
return d.GetZoneByID(zone.ID)
}
// DeleteZone deletes a zone and all its records.
func (d *DB) DeleteZone(name string) error {
name = strings.ToLower(strings.TrimSuffix(name, "."))
res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name)
if err != nil {
return fmt.Errorf("delete zone: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("delete zone: rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// BumpSerial increments the serial for a zone within a transaction.
func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error {
var current int64
if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(&current); err != nil {
return fmt.Errorf("read serial: %w", err)
}
serial := nextSerial(current)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID)
return err
}
// ZoneNames returns all zone names for the DNS handler.
func (d *DB) ZoneNames() ([]string, error) {
rows, err := d.Query(`SELECT name FROM zones ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var names []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, rows.Err()
}
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
func nextSerial(current int64) int64 {
today := time.Now().UTC()
datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64)
datePrefix *= 100 // YYYYMMDD00
if current >= datePrefix {
return current + 1
}
return datePrefix + 1
}

167
internal/db/zones_test.go Normal file
View File

@@ -0,0 +1,167 @@
package db
import (
"path/filepath"
"testing"
)
func openTestDB(t *testing.T) *DB {
t.Helper()
dir := t.TempDir()
database, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
func TestCreateZone(t *testing.T) {
db := openTestDB(t)
zone, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
if zone.Serial == 0 {
t.Fatal("serial should not be zero")
}
if zone.PrimaryNS != "ns1.example.com." {
t.Fatalf("got primary_ns %q, want %q", zone.PrimaryNS, "ns1.example.com.")
}
}
func TestCreateZoneDuplicate(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
_, err = db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err == nil {
t.Fatal("expected error for duplicate zone")
}
}
func TestCreateZoneNormalization(t *testing.T) {
db := openTestDB(t)
zone, err := db.CreateZone("Example.COM.", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
}
func TestListZones(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("b.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone b: %v", err)
}
_, err = db.CreateZone("a.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone a: %v", err)
}
zones, err := db.ListZones()
if err != nil {
t.Fatalf("list zones: %v", err)
}
if len(zones) != 2 {
t.Fatalf("got %d zones, want 2", len(zones))
}
if zones[0].Name != "a.example.com" {
t.Fatalf("zones should be ordered by name, got %q first", zones[0].Name)
}
}
func TestGetZone(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
zone, err := db.GetZone("example.com")
if err != nil {
t.Fatalf("get zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
_, err = db.GetZone("nonexistent.com")
if err != ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestUpdateZone(t *testing.T) {
db := openTestDB(t)
original, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
updated, err := db.UpdateZone("example.com", "ns2.example.com.", "newadmin.example.com.", 7200, 1200, 172800, 600)
if err != nil {
t.Fatalf("update zone: %v", err)
}
if updated.PrimaryNS != "ns2.example.com." {
t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNS, "ns2.example.com.")
}
if updated.Serial <= original.Serial {
t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial)
}
}
func TestDeleteZone(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if err := db.DeleteZone("example.com"); err != nil {
t.Fatalf("delete zone: %v", err)
}
_, err = db.GetZone("example.com")
if err != ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
if err := db.DeleteZone("nonexistent.com"); err != ErrNotFound {
t.Fatalf("expected ErrNotFound for nonexistent zone, got %v", err)
}
}
func TestNextSerial(t *testing.T) {
// A zero serial should produce a date-based serial.
s1 := nextSerial(0)
if s1 < 2026032600 {
t.Fatalf("serial %d seems too low", s1)
}
// Incrementing should increase.
s2 := nextSerial(s1)
if s2 != s1+1 {
t.Fatalf("expected %d, got %d", s1+1, s2)
}
}