Implement MCNS v1: custom Go DNS server replacing CoreDNS
Replace the CoreDNS precursor with a purpose-built authoritative DNS server. Zones and records (A, AAAA, CNAME) are stored in SQLite and managed via synchronized gRPC + REST APIs authenticated through MCIAS. Non-authoritative queries are forwarded to upstream resolvers with in-memory caching. Key components: - DNS server (miekg/dns) with authoritative zone handling and forwarding - gRPC + REST management APIs with MCIAS auth (mcdsl integration) - SQLite storage with CNAME exclusivity enforcement and auto SOA serials - 30 tests covering database CRUD, DNS resolution, and caching Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
23
internal/db/db.go
Normal file
23
internal/db/db.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
||||
)
|
||||
|
||||
// DB wraps a SQLite database connection.
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// Open opens (or creates) a SQLite database at the given path with the
|
||||
// standard Metacircular pragmas: WAL mode, foreign keys, busy timeout.
|
||||
func Open(path string) (*DB, error) {
|
||||
sqlDB, err := mcdsldb.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: %w", err)
|
||||
}
|
||||
return &DB{sqlDB}, nil
|
||||
}
|
||||
46
internal/db/migrate.go
Normal file
46
internal/db/migrate.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
||||
)
|
||||
|
||||
// Migrations is the ordered list of MCNS schema migrations.
|
||||
var Migrations = []mcdsldb.Migration{
|
||||
{
|
||||
Version: 1,
|
||||
Name: "zones and records",
|
||||
SQL: `
|
||||
CREATE TABLE IF NOT EXISTS zones (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
primary_ns TEXT NOT NULL,
|
||||
admin_email TEXT NOT NULL,
|
||||
refresh INTEGER NOT NULL DEFAULT 3600,
|
||||
retry INTEGER NOT NULL DEFAULT 600,
|
||||
expire INTEGER NOT NULL DEFAULT 86400,
|
||||
minimum_ttl INTEGER NOT NULL DEFAULT 300,
|
||||
serial INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
zone_id INTEGER NOT NULL REFERENCES zones(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK (type IN ('A', 'AAAA', 'CNAME')),
|
||||
value TEXT NOT NULL,
|
||||
ttl INTEGER NOT NULL DEFAULT 300,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
UNIQUE(zone_id, name, type, value)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_records_zone_name ON records(zone_id, name);`,
|
||||
},
|
||||
}
|
||||
|
||||
// Migrate applies all pending migrations.
|
||||
func (d *DB) Migrate() error {
|
||||
return mcdsldb.Migrate(d.DB, Migrations)
|
||||
}
|
||||
308
internal/db/records.go
Normal file
308
internal/db/records.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Record represents a DNS record stored in the database.
|
||||
type Record struct {
|
||||
ID int64
|
||||
ZoneID int64
|
||||
ZoneName string
|
||||
Name string
|
||||
Type string
|
||||
Value string
|
||||
TTL int
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
// ListRecords returns records for a zone, optionally filtered by name and type.
|
||||
func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
|
||||
zone, err := d.GetZone(zoneName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?`
|
||||
args := []any{zone.ID}
|
||||
|
||||
if name != "" {
|
||||
query += ` AND r.name = ?`
|
||||
args = append(args, strings.ToLower(name))
|
||||
}
|
||||
if recordType != "" {
|
||||
query += ` AND r.type = ?`
|
||||
args = append(args, strings.ToUpper(recordType))
|
||||
}
|
||||
|
||||
query += ` ORDER BY r.name, r.type, r.value`
|
||||
|
||||
rows, err := d.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []Record
|
||||
for rows.Next() {
|
||||
var r Record
|
||||
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan record: %w", err)
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// LookupRecords returns records matching a name and type within a zone.
|
||||
// Used by the DNS handler for query resolution.
|
||||
func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
name = strings.ToLower(name)
|
||||
|
||||
rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id
|
||||
WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []Record
|
||||
for rows.Next() {
|
||||
var r Record
|
||||
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan record: %w", err)
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// LookupCNAME returns CNAME records for a name within a zone.
|
||||
func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) {
|
||||
return d.LookupRecords(zoneName, name, "CNAME")
|
||||
}
|
||||
|
||||
// HasRecordsForName checks if any records of the given types exist for a name.
|
||||
func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) {
|
||||
placeholders := make([]string, len(types))
|
||||
args := []any{zoneID, strings.ToLower(name)}
|
||||
for i, t := range types {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, strings.ToUpper(t))
|
||||
}
|
||||
query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ","))
|
||||
|
||||
var count int
|
||||
if err := tx.QueryRow(query, args...).Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetRecord returns a record by ID.
|
||||
func (d *DB) GetRecord(id int64) (*Record, error) {
|
||||
var r Record
|
||||
err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id).
|
||||
Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get record: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// CreateRecord inserts a new record, enforcing CNAME exclusivity and
|
||||
// value validation. Bumps the zone serial within the same transaction.
|
||||
func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
name = strings.ToLower(name)
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
if err := validateRecordValue(recordType, value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone, err := d.GetZone(zoneName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = 300
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Enforce CNAME exclusivity.
|
||||
if recordType == "CNAME" {
|
||||
hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasAddr {
|
||||
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
||||
}
|
||||
} else if recordType == "A" || recordType == "AAAA" {
|
||||
hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasCNAME {
|
||||
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`,
|
||||
zone.ID, name, recordType, value, ttl)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
||||
}
|
||||
return nil, fmt.Errorf("insert record: %w", err)
|
||||
}
|
||||
|
||||
recordID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last insert id: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, zone.ID); err != nil {
|
||||
return nil, fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return d.GetRecord(recordID)
|
||||
}
|
||||
|
||||
// UpdateRecord updates an existing record's fields and bumps the zone serial.
|
||||
func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) {
|
||||
name = strings.ToLower(name)
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
if err := validateRecordValue(recordType, value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := d.GetRecord(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = 300
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Enforce CNAME exclusivity for the new type/name combo.
|
||||
if recordType == "CNAME" {
|
||||
hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasAddr {
|
||||
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
||||
}
|
||||
} else if recordType == "A" || recordType == "AAAA" {
|
||||
hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasCNAME {
|
||||
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
_, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`,
|
||||
name, recordType, value, ttl, now, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
||||
}
|
||||
return nil, fmt.Errorf("update record: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
||||
return nil, fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return d.GetRecord(id)
|
||||
}
|
||||
|
||||
// DeleteRecord deletes a record and bumps the zone serial.
|
||||
func (d *DB) DeleteRecord(id int64) error {
|
||||
existing, err := d.GetRecord(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
_, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete record: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
||||
return fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// validateRecordValue checks that a record value is valid for its type.
|
||||
func validateRecordValue(recordType, value string) error {
|
||||
switch recordType {
|
||||
case "A":
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return fmt.Errorf("invalid IPv4 address: %q", value)
|
||||
}
|
||||
case "AAAA":
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return fmt.Errorf("invalid IPv6 address: %q", value)
|
||||
}
|
||||
case "CNAME":
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported record type: %q", recordType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
289
internal/db/records_test.go
Normal file
289
internal/db/records_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createTestZone(t *testing.T, db *DB) *Zone {
|
||||
t.Helper()
|
||||
zone, err := db.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
return zone
|
||||
}
|
||||
|
||||
func TestCreateRecordA(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Name != "metacrypt" {
|
||||
t.Fatalf("got name %q, want %q", record.Name, "metacrypt")
|
||||
}
|
||||
if record.Type != "A" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "A")
|
||||
}
|
||||
if record.Value != "192.168.88.181" {
|
||||
t.Fatalf("got value %q, want %q", record.Value, "192.168.88.181")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordAAAA(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "AAAA", "2001:db8::1", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Type != "AAAA" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "AAAA")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAME(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Type != "CNAME" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "CNAME")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordInvalidIP(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "bad", "A", "not-an-ip", 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IPv4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEExclusivity(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
// Create an A record first.
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create A record: %v", err)
|
||||
}
|
||||
|
||||
// Trying to add a CNAME for the same name should fail.
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if !errors.Is(err, ErrConflict) {
|
||||
t.Fatalf("expected ErrConflict, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEExclusivityReverse(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
// Create a CNAME record first.
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create CNAME record: %v", err)
|
||||
}
|
||||
|
||||
// Trying to add an A record for the same name should fail.
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "alias", "A", "192.168.88.181", 300)
|
||||
if !errors.Is(err, ErrConflict) {
|
||||
t.Fatalf("expected ErrConflict, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordBumpsSerial(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
zone := createTestZone(t, db)
|
||||
originalSerial := zone.Serial
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
if updated.Serial <= originalSerial {
|
||||
t.Fatalf("serial should have bumped: %d <= %d", updated.Serial, originalSerial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 1: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 2: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "mcr", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 3: %v", err)
|
||||
}
|
||||
|
||||
// List all records.
|
||||
records, err := db.ListRecords("svc.mcp.metacircular.net", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list records: %v", err)
|
||||
}
|
||||
if len(records) != 3 {
|
||||
t.Fatalf("got %d records, want 3", len(records))
|
||||
}
|
||||
|
||||
// Filter by name.
|
||||
records, err = db.ListRecords("svc.mcp.metacircular.net", "metacrypt", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list records by name: %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
|
||||
// Filter by type.
|
||||
records, err = db.ListRecords("svc.mcp.metacircular.net", "", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("list records by type: %v", err)
|
||||
}
|
||||
if len(records) != 3 {
|
||||
t.Fatalf("got %d records, want 3", len(records))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRecord(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.UpdateRecord(record.ID, "metacrypt", "A", "10.0.0.1", 600)
|
||||
if err != nil {
|
||||
t.Fatalf("update record: %v", err)
|
||||
}
|
||||
if updated.Value != "10.0.0.1" {
|
||||
t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.1")
|
||||
}
|
||||
if updated.TTL != 600 {
|
||||
t.Fatalf("got ttl %d, want 600", updated.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRecord(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteRecord(record.ID); err != nil {
|
||||
t.Fatalf("delete record: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetRecord(record.ID)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRecordBumpsSerial(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
zone, err := db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
serialBefore := zone.Serial
|
||||
|
||||
if err := db.DeleteRecord(record.ID); err != nil {
|
||||
t.Fatalf("delete record: %v", err)
|
||||
}
|
||||
|
||||
zone, err = db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone after delete: %v", err)
|
||||
}
|
||||
if zone.Serial <= serialBefore {
|
||||
t.Fatalf("serial should have bumped: %d <= %d", zone.Serial, serialBefore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("lookup records: %v", err)
|
||||
}
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
if records[0].Value != "192.168.88.181" {
|
||||
t.Fatalf("got value %q, want %q", records[0].Value, "192.168.88.181")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEMissingDot(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net", 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for CNAME without trailing dot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleARecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create first A record: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create second A record: %v", err)
|
||||
}
|
||||
|
||||
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("lookup records: %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
}
|
||||
182
internal/db/zones.go
Normal file
182
internal/db/zones.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Zone represents a DNS zone stored in the database.
|
||||
type Zone struct {
|
||||
ID int64
|
||||
Name string
|
||||
PrimaryNS string
|
||||
AdminEmail string
|
||||
Refresh int
|
||||
Retry int
|
||||
Expire int
|
||||
MinimumTTL int
|
||||
Serial int64
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
// ErrNotFound is returned when a requested resource does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// ErrConflict is returned when a write conflicts with existing data.
|
||||
var ErrConflict = errors.New("conflict")
|
||||
|
||||
// ListZones returns all zones ordered by name.
|
||||
func (d *DB) ListZones() ([]Zone, error) {
|
||||
rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list zones: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var zones []Zone
|
||||
for rows.Next() {
|
||||
var z Zone
|
||||
if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan zone: %w", err)
|
||||
}
|
||||
zones = append(zones, z)
|
||||
}
|
||||
return zones, rows.Err()
|
||||
}
|
||||
|
||||
// GetZone returns a zone by name (case-insensitive).
|
||||
func (d *DB) GetZone(name string) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
var z Zone
|
||||
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name).
|
||||
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get zone: %w", err)
|
||||
}
|
||||
return &z, nil
|
||||
}
|
||||
|
||||
// GetZoneByID returns a zone by ID.
|
||||
func (d *DB) GetZoneByID(id int64) (*Zone, error) {
|
||||
var z Zone
|
||||
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id).
|
||||
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get zone by id: %w", err)
|
||||
}
|
||||
return &z, nil
|
||||
}
|
||||
|
||||
// CreateZone inserts a new zone and returns it with the generated serial.
|
||||
func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
serial := nextSerial(0)
|
||||
|
||||
res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name)
|
||||
}
|
||||
return nil, fmt.Errorf("create zone: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create zone: last insert id: %w", err)
|
||||
}
|
||||
|
||||
return d.GetZoneByID(id)
|
||||
}
|
||||
|
||||
// UpdateZone updates a zone's SOA parameters and bumps the serial.
|
||||
func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
|
||||
zone, err := d.GetZone(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serial := nextSerial(zone.Serial)
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
|
||||
_, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`,
|
||||
primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update zone: %w", err)
|
||||
}
|
||||
|
||||
return d.GetZoneByID(zone.ID)
|
||||
}
|
||||
|
||||
// DeleteZone deletes a zone and all its records.
|
||||
func (d *DB) DeleteZone(name string) error {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete zone: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete zone: rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BumpSerial increments the serial for a zone within a transaction.
|
||||
func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error {
|
||||
var current int64
|
||||
if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(¤t); err != nil {
|
||||
return fmt.Errorf("read serial: %w", err)
|
||||
}
|
||||
serial := nextSerial(current)
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
_, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ZoneNames returns all zone names for the DNS handler.
|
||||
func (d *DB) ZoneNames() ([]string, error) {
|
||||
rows, err := d.Query(`SELECT name FROM zones ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var names []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return names, rows.Err()
|
||||
}
|
||||
|
||||
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
|
||||
func nextSerial(current int64) int64 {
|
||||
today := time.Now().UTC()
|
||||
datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64)
|
||||
datePrefix *= 100 // YYYYMMDD00
|
||||
|
||||
if current >= datePrefix {
|
||||
return current + 1
|
||||
}
|
||||
return datePrefix + 1
|
||||
}
|
||||
167
internal/db/zones_test.go
Normal file
167
internal/db/zones_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func openTestDB(t *testing.T) *DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = database.Close() })
|
||||
return database
|
||||
}
|
||||
|
||||
func TestCreateZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
zone, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
if zone.Serial == 0 {
|
||||
t.Fatal("serial should not be zero")
|
||||
}
|
||||
if zone.PrimaryNS != "ns1.example.com." {
|
||||
t.Fatalf("got primary_ns %q, want %q", zone.PrimaryNS, "ns1.example.com.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateZoneDuplicate(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate zone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateZoneNormalization(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
zone, err := db.CreateZone("Example.COM.", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListZones(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("b.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone b: %v", err)
|
||||
}
|
||||
_, err = db.CreateZone("a.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone a: %v", err)
|
||||
}
|
||||
|
||||
zones, err := db.ListZones()
|
||||
if err != nil {
|
||||
t.Fatalf("list zones: %v", err)
|
||||
}
|
||||
if len(zones) != 2 {
|
||||
t.Fatalf("got %d zones, want 2", len(zones))
|
||||
}
|
||||
if zones[0].Name != "a.example.com" {
|
||||
t.Fatalf("zones should be ordered by name, got %q first", zones[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
zone, err := db.GetZone("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
|
||||
_, err = db.GetZone("nonexistent.com")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
original, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.UpdateZone("example.com", "ns2.example.com.", "newadmin.example.com.", 7200, 1200, 172800, 600)
|
||||
if err != nil {
|
||||
t.Fatalf("update zone: %v", err)
|
||||
}
|
||||
|
||||
if updated.PrimaryNS != "ns2.example.com." {
|
||||
t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNS, "ns2.example.com.")
|
||||
}
|
||||
if updated.Serial <= original.Serial {
|
||||
t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteZone("example.com"); err != nil {
|
||||
t.Fatalf("delete zone: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetZone("example.com")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteZone("nonexistent.com"); err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound for nonexistent zone, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextSerial(t *testing.T) {
|
||||
// A zero serial should produce a date-based serial.
|
||||
s1 := nextSerial(0)
|
||||
if s1 < 2026032600 {
|
||||
t.Fatalf("serial %d seems too low", s1)
|
||||
}
|
||||
|
||||
// Incrementing should increase.
|
||||
s2 := nextSerial(s1)
|
||||
if s2 != s1+1 {
|
||||
t.Fatalf("expected %d, got %d", s1+1, s2)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user