Implement MCNS v1: custom Go DNS server replacing CoreDNS

Replace the CoreDNS precursor with a purpose-built authoritative DNS
server. Zones and records (A, AAAA, CNAME) are stored in SQLite and
managed via synchronized gRPC + REST APIs authenticated through MCIAS.
Non-authoritative queries are forwarded to upstream resolvers with
in-memory caching.

Key components:
- DNS server (miekg/dns) with authoritative zone handling and forwarding
- gRPC + REST management APIs with MCIAS auth (mcdsl integration)
- SQLite storage with CNAME exclusivity enforcement and auto SOA serials
- 30 tests covering database CRUD, DNS resolution, and caching

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 18:37:14 -07:00
parent a545fec658
commit f9635578e0
48 changed files with 6015 additions and 87 deletions

49
internal/config/config.go Normal file
View File

@@ -0,0 +1,49 @@
package config
import (
"fmt"
mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config"
)
// Config is the top-level MCNS configuration.
type Config struct {
mcdslconfig.Base
DNS DNSConfig `toml:"dns"`
}
// DNSConfig holds the DNS server settings.
type DNSConfig struct {
ListenAddr string `toml:"listen_addr"`
Upstreams []string `toml:"upstreams"`
}
// Load reads a TOML config file, applies environment variable overrides
// (MCNS_ prefix), sets defaults, and validates required fields.
func Load(path string) (*Config, error) {
cfg, err := mcdslconfig.Load[Config](path, "MCNS")
if err != nil {
return nil, err
}
// Apply DNS defaults.
if cfg.DNS.ListenAddr == "" {
cfg.DNS.ListenAddr = ":53"
}
if len(cfg.DNS.Upstreams) == 0 {
cfg.DNS.Upstreams = []string{"1.1.1.1:53", "8.8.8.8:53"}
}
return cfg, nil
}
// Validate implements the mcdsl config.Validator interface.
func (c *Config) Validate() error {
if c.Database.Path == "" {
return fmt.Errorf("database.path is required")
}
if c.MCIAS.ServerURL == "" {
return fmt.Errorf("mcias.server_url is required")
}
return nil
}

23
internal/db/db.go Normal file
View File

@@ -0,0 +1,23 @@
package db
import (
"database/sql"
"fmt"
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
// DB wraps a SQLite database connection.
type DB struct {
*sql.DB
}
// Open opens (or creates) a SQLite database at the given path with the
// standard Metacircular pragmas: WAL mode, foreign keys, busy timeout.
func Open(path string) (*DB, error) {
sqlDB, err := mcdsldb.Open(path)
if err != nil {
return nil, fmt.Errorf("db: %w", err)
}
return &DB{sqlDB}, nil
}

46
internal/db/migrate.go Normal file
View File

@@ -0,0 +1,46 @@
package db
import (
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
)
// Migrations is the ordered list of MCNS schema migrations.
var Migrations = []mcdsldb.Migration{
{
Version: 1,
Name: "zones and records",
SQL: `
CREATE TABLE IF NOT EXISTS zones (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL UNIQUE,
primary_ns TEXT NOT NULL,
admin_email TEXT NOT NULL,
refresh INTEGER NOT NULL DEFAULT 3600,
retry INTEGER NOT NULL DEFAULT 600,
expire INTEGER NOT NULL DEFAULT 86400,
minimum_ttl INTEGER NOT NULL DEFAULT 300,
serial INTEGER NOT NULL DEFAULT 0,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
);
CREATE TABLE IF NOT EXISTS records (
id INTEGER PRIMARY KEY,
zone_id INTEGER NOT NULL REFERENCES zones(id) ON DELETE CASCADE,
name TEXT NOT NULL,
type TEXT NOT NULL CHECK (type IN ('A', 'AAAA', 'CNAME')),
value TEXT NOT NULL,
ttl INTEGER NOT NULL DEFAULT 300,
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
UNIQUE(zone_id, name, type, value)
);
CREATE INDEX IF NOT EXISTS idx_records_zone_name ON records(zone_id, name);`,
},
}
// Migrate applies all pending migrations.
func (d *DB) Migrate() error {
return mcdsldb.Migrate(d.DB, Migrations)
}

308
internal/db/records.go Normal file
View File

@@ -0,0 +1,308 @@
package db
import (
"database/sql"
"errors"
"fmt"
"net"
"strings"
"time"
)
// Record represents a DNS record stored in the database.
type Record struct {
ID int64
ZoneID int64
ZoneName string
Name string
Type string
Value string
TTL int
CreatedAt string
UpdatedAt string
}
// ListRecords returns records for a zone, optionally filtered by name and type.
func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
zone, err := d.GetZone(zoneName)
if err != nil {
return nil, err
}
query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?`
args := []any{zone.ID}
if name != "" {
query += ` AND r.name = ?`
args = append(args, strings.ToLower(name))
}
if recordType != "" {
query += ` AND r.type = ?`
args = append(args, strings.ToUpper(recordType))
}
query += ` ORDER BY r.name, r.type, r.value`
rows, err := d.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("list records: %w", err)
}
defer rows.Close()
var records []Record
for rows.Next() {
var r Record
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan record: %w", err)
}
records = append(records, r)
}
return records, rows.Err()
}
// LookupRecords returns records matching a name and type within a zone.
// Used by the DNS handler for query resolution.
func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
name = strings.ToLower(name)
rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id
WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType))
if err != nil {
return nil, fmt.Errorf("lookup records: %w", err)
}
defer rows.Close()
var records []Record
for rows.Next() {
var r Record
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan record: %w", err)
}
records = append(records, r)
}
return records, rows.Err()
}
// LookupCNAME returns CNAME records for a name within a zone.
func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) {
return d.LookupRecords(zoneName, name, "CNAME")
}
// HasRecordsForName checks if any records of the given types exist for a name.
func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) {
placeholders := make([]string, len(types))
args := []any{zoneID, strings.ToLower(name)}
for i, t := range types {
placeholders[i] = "?"
args = append(args, strings.ToUpper(t))
}
query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ","))
var count int
if err := tx.QueryRow(query, args...).Scan(&count); err != nil {
return false, err
}
return count > 0, nil
}
// GetRecord returns a record by ID.
func (d *DB) GetRecord(id int64) (*Record, error) {
var r Record
err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id).
Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get record: %w", err)
}
return &r, nil
}
// CreateRecord inserts a new record, enforcing CNAME exclusivity and
// value validation. Bumps the zone serial within the same transaction.
func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) {
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
name = strings.ToLower(name)
recordType = strings.ToUpper(recordType)
if err := validateRecordValue(recordType, value); err != nil {
return nil, err
}
zone, err := d.GetZone(zoneName)
if err != nil {
return nil, err
}
if ttl <= 0 {
ttl = 300
}
tx, err := d.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Enforce CNAME exclusivity.
if recordType == "CNAME" {
hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasAddr {
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
}
} else if recordType == "A" || recordType == "AAAA" {
hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasCNAME {
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
}
}
res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`,
zone.ID, name, recordType, value, ttl)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
}
return nil, fmt.Errorf("insert record: %w", err)
}
recordID, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("last insert id: %w", err)
}
if err := d.BumpSerial(tx, zone.ID); err != nil {
return nil, fmt.Errorf("bump serial: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
return d.GetRecord(recordID)
}
// UpdateRecord updates an existing record's fields and bumps the zone serial.
func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) {
name = strings.ToLower(name)
recordType = strings.ToUpper(recordType)
if err := validateRecordValue(recordType, value); err != nil {
return nil, err
}
existing, err := d.GetRecord(id)
if err != nil {
return nil, err
}
if ttl <= 0 {
ttl = 300
}
tx, err := d.Begin()
if err != nil {
return nil, fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
// Enforce CNAME exclusivity for the new type/name combo.
if recordType == "CNAME" {
hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasAddr {
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
}
} else if recordType == "A" || recordType == "AAAA" {
hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"})
if err != nil {
return nil, fmt.Errorf("check cname exclusivity: %w", err)
}
if hasCNAME {
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
}
}
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`,
name, recordType, value, ttl, now, id)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
}
return nil, fmt.Errorf("update record: %w", err)
}
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
return nil, fmt.Errorf("bump serial: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit: %w", err)
}
return d.GetRecord(id)
}
// DeleteRecord deletes a record and bumps the zone serial.
func (d *DB) DeleteRecord(id int64) error {
existing, err := d.GetRecord(id)
if err != nil {
return err
}
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("begin tx: %w", err)
}
defer func() { _ = tx.Rollback() }()
_, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id)
if err != nil {
return fmt.Errorf("delete record: %w", err)
}
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
return fmt.Errorf("bump serial: %w", err)
}
return tx.Commit()
}
// validateRecordValue checks that a record value is valid for its type.
func validateRecordValue(recordType, value string) error {
switch recordType {
case "A":
ip := net.ParseIP(value)
if ip == nil || ip.To4() == nil {
return fmt.Errorf("invalid IPv4 address: %q", value)
}
case "AAAA":
ip := net.ParseIP(value)
if ip == nil || ip.To4() != nil {
return fmt.Errorf("invalid IPv6 address: %q", value)
}
case "CNAME":
if !strings.HasSuffix(value, ".") {
return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value)
}
default:
return fmt.Errorf("unsupported record type: %q", recordType)
}
return nil
}

289
internal/db/records_test.go Normal file
View File

@@ -0,0 +1,289 @@
package db
import (
"errors"
"testing"
)
func createTestZone(t *testing.T, db *DB) *Zone {
t.Helper()
zone, err := db.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
return zone
}
func TestCreateRecordA(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Name != "metacrypt" {
t.Fatalf("got name %q, want %q", record.Name, "metacrypt")
}
if record.Type != "A" {
t.Fatalf("got type %q, want %q", record.Type, "A")
}
if record.Value != "192.168.88.181" {
t.Fatalf("got value %q, want %q", record.Value, "192.168.88.181")
}
}
func TestCreateRecordAAAA(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "AAAA", "2001:db8::1", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Type != "AAAA" {
t.Fatalf("got type %q, want %q", record.Type, "AAAA")
}
}
func TestCreateRecordCNAME(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if record.Type != "CNAME" {
t.Fatalf("got type %q, want %q", record.Type, "CNAME")
}
}
func TestCreateRecordInvalidIP(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "bad", "A", "not-an-ip", 300)
if err == nil {
t.Fatal("expected error for invalid IPv4")
}
}
func TestCreateRecordCNAMEExclusivity(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
// Create an A record first.
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create A record: %v", err)
}
// Trying to add a CNAME for the same name should fail.
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "CNAME", "rift.mcp.metacircular.net.", 300)
if !errors.Is(err, ErrConflict) {
t.Fatalf("expected ErrConflict, got %v", err)
}
}
func TestCreateRecordCNAMEExclusivityReverse(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
// Create a CNAME record first.
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create CNAME record: %v", err)
}
// Trying to add an A record for the same name should fail.
_, err = db.CreateRecord("svc.mcp.metacircular.net", "alias", "A", "192.168.88.181", 300)
if !errors.Is(err, ErrConflict) {
t.Fatalf("expected ErrConflict, got %v", err)
}
}
func TestCreateRecordBumpsSerial(t *testing.T) {
db := openTestDB(t)
zone := createTestZone(t, db)
originalSerial := zone.Serial
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
updated, err := db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
if updated.Serial <= originalSerial {
t.Fatalf("serial should have bumped: %d <= %d", updated.Serial, originalSerial)
}
}
func TestListRecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record 1: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create record 2: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "mcr", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record 3: %v", err)
}
// List all records.
records, err := db.ListRecords("svc.mcp.metacircular.net", "", "")
if err != nil {
t.Fatalf("list records: %v", err)
}
if len(records) != 3 {
t.Fatalf("got %d records, want 3", len(records))
}
// Filter by name.
records, err = db.ListRecords("svc.mcp.metacircular.net", "metacrypt", "")
if err != nil {
t.Fatalf("list records by name: %v", err)
}
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
// Filter by type.
records, err = db.ListRecords("svc.mcp.metacircular.net", "", "A")
if err != nil {
t.Fatalf("list records by type: %v", err)
}
if len(records) != 3 {
t.Fatalf("got %d records, want 3", len(records))
}
}
func TestUpdateRecord(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
updated, err := db.UpdateRecord(record.ID, "metacrypt", "A", "10.0.0.1", 600)
if err != nil {
t.Fatalf("update record: %v", err)
}
if updated.Value != "10.0.0.1" {
t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.1")
}
if updated.TTL != 600 {
t.Fatalf("got ttl %d, want 600", updated.TTL)
}
}
func TestDeleteRecord(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
if err := db.DeleteRecord(record.ID); err != nil {
t.Fatalf("delete record: %v", err)
}
_, err = db.GetRecord(record.ID)
if !errors.Is(err, ErrNotFound) {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
}
func TestDeleteRecordBumpsSerial(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
zone, err := db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
serialBefore := zone.Serial
if err := db.DeleteRecord(record.ID); err != nil {
t.Fatalf("delete record: %v", err)
}
zone, err = db.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone after delete: %v", err)
}
if zone.Serial <= serialBefore {
t.Fatalf("serial should have bumped: %d <= %d", zone.Serial, serialBefore)
}
}
func TestLookupRecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create record: %v", err)
}
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
if err != nil {
t.Fatalf("lookup records: %v", err)
}
if len(records) != 1 {
t.Fatalf("got %d records, want 1", len(records))
}
if records[0].Value != "192.168.88.181" {
t.Fatalf("got value %q, want %q", records[0].Value, "192.168.88.181")
}
}
func TestCreateRecordCNAMEMissingDot(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net", 300)
if err == nil {
t.Fatal("expected error for CNAME without trailing dot")
}
}
func TestMultipleARecords(t *testing.T) {
db := openTestDB(t)
createTestZone(t, db)
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create first A record: %v", err)
}
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create second A record: %v", err)
}
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
if err != nil {
t.Fatalf("lookup records: %v", err)
}
if len(records) != 2 {
t.Fatalf("got %d records, want 2", len(records))
}
}

182
internal/db/zones.go Normal file
View File

@@ -0,0 +1,182 @@
package db
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"time"
)
// Zone represents a DNS zone stored in the database.
type Zone struct {
ID int64
Name string
PrimaryNS string
AdminEmail string
Refresh int
Retry int
Expire int
MinimumTTL int
Serial int64
CreatedAt string
UpdatedAt string
}
// ErrNotFound is returned when a requested resource does not exist.
var ErrNotFound = errors.New("not found")
// ErrConflict is returned when a write conflicts with existing data.
var ErrConflict = errors.New("conflict")
// ListZones returns all zones ordered by name.
func (d *DB) ListZones() ([]Zone, error) {
rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`)
if err != nil {
return nil, fmt.Errorf("list zones: %w", err)
}
defer rows.Close()
var zones []Zone
for rows.Next() {
var z Zone
if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil {
return nil, fmt.Errorf("scan zone: %w", err)
}
zones = append(zones, z)
}
return zones, rows.Err()
}
// GetZone returns a zone by name (case-insensitive).
func (d *DB) GetZone(name string) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone: %w", err)
}
return &z, nil
}
// GetZoneByID returns a zone by ID.
func (d *DB) GetZoneByID(id int64) (*Zone, error) {
var z Zone
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id).
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, fmt.Errorf("get zone by id: %w", err)
}
return &z, nil
}
// CreateZone inserts a new zone and returns it with the generated serial.
func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
serial := nextSerial(0)
res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial)
if err != nil {
if strings.Contains(err.Error(), "UNIQUE constraint") {
return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name)
}
return nil, fmt.Errorf("create zone: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
return nil, fmt.Errorf("create zone: last insert id: %w", err)
}
return d.GetZoneByID(id)
}
// UpdateZone updates a zone's SOA parameters and bumps the serial.
func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
name = strings.ToLower(strings.TrimSuffix(name, "."))
zone, err := d.GetZone(name)
if err != nil {
return nil, err
}
serial := nextSerial(zone.Serial)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`,
primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID)
if err != nil {
return nil, fmt.Errorf("update zone: %w", err)
}
return d.GetZoneByID(zone.ID)
}
// DeleteZone deletes a zone and all its records.
func (d *DB) DeleteZone(name string) error {
name = strings.ToLower(strings.TrimSuffix(name, "."))
res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name)
if err != nil {
return fmt.Errorf("delete zone: %w", err)
}
n, err := res.RowsAffected()
if err != nil {
return fmt.Errorf("delete zone: rows affected: %w", err)
}
if n == 0 {
return ErrNotFound
}
return nil
}
// BumpSerial increments the serial for a zone within a transaction.
func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error {
var current int64
if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(&current); err != nil {
return fmt.Errorf("read serial: %w", err)
}
serial := nextSerial(current)
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
_, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID)
return err
}
// ZoneNames returns all zone names for the DNS handler.
func (d *DB) ZoneNames() ([]string, error) {
rows, err := d.Query(`SELECT name FROM zones ORDER BY name`)
if err != nil {
return nil, err
}
defer rows.Close()
var names []string
for rows.Next() {
var name string
if err := rows.Scan(&name); err != nil {
return nil, err
}
names = append(names, name)
}
return names, rows.Err()
}
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
func nextSerial(current int64) int64 {
today := time.Now().UTC()
datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64)
datePrefix *= 100 // YYYYMMDD00
if current >= datePrefix {
return current + 1
}
return datePrefix + 1
}

167
internal/db/zones_test.go Normal file
View File

@@ -0,0 +1,167 @@
package db
import (
"path/filepath"
"testing"
)
func openTestDB(t *testing.T) *DB {
t.Helper()
dir := t.TempDir()
database, err := Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
func TestCreateZone(t *testing.T) {
db := openTestDB(t)
zone, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
if zone.Serial == 0 {
t.Fatal("serial should not be zero")
}
if zone.PrimaryNS != "ns1.example.com." {
t.Fatalf("got primary_ns %q, want %q", zone.PrimaryNS, "ns1.example.com.")
}
}
func TestCreateZoneDuplicate(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
_, err = db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err == nil {
t.Fatal("expected error for duplicate zone")
}
}
func TestCreateZoneNormalization(t *testing.T) {
db := openTestDB(t)
zone, err := db.CreateZone("Example.COM.", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
}
func TestListZones(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("b.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone b: %v", err)
}
_, err = db.CreateZone("a.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone a: %v", err)
}
zones, err := db.ListZones()
if err != nil {
t.Fatalf("list zones: %v", err)
}
if len(zones) != 2 {
t.Fatalf("got %d zones, want 2", len(zones))
}
if zones[0].Name != "a.example.com" {
t.Fatalf("zones should be ordered by name, got %q first", zones[0].Name)
}
}
func TestGetZone(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
zone, err := db.GetZone("example.com")
if err != nil {
t.Fatalf("get zone: %v", err)
}
if zone.Name != "example.com" {
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
}
_, err = db.GetZone("nonexistent.com")
if err != ErrNotFound {
t.Fatalf("expected ErrNotFound, got %v", err)
}
}
func TestUpdateZone(t *testing.T) {
db := openTestDB(t)
original, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
updated, err := db.UpdateZone("example.com", "ns2.example.com.", "newadmin.example.com.", 7200, 1200, 172800, 600)
if err != nil {
t.Fatalf("update zone: %v", err)
}
if updated.PrimaryNS != "ns2.example.com." {
t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNS, "ns2.example.com.")
}
if updated.Serial <= original.Serial {
t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial)
}
}
func TestDeleteZone(t *testing.T) {
db := openTestDB(t)
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
if err := db.DeleteZone("example.com"); err != nil {
t.Fatalf("delete zone: %v", err)
}
_, err = db.GetZone("example.com")
if err != ErrNotFound {
t.Fatalf("expected ErrNotFound after delete, got %v", err)
}
if err := db.DeleteZone("nonexistent.com"); err != ErrNotFound {
t.Fatalf("expected ErrNotFound for nonexistent zone, got %v", err)
}
}
func TestNextSerial(t *testing.T) {
// A zero serial should produce a date-based serial.
s1 := nextSerial(0)
if s1 < 2026032600 {
t.Fatalf("serial %d seems too low", s1)
}
// Incrementing should increase.
s2 := nextSerial(s1)
if s2 != s1+1 {
t.Fatalf("expected %d, got %d", s1+1, s2)
}
}

67
internal/dns/cache.go Normal file
View File

@@ -0,0 +1,67 @@
package dns
import (
"sync"
"time"
"github.com/miekg/dns"
)
type cacheKey struct {
Name string
Qtype uint16
Class uint16
}
type cacheEntry struct {
msg *dns.Msg
expiresAt time.Time
}
// Cache is a thread-safe in-memory DNS response cache with TTL-based expiry.
type Cache struct {
mu sync.RWMutex
entries map[cacheKey]*cacheEntry
}
// NewCache creates an empty DNS cache.
func NewCache() *Cache {
return &Cache{
entries: make(map[cacheKey]*cacheEntry),
}
}
// Get returns a cached response if it exists and has not expired.
func (c *Cache) Get(name string, qtype, class uint16) *dns.Msg {
c.mu.RLock()
defer c.mu.RUnlock()
key := cacheKey{Name: name, Qtype: qtype, Class: class}
entry, ok := c.entries[key]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.msg
}
// Set stores a DNS response in the cache with the given TTL.
func (c *Cache) Set(name string, qtype, class uint16, msg *dns.Msg, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
key := cacheKey{Name: name, Qtype: qtype, Class: class}
c.entries[key] = &cacheEntry{
msg: msg.Copy(),
expiresAt: time.Now().Add(ttl),
}
// Lazy eviction: clean up expired entries if cache is growing.
if len(c.entries) > 1000 {
now := time.Now()
for k, v := range c.entries {
if now.After(v.expiresAt) {
delete(c.entries, k)
}
}
}
}

View File

@@ -0,0 +1,81 @@
package dns
import (
"testing"
"time"
"github.com/miekg/dns"
)
func TestCacheSetGet(t *testing.T) {
c := NewCache()
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: []byte{1, 2, 3, 4},
})
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 5*time.Second)
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached == nil {
t.Fatal("expected cached response")
}
if len(cached.Answer) != 1 {
t.Fatalf("got %d answers, want 1", len(cached.Answer))
}
}
func TestCacheMiss(t *testing.T) {
c := NewCache()
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached != nil {
t.Fatal("expected nil for cache miss")
}
}
func TestCacheExpiry(t *testing.T) {
c := NewCache()
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 1*time.Millisecond)
time.Sleep(2 * time.Millisecond)
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached != nil {
t.Fatal("expected nil for expired entry")
}
}
func TestCacheDifferentTypes(t *testing.T) {
c := NewCache()
msgA := new(dns.Msg)
msgA.SetQuestion("example.com.", dns.TypeA)
c.Set("example.com.", dns.TypeA, dns.ClassINET, msgA, 5*time.Second)
msgAAAA := new(dns.Msg)
msgAAAA.SetQuestion("example.com.", dns.TypeAAAA)
c.Set("example.com.", dns.TypeAAAA, dns.ClassINET, msgAAAA, 5*time.Second)
cachedA := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cachedA == nil {
t.Fatal("expected cached A response")
}
cachedAAAA := c.Get("example.com.", dns.TypeAAAA, dns.ClassINET)
if cachedAAAA == nil {
t.Fatal("expected cached AAAA response")
}
// Different type should not match.
cachedMX := c.Get("example.com.", dns.TypeMX, dns.ClassINET)
if cachedMX != nil {
t.Fatal("expected nil for uncached type")
}
}

87
internal/dns/forwarder.go Normal file
View File

@@ -0,0 +1,87 @@
package dns
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Forwarder handles forwarding DNS queries to upstream resolvers.
type Forwarder struct {
upstreams []string
client *dns.Client
cache *Cache
}
// NewForwarder creates a Forwarder with the given upstream addresses.
func NewForwarder(upstreams []string) *Forwarder {
return &Forwarder{
upstreams: upstreams,
client: &dns.Client{
Timeout: 2 * time.Second,
},
cache: NewCache(),
}
}
// Forward sends a query to upstream resolvers and returns the response.
// Responses are cached by (qname, qtype, qclass) with TTL-based expiry.
func (f *Forwarder) Forward(r *dns.Msg) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
q := r.Question[0]
// Check cache.
if cached := f.cache.Get(q.Name, q.Qtype, q.Qclass); cached != nil {
return cached.Copy(), nil
}
// Try each upstream in order.
var lastErr error
for _, upstream := range f.upstreams {
resp, _, err := f.client.Exchange(r, upstream)
if err != nil {
lastErr = err
continue
}
// Don't cache SERVFAIL or REFUSED.
if resp.Rcode != dns.RcodeServerFailure && resp.Rcode != dns.RcodeRefused {
ttl := minTTL(resp)
if ttl > 300 {
ttl = 300
}
if ttl > 0 {
f.cache.Set(q.Name, q.Qtype, q.Qclass, resp, time.Duration(ttl)*time.Second)
}
}
return resp, nil
}
return nil, fmt.Errorf("all upstreams failed: %w", lastErr)
}
// minTTL returns the minimum TTL from all resource records in a response.
func minTTL(msg *dns.Msg) uint32 {
var min uint32
first := true
for _, sections := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} {
for _, rr := range sections {
ttl := rr.Header().Ttl
if first || ttl < min {
min = ttl
first = false
}
}
}
if first {
return 60 // No records; default to 60s.
}
return min
}

280
internal/dns/server.go Normal file
View File

@@ -0,0 +1,280 @@
// Package dns implements the authoritative DNS server for MCNS.
// It serves records from SQLite for authoritative zones and forwards
// all other queries to configured upstream resolvers.
package dns
import (
"log/slog"
"net"
"strings"
"github.com/miekg/dns"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// Server is the MCNS DNS server. It listens on both UDP and TCP.
type Server struct {
db *db.DB
forwarder *Forwarder
logger *slog.Logger
udp *dns.Server
tcp *dns.Server
}
// New creates a DNS server that serves records from the database and
// forwards non-authoritative queries to the given upstreams.
func New(database *db.DB, upstreams []string, logger *slog.Logger) *Server {
s := &Server{
db: database,
forwarder: NewForwarder(upstreams),
logger: logger,
}
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
s.udp = &dns.Server{Handler: mux, Net: "udp"}
s.tcp = &dns.Server{Handler: mux, Net: "tcp"}
return s
}
// ListenAndServe starts the DNS server on the given address for both
// UDP and TCP. It blocks until Shutdown is called.
func (s *Server) ListenAndServe(addr string) error {
s.udp.Addr = addr
s.tcp.Addr = addr
errCh := make(chan error, 2)
go func() {
s.logger.Info("dns server listening", "addr", addr, "proto", "udp")
errCh <- s.udp.ListenAndServe()
}()
go func() {
s.logger.Info("dns server listening", "addr", addr, "proto", "tcp")
errCh <- s.tcp.ListenAndServe()
}()
return <-errCh
}
// Shutdown gracefully stops the DNS server.
func (s *Server) Shutdown() {
_ = s.udp.Shutdown()
_ = s.tcp.Shutdown()
}
// handleQuery is the main DNS query handler. It checks if the query
// falls within an authoritative zone and either serves from the database
// or forwards to upstream.
func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
s.writeResponse(w, r, dns.RcodeFormatError, nil, nil)
return
}
q := r.Question[0]
qname := strings.ToLower(q.Name)
// Find the authoritative zone for this query.
zone := s.findZone(qname)
if zone == nil {
// Not authoritative — forward to upstream.
s.forwardQuery(w, r)
return
}
s.handleAuthoritativeQuery(w, r, zone, qname, q.Qtype)
}
// findZone returns the best matching zone for the query name, or nil.
func (s *Server) findZone(qname string) *db.Zone {
// Walk up the domain labels to find the longest matching zone.
name := strings.TrimSuffix(qname, ".")
parts := strings.Split(name, ".")
for i := range parts {
candidate := strings.Join(parts[i:], ".")
zone, err := s.db.GetZone(candidate)
if err == nil {
return zone
}
}
return nil
}
// handleAuthoritativeQuery serves a query from the database.
func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone *db.Zone, qname string, qtype uint16) {
// Extract the record name relative to the zone.
zoneFQDN := zone.Name + "."
var relName string
if qname == zoneFQDN {
relName = "@"
} else {
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
}
// Handle SOA queries.
if qtype == dns.TypeSOA || relName == "@" && qtype == dns.TypeSOA {
soa := s.buildSOA(zone)
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
return
}
// Handle NS queries at the zone apex.
if qtype == dns.TypeNS && relName == "@" {
ns := &dns.NS{
Hdr: dns.RR_Header{Name: zoneFQDN, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
Ns: zone.PrimaryNS,
}
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{ns}, nil)
return
}
// Look up the requested record type.
var answers []dns.RR
var lookupType string
switch qtype {
case dns.TypeA:
lookupType = "A"
case dns.TypeAAAA:
lookupType = "AAAA"
case dns.TypeCNAME:
lookupType = "CNAME"
default:
// For unsupported types, check if the name exists at all.
// If it does, return empty answer. If not, NXDOMAIN.
exists, _ := s.nameExists(zone.Name, relName)
if exists {
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
} else {
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
}
return
}
records, err := s.db.LookupRecords(zone.Name, relName, lookupType)
if err != nil {
s.logger.Error("dns lookup failed", "zone", zone.Name, "name", relName, "type", lookupType, "error", err)
s.writeResponse(w, r, dns.RcodeServerFailure, nil, nil)
return
}
// If no direct records, check for CNAME.
if len(records) == 0 && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
cnameRecords, err := s.db.LookupCNAME(zone.Name, relName)
if err == nil && len(cnameRecords) > 0 {
for _, rec := range cnameRecords {
answers = append(answers, &dns.CNAME{
Hdr: dns.RR_Header{Name: qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(rec.TTL)},
Target: rec.Value,
})
}
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
return
}
}
if len(records) == 0 {
// Name might still exist with other record types.
exists, _ := s.nameExists(zone.Name, relName)
if exists {
// NODATA: name exists but no records of requested type.
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
} else {
// NXDOMAIN: name does not exist.
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
}
return
}
for _, rec := range records {
rr := s.recordToRR(qname, rec)
if rr != nil {
answers = append(answers, rr)
}
}
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
}
// nameExists checks if any records exist for a name in a zone.
func (s *Server) nameExists(zoneName, name string) (bool, error) {
records, err := s.db.ListRecords(zoneName, name, "")
if err != nil {
return false, err
}
return len(records) > 0, nil
}
// recordToRR converts a database Record to a dns.RR.
func (s *Server) recordToRR(qname string, rec db.Record) dns.RR {
hdr := dns.RR_Header{Name: qname, Class: dns.ClassINET, Ttl: uint32(rec.TTL)}
switch rec.Type {
case "A":
hdr.Rrtype = dns.TypeA
return &dns.A{Hdr: hdr, A: parseIP(rec.Value)}
case "AAAA":
hdr.Rrtype = dns.TypeAAAA
return &dns.AAAA{Hdr: hdr, AAAA: parseIP(rec.Value)}
case "CNAME":
hdr.Rrtype = dns.TypeCNAME
return &dns.CNAME{Hdr: hdr, Target: rec.Value}
}
return nil
}
// buildSOA constructs a SOA record for the given zone.
func (s *Server) buildSOA(zone *db.Zone) *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{Name: zone.Name + ".", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
Ns: zone.PrimaryNS,
Mbox: zone.AdminEmail,
Serial: uint32(zone.Serial),
Refresh: uint32(zone.Refresh),
Retry: uint32(zone.Retry),
Expire: uint32(zone.Expire),
Minttl: uint32(zone.MinimumTTL),
}
}
// writeResponse constructs and writes a DNS response.
func (s *Server) writeResponse(w dns.ResponseWriter, r *dns.Msg, rcode int, answer []dns.RR, ns []dns.RR) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Rcode = rcode
m.Answer = answer
m.Ns = ns
if err := w.WriteMsg(m); err != nil {
s.logger.Error("dns write failed", "error", err)
}
}
// forwardQuery forwards a DNS query to upstream resolvers.
func (s *Server) forwardQuery(w dns.ResponseWriter, r *dns.Msg) {
resp, err := s.forwarder.Forward(r)
if err != nil {
s.logger.Debug("dns forward failed", "error", err)
m := new(dns.Msg)
m.SetReply(r)
m.Rcode = dns.RcodeServerFailure
_ = w.WriteMsg(m)
return
}
resp.Id = r.Id
if err := w.WriteMsg(resp); err != nil {
s.logger.Error("dns write failed", "error", err)
}
}
// parseIP parses an IP address string into a net.IP.
func parseIP(s string) net.IP {
return net.ParseIP(s)
}

142
internal/dns/server_test.go Normal file
View File

@@ -0,0 +1,142 @@
package dns
import (
"path/filepath"
"testing"
"github.com/miekg/dns"
"git.wntrmute.dev/kyle/mcns/internal/db"
"log/slog"
)
func openTestDB(t *testing.T) *db.DB {
t.Helper()
dir := t.TempDir()
database, err := db.Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
func setupTestServer(t *testing.T) (*Server, *db.DB) {
t.Helper()
database := openTestDB(t)
logger := slog.Default()
_, err := database.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create A record: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create A record 2: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "mcr", "AAAA", "2001:db8::1", 300)
if err != nil {
t.Fatalf("create AAAA record: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "metacrypt.svc.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create CNAME record: %v", err)
}
srv := New(database, []string{"1.1.1.1:53"}, logger)
return srv, database
}
func TestFindZone(t *testing.T) {
srv, _ := setupTestServer(t)
zone := srv.findZone("metacrypt.svc.mcp.metacircular.net.")
if zone == nil {
t.Fatal("expected to find zone")
}
if zone.Name != "svc.mcp.metacircular.net" {
t.Fatalf("got zone %q, want %q", zone.Name, "svc.mcp.metacircular.net")
}
zone = srv.findZone("nonexistent.com.")
if zone != nil {
t.Fatal("expected nil for nonexistent zone")
}
}
func TestBuildSOA(t *testing.T) {
srv, database := setupTestServer(t)
zone, err := database.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
soa := srv.buildSOA(zone)
if soa.Ns != "ns.mcp.metacircular.net." {
t.Fatalf("got ns %q, want %q", soa.Ns, "ns.mcp.metacircular.net.")
}
if soa.Hdr.Name != "svc.mcp.metacircular.net." {
t.Fatalf("got name %q, want %q", soa.Hdr.Name, "svc.mcp.metacircular.net.")
}
}
func TestRecordToRR_A(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "metacrypt", Type: "A", Value: "192.168.88.181", TTL: 300}
rr := srv.recordToRR("metacrypt.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
a, ok := rr.(*dns.A)
if !ok {
t.Fatalf("expected *dns.A, got %T", rr)
}
if a.A.String() != "192.168.88.181" {
t.Fatalf("got IP %q, want %q", a.A.String(), "192.168.88.181")
}
}
func TestRecordToRR_AAAA(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "mcr", Type: "AAAA", Value: "2001:db8::1", TTL: 300}
rr := srv.recordToRR("mcr.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
aaaa, ok := rr.(*dns.AAAA)
if !ok {
t.Fatalf("expected *dns.AAAA, got %T", rr)
}
if aaaa.AAAA.String() != "2001:db8::1" {
t.Fatalf("got IP %q, want %q", aaaa.AAAA.String(), "2001:db8::1")
}
}
func TestRecordToRR_CNAME(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "alias", Type: "CNAME", Value: "metacrypt.svc.mcp.metacircular.net.", TTL: 300}
rr := srv.recordToRR("alias.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
cname, ok := rr.(*dns.CNAME)
if !ok {
t.Fatalf("expected *dns.CNAME, got %T", rr)
}
if cname.Target != "metacrypt.svc.mcp.metacircular.net." {
t.Fatalf("got target %q, want %q", cname.Target, "metacrypt.svc.mcp.metacircular.net.")
}
}

View File

@@ -0,0 +1,20 @@
package grpcserver
import (
"context"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
type adminService struct {
pb.UnimplementedAdminServiceServer
db *db.DB
}
func (s *adminService) Health(_ context.Context, _ *pb.HealthRequest) (*pb.HealthResponse, error) {
if err := s.db.Ping(); err != nil {
return &pb.HealthResponse{Status: "unhealthy"}, nil
}
return &pb.HealthResponse{Status: "ok"}, nil
}

View File

@@ -0,0 +1,38 @@
package grpcserver
import (
"context"
"errors"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
)
type authService struct {
pb.UnimplementedAuthServiceServer
auth *mcdslauth.Authenticator
}
func (s *authService) Login(_ context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) {
token, _, err := s.auth.Login(req.Username, req.Password, req.TotpCode)
if err != nil {
if errors.Is(err, mcdslauth.ErrInvalidCredentials) {
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
}
if errors.Is(err, mcdslauth.ErrForbidden) {
return nil, status.Error(codes.PermissionDenied, "access denied by login policy")
}
return nil, status.Error(codes.Unavailable, "authentication service unavailable")
}
return &pb.LoginResponse{Token: token}, nil
}
func (s *authService) Logout(_ context.Context, req *pb.LogoutRequest) (*pb.LogoutResponse, error) {
if err := s.auth.Logout(req.Token); err != nil {
return nil, status.Error(codes.Internal, "logout failed")
}
return &pb.LogoutResponse{}, nil
}

View File

@@ -0,0 +1,45 @@
package grpcserver
import (
mcdslgrpc "git.wntrmute.dev/kyle/mcdsl/grpcserver"
)
// methodMap builds the mcdsl grpcserver.MethodMap for MCNS.
//
// Adding a new RPC without adding it to the correct map is a security
// defect — the mcdsl auth interceptor denies unmapped methods by default.
func methodMap() mcdslgrpc.MethodMap {
return mcdslgrpc.MethodMap{
Public: publicMethods(),
AuthRequired: authRequiredMethods(),
AdminRequired: adminRequiredMethods(),
}
}
func publicMethods() map[string]bool {
return map[string]bool{
"/mcns.v1.AdminService/Health": true,
"/mcns.v1.AuthService/Login": true,
}
}
func authRequiredMethods() map[string]bool {
return map[string]bool{
"/mcns.v1.AuthService/Logout": true,
"/mcns.v1.ZoneService/ListZones": true,
"/mcns.v1.ZoneService/GetZone": true,
"/mcns.v1.RecordService/ListRecords": true,
"/mcns.v1.RecordService/GetRecord": true,
}
}
func adminRequiredMethods() map[string]bool {
return map[string]bool{
"/mcns.v1.ZoneService/CreateZone": true,
"/mcns.v1.ZoneService/UpdateZone": true,
"/mcns.v1.ZoneService/DeleteZone": true,
"/mcns.v1.RecordService/CreateRecord": true,
"/mcns.v1.RecordService/UpdateRecord": true,
"/mcns.v1.RecordService/DeleteRecord": true,
}
}

View File

@@ -0,0 +1,110 @@
package grpcserver
import (
"context"
"errors"
"time"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
type recordService struct {
pb.UnimplementedRecordServiceServer
db *db.DB
}
func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsRequest) (*pb.ListRecordsResponse, error) {
records, err := s.db.ListRecords(req.Zone, req.Name, req.Type)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to list records")
}
resp := &pb.ListRecordsResponse{}
for _, r := range records {
resp.Records = append(resp.Records, recordToProto(r))
}
return resp, nil
}
func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (*pb.Record, error) {
record, err := s.db.GetRecord(req.Id)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to get record")
}
return recordToProto(*record), nil
}
func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequest) (*pb.Record, error) {
record, err := s.db.CreateRecord(req.Zone, req.Name, req.Type, req.Value, int(req.Ttl))
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
}
if errors.Is(err, db.ErrConflict) {
return nil, status.Error(codes.AlreadyExists, err.Error())
}
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return recordToProto(*record), nil
}
func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequest) (*pb.Record, error) {
record, err := s.db.UpdateRecord(req.Id, req.Name, req.Type, req.Value, int(req.Ttl))
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
}
if errors.Is(err, db.ErrConflict) {
return nil, status.Error(codes.AlreadyExists, err.Error())
}
if err != nil {
return nil, status.Error(codes.InvalidArgument, err.Error())
}
return recordToProto(*record), nil
}
func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequest) (*pb.DeleteRecordResponse, error) {
err := s.db.DeleteRecord(req.Id)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "record not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to delete record")
}
return &pb.DeleteRecordResponse{}, nil
}
func recordToProto(r db.Record) *pb.Record {
return &pb.Record{
Id: r.ID,
Zone: r.ZoneName,
Name: r.Name,
Type: r.Type,
Value: r.Value,
Ttl: int32(r.TTL),
CreatedAt: parseRecordTimestamp(r.CreatedAt),
UpdatedAt: parseRecordTimestamp(r.UpdatedAt),
}
}
func parseRecordTimestamp(s string) *timestamppb.Timestamp {
t, err := parseTime(s)
if err != nil {
return nil
}
return timestamppb.New(t)
}
func parseTime(s string) (time.Time, error) {
return time.Parse("2006-01-02T15:04:05Z", s)
}

View File

@@ -0,0 +1,50 @@
package grpcserver
import (
"log/slog"
"net"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
mcdslgrpc "git.wntrmute.dev/kyle/mcdsl/grpcserver"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// Deps holds the dependencies injected into the gRPC server.
type Deps struct {
DB *db.DB
Authenticator *mcdslauth.Authenticator
}
// Server wraps a mcdsl grpcserver.Server with MCNS-specific services.
type Server struct {
srv *mcdslgrpc.Server
}
// New creates a configured gRPC server with MCNS services registered.
func New(certFile, keyFile string, deps Deps, logger *slog.Logger) (*Server, error) {
srv, err := mcdslgrpc.New(certFile, keyFile, deps.Authenticator, methodMap(), logger)
if err != nil {
return nil, err
}
s := &Server{srv: srv}
pb.RegisterAdminServiceServer(srv.GRPCServer, &adminService{db: deps.DB})
pb.RegisterAuthServiceServer(srv.GRPCServer, &authService{auth: deps.Authenticator})
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB})
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB})
return s, nil
}
// Serve starts the gRPC server on the given listener.
func (s *Server) Serve(lis net.Listener) error {
return s.srv.GRPCServer.Serve(lis)
}
// GracefulStop gracefully stops the gRPC server.
func (s *Server) GracefulStop() {
s.srv.Stop()
}

View File

@@ -0,0 +1,134 @@
package grpcserver
import (
"context"
"errors"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
type zoneService struct {
pb.UnimplementedZoneServiceServer
db *db.DB
}
func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.ListZonesResponse, error) {
zones, err := s.db.ListZones()
if err != nil {
return nil, status.Error(codes.Internal, "failed to list zones")
}
resp := &pb.ListZonesResponse{}
for _, z := range zones {
resp.Zones = append(resp.Zones, zoneToProto(z))
}
return resp, nil
}
func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zone, error) {
zone, err := s.db.GetZone(req.Name)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to get zone")
}
return zoneToProto(*zone), nil
}
func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (*pb.Zone, error) {
refresh := int(req.Refresh)
if refresh == 0 {
refresh = 3600
}
retry := int(req.Retry)
if retry == 0 {
retry = 600
}
expire := int(req.Expire)
if expire == 0 {
expire = 86400
}
minTTL := int(req.MinimumTtl)
if minTTL == 0 {
minTTL = 300
}
zone, err := s.db.CreateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
if errors.Is(err, db.ErrConflict) {
return nil, status.Error(codes.AlreadyExists, err.Error())
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to create zone")
}
return zoneToProto(*zone), nil
}
func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (*pb.Zone, error) {
refresh := int(req.Refresh)
if refresh == 0 {
refresh = 3600
}
retry := int(req.Retry)
if retry == 0 {
retry = 600
}
expire := int(req.Expire)
if expire == 0 {
expire = 86400
}
minTTL := int(req.MinimumTtl)
if minTTL == 0 {
minTTL = 300
}
zone, err := s.db.UpdateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to update zone")
}
return zoneToProto(*zone), nil
}
func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (*pb.DeleteZoneResponse, error) {
err := s.db.DeleteZone(req.Name)
if errors.Is(err, db.ErrNotFound) {
return nil, status.Error(codes.NotFound, "zone not found")
}
if err != nil {
return nil, status.Error(codes.Internal, "failed to delete zone")
}
return &pb.DeleteZoneResponse{}, nil
}
func zoneToProto(z db.Zone) *pb.Zone {
return &pb.Zone{
Id: z.ID,
Name: z.Name,
PrimaryNs: z.PrimaryNS,
AdminEmail: z.AdminEmail,
Refresh: int32(z.Refresh),
Retry: int32(z.Retry),
Expire: int32(z.Expire),
MinimumTtl: int32(z.MinimumTTL),
Serial: z.Serial,
CreatedAt: parseTimestamp(z.CreatedAt),
UpdatedAt: parseTimestamp(z.UpdatedAt),
}
}
func parseTimestamp(s string) *timestamppb.Timestamp {
// SQLite stores as "2006-01-02T15:04:05Z".
t, err := parseTime(s)
if err != nil {
return nil
}
return timestamppb.New(t)
}

62
internal/server/auth.go Normal file
View File

@@ -0,0 +1,62 @@
package server
import (
"encoding/json"
"errors"
"net/http"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
)
type loginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
TOTPCode string `json:"totp_code"`
}
type loginResponse struct {
Token string `json:"token"`
}
func loginHandler(auth *mcdslauth.Authenticator) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req loginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
token, _, err := auth.Login(req.Username, req.Password, req.TOTPCode)
if err != nil {
if errors.Is(err, mcdslauth.ErrInvalidCredentials) {
writeError(w, http.StatusUnauthorized, "invalid credentials")
return
}
if errors.Is(err, mcdslauth.ErrForbidden) {
writeError(w, http.StatusForbidden, "access denied by login policy")
return
}
writeError(w, http.StatusServiceUnavailable, "authentication service unavailable")
return
}
writeJSON(w, http.StatusOK, loginResponse{Token: token})
}
}
func logoutHandler(auth *mcdslauth.Authenticator) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
token := extractBearerToken(r)
if token == "" {
writeError(w, http.StatusUnauthorized, "authentication required")
return
}
if err := auth.Logout(token); err != nil {
writeError(w, http.StatusInternalServerError, "logout failed")
return
}
w.WriteHeader(http.StatusNoContent)
}
}

View File

@@ -0,0 +1,96 @@
package server
import (
"context"
"log/slog"
"net/http"
"strings"
"time"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
)
type contextKey string
const tokenInfoKey contextKey = "tokenInfo"
// requireAuth returns middleware that validates Bearer tokens via MCIAS.
func requireAuth(auth *mcdslauth.Authenticator) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := extractBearerToken(r)
if token == "" {
writeError(w, http.StatusUnauthorized, "authentication required")
return
}
info, err := auth.ValidateToken(token)
if err != nil {
writeError(w, http.StatusUnauthorized, "invalid or expired token")
return
}
ctx := context.WithValue(r.Context(), tokenInfoKey, info)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// requireAdmin is middleware that checks the caller has the admin role.
func requireAdmin(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info := tokenInfoFromContext(r.Context())
if info == nil || !info.IsAdmin {
writeError(w, http.StatusForbidden, "admin role required")
return
}
next.ServeHTTP(w, r)
})
}
// tokenInfoFromContext extracts the TokenInfo from the request context.
func tokenInfoFromContext(ctx context.Context) *mcdslauth.TokenInfo {
info, _ := ctx.Value(tokenInfoKey).(*mcdslauth.TokenInfo)
return info
}
// extractBearerToken extracts a bearer token from the Authorization header.
func extractBearerToken(r *http.Request) string {
h := r.Header.Get("Authorization")
if h == "" {
return ""
}
const prefix = "Bearer "
if !strings.HasPrefix(h, prefix) {
return ""
}
return strings.TrimSpace(h[len(prefix):])
}
// loggingMiddleware logs HTTP requests.
func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
next.ServeHTTP(sw, r)
logger.Info("http",
"method", r.Method,
"path", r.URL.Path,
"status", sw.status,
"duration", time.Since(start),
"remote", r.RemoteAddr,
)
})
}
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}

174
internal/server/records.go Normal file
View File

@@ -0,0 +1,174 @@
package server
import (
"encoding/json"
"errors"
"net/http"
"strconv"
"github.com/go-chi/chi/v5"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
type createRecordRequest struct {
Name string `json:"name"`
Type string `json:"type"`
Value string `json:"value"`
TTL int `json:"ttl"`
}
func listRecordsHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
zoneName := chi.URLParam(r, "zone")
nameFilter := r.URL.Query().Get("name")
typeFilter := r.URL.Query().Get("type")
records, err := database.ListRecords(zoneName, nameFilter, typeFilter)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "zone not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list records")
return
}
if records == nil {
records = []db.Record{}
}
writeJSON(w, http.StatusOK, map[string]any{"records": records})
}
}
func getRecordHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid record ID")
return
}
record, err := database.GetRecord(id)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "record not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to get record")
return
}
writeJSON(w, http.StatusOK, record)
}
}
func createRecordHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
zoneName := chi.URLParam(r, "zone")
var req createRecordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
if req.Type == "" {
writeError(w, http.StatusBadRequest, "type is required")
return
}
if req.Value == "" {
writeError(w, http.StatusBadRequest, "value is required")
return
}
record, err := database.CreateRecord(zoneName, req.Name, req.Type, req.Value, req.TTL)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "zone not found")
return
}
if errors.Is(err, db.ErrConflict) {
writeError(w, http.StatusConflict, err.Error())
return
}
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusCreated, record)
}
}
func updateRecordHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid record ID")
return
}
var req createRecordRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
if req.Type == "" {
writeError(w, http.StatusBadRequest, "type is required")
return
}
if req.Value == "" {
writeError(w, http.StatusBadRequest, "value is required")
return
}
record, err := database.UpdateRecord(id, req.Name, req.Type, req.Value, req.TTL)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "record not found")
return
}
if errors.Is(err, db.ErrConflict) {
writeError(w, http.StatusConflict, err.Error())
return
}
if err != nil {
writeError(w, http.StatusBadRequest, err.Error())
return
}
writeJSON(w, http.StatusOK, record)
}
}
func deleteRecordHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
idStr := chi.URLParam(r, "id")
id, err := strconv.ParseInt(idStr, 10, 64)
if err != nil {
writeError(w, http.StatusBadRequest, "invalid record ID")
return
}
err = database.DeleteRecord(id)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "record not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to delete record")
return
}
w.WriteHeader(http.StatusNoContent)
}
}

71
internal/server/routes.go Normal file
View File

@@ -0,0 +1,71 @@
package server
import (
"encoding/json"
"log/slog"
"net/http"
"github.com/go-chi/chi/v5"
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
"git.wntrmute.dev/kyle/mcdsl/health"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// Deps holds dependencies injected into the REST handlers.
type Deps struct {
DB *db.DB
Auth *mcdslauth.Authenticator
Logger *slog.Logger
}
// NewRouter builds the chi router with all MCNS REST endpoints.
func NewRouter(deps Deps) *chi.Mux {
r := chi.NewRouter()
r.Use(loggingMiddleware(deps.Logger))
// Public endpoints.
r.Post("/v1/auth/login", loginHandler(deps.Auth))
r.Get("/v1/health", health.Handler(deps.DB.DB))
// Authenticated endpoints.
r.Group(func(r chi.Router) {
r.Use(requireAuth(deps.Auth))
r.Post("/v1/auth/logout", logoutHandler(deps.Auth))
// Zone endpoints — reads for all authenticated users, writes for admin.
r.Get("/v1/zones", listZonesHandler(deps.DB))
r.Get("/v1/zones/{zone}", getZoneHandler(deps.DB))
// Admin-only zone mutations.
r.With(requireAdmin).Post("/v1/zones", createZoneHandler(deps.DB))
r.With(requireAdmin).Put("/v1/zones/{zone}", updateZoneHandler(deps.DB))
r.With(requireAdmin).Delete("/v1/zones/{zone}", deleteZoneHandler(deps.DB))
// Record endpoints — reads for all authenticated users, writes for admin.
r.Get("/v1/zones/{zone}/records", listRecordsHandler(deps.DB))
r.Get("/v1/zones/{zone}/records/{id}", getRecordHandler(deps.DB))
// Admin-only record mutations.
r.With(requireAdmin).Post("/v1/zones/{zone}/records", createRecordHandler(deps.DB))
r.With(requireAdmin).Put("/v1/zones/{zone}/records/{id}", updateRecordHandler(deps.DB))
r.With(requireAdmin).Delete("/v1/zones/{zone}/records/{id}", deleteRecordHandler(deps.DB))
})
return r
}
// writeJSON writes a JSON response with the given status code.
func writeJSON(w http.ResponseWriter, status int, v any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
_ = json.NewEncoder(w).Encode(v)
}
// writeError writes a standard error response.
func writeError(w http.ResponseWriter, status int, message string) {
writeJSON(w, status, map[string]string{"error": message})
}

163
internal/server/zones.go Normal file
View File

@@ -0,0 +1,163 @@
package server
import (
"encoding/json"
"errors"
"net/http"
"github.com/go-chi/chi/v5"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
type createZoneRequest struct {
Name string `json:"name"`
PrimaryNS string `json:"primary_ns"`
AdminEmail string `json:"admin_email"`
Refresh int `json:"refresh"`
Retry int `json:"retry"`
Expire int `json:"expire"`
MinimumTTL int `json:"minimum_ttl"`
}
func listZonesHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, _ *http.Request) {
zones, err := database.ListZones()
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to list zones")
return
}
if zones == nil {
zones = []db.Zone{}
}
writeJSON(w, http.StatusOK, map[string]any{"zones": zones})
}
}
func getZoneHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "zone")
zone, err := database.GetZone(name)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "zone not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to get zone")
return
}
writeJSON(w, http.StatusOK, zone)
}
}
func createZoneHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var req createZoneRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.Name == "" {
writeError(w, http.StatusBadRequest, "name is required")
return
}
if req.PrimaryNS == "" {
writeError(w, http.StatusBadRequest, "primary_ns is required")
return
}
if req.AdminEmail == "" {
writeError(w, http.StatusBadRequest, "admin_email is required")
return
}
// Apply defaults for SOA params.
if req.Refresh == 0 {
req.Refresh = 3600
}
if req.Retry == 0 {
req.Retry = 600
}
if req.Expire == 0 {
req.Expire = 86400
}
if req.MinimumTTL == 0 {
req.MinimumTTL = 300
}
zone, err := database.CreateZone(req.Name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
if errors.Is(err, db.ErrConflict) {
writeError(w, http.StatusConflict, err.Error())
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to create zone")
return
}
writeJSON(w, http.StatusCreated, zone)
}
}
func updateZoneHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "zone")
var req createZoneRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
writeError(w, http.StatusBadRequest, "invalid request body")
return
}
if req.PrimaryNS == "" {
writeError(w, http.StatusBadRequest, "primary_ns is required")
return
}
if req.AdminEmail == "" {
writeError(w, http.StatusBadRequest, "admin_email is required")
return
}
if req.Refresh == 0 {
req.Refresh = 3600
}
if req.Retry == 0 {
req.Retry = 600
}
if req.Expire == 0 {
req.Expire = 86400
}
if req.MinimumTTL == 0 {
req.MinimumTTL = 300
}
zone, err := database.UpdateZone(name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "zone not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to update zone")
return
}
writeJSON(w, http.StatusOK, zone)
}
}
func deleteZoneHandler(database *db.DB) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "zone")
err := database.DeleteZone(name)
if errors.Is(err, db.ErrNotFound) {
writeError(w, http.StatusNotFound, "zone not found")
return
}
if err != nil {
writeError(w, http.StatusInternalServerError, "failed to delete zone")
return
}
w.WriteHeader(http.StatusNoContent)
}
}