Implement MCNS v1: custom Go DNS server replacing CoreDNS
Replace the CoreDNS precursor with a purpose-built authoritative DNS server. Zones and records (A, AAAA, CNAME) are stored in SQLite and managed via synchronized gRPC + REST APIs authenticated through MCIAS. Non-authoritative queries are forwarded to upstream resolvers with in-memory caching. Key components: - DNS server (miekg/dns) with authoritative zone handling and forwarding - gRPC + REST management APIs with MCIAS auth (mcdsl integration) - SQLite storage with CNAME exclusivity enforcement and auto SOA serials - 30 tests covering database CRUD, DNS resolution, and caching Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
49
internal/config/config.go
Normal file
49
internal/config/config.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
mcdslconfig "git.wntrmute.dev/kyle/mcdsl/config"
|
||||
)
|
||||
|
||||
// Config is the top-level MCNS configuration.
|
||||
type Config struct {
|
||||
mcdslconfig.Base
|
||||
DNS DNSConfig `toml:"dns"`
|
||||
}
|
||||
|
||||
// DNSConfig holds the DNS server settings.
|
||||
type DNSConfig struct {
|
||||
ListenAddr string `toml:"listen_addr"`
|
||||
Upstreams []string `toml:"upstreams"`
|
||||
}
|
||||
|
||||
// Load reads a TOML config file, applies environment variable overrides
|
||||
// (MCNS_ prefix), sets defaults, and validates required fields.
|
||||
func Load(path string) (*Config, error) {
|
||||
cfg, err := mcdslconfig.Load[Config](path, "MCNS")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Apply DNS defaults.
|
||||
if cfg.DNS.ListenAddr == "" {
|
||||
cfg.DNS.ListenAddr = ":53"
|
||||
}
|
||||
if len(cfg.DNS.Upstreams) == 0 {
|
||||
cfg.DNS.Upstreams = []string{"1.1.1.1:53", "8.8.8.8:53"}
|
||||
}
|
||||
|
||||
return cfg, nil
|
||||
}
|
||||
|
||||
// Validate implements the mcdsl config.Validator interface.
|
||||
func (c *Config) Validate() error {
|
||||
if c.Database.Path == "" {
|
||||
return fmt.Errorf("database.path is required")
|
||||
}
|
||||
if c.MCIAS.ServerURL == "" {
|
||||
return fmt.Errorf("mcias.server_url is required")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
23
internal/db/db.go
Normal file
23
internal/db/db.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
||||
)
|
||||
|
||||
// DB wraps a SQLite database connection.
|
||||
type DB struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
// Open opens (or creates) a SQLite database at the given path with the
|
||||
// standard Metacircular pragmas: WAL mode, foreign keys, busy timeout.
|
||||
func Open(path string) (*DB, error) {
|
||||
sqlDB, err := mcdsldb.Open(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("db: %w", err)
|
||||
}
|
||||
return &DB{sqlDB}, nil
|
||||
}
|
||||
46
internal/db/migrate.go
Normal file
46
internal/db/migrate.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
mcdsldb "git.wntrmute.dev/kyle/mcdsl/db"
|
||||
)
|
||||
|
||||
// Migrations is the ordered list of MCNS schema migrations.
|
||||
var Migrations = []mcdsldb.Migration{
|
||||
{
|
||||
Version: 1,
|
||||
Name: "zones and records",
|
||||
SQL: `
|
||||
CREATE TABLE IF NOT EXISTS zones (
|
||||
id INTEGER PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
primary_ns TEXT NOT NULL,
|
||||
admin_email TEXT NOT NULL,
|
||||
refresh INTEGER NOT NULL DEFAULT 3600,
|
||||
retry INTEGER NOT NULL DEFAULT 600,
|
||||
expire INTEGER NOT NULL DEFAULT 86400,
|
||||
minimum_ttl INTEGER NOT NULL DEFAULT 300,
|
||||
serial INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS records (
|
||||
id INTEGER PRIMARY KEY,
|
||||
zone_id INTEGER NOT NULL REFERENCES zones(id) ON DELETE CASCADE,
|
||||
name TEXT NOT NULL,
|
||||
type TEXT NOT NULL CHECK (type IN ('A', 'AAAA', 'CNAME')),
|
||||
value TEXT NOT NULL,
|
||||
ttl INTEGER NOT NULL DEFAULT 300,
|
||||
created_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
updated_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%SZ', 'now')),
|
||||
UNIQUE(zone_id, name, type, value)
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_records_zone_name ON records(zone_id, name);`,
|
||||
},
|
||||
}
|
||||
|
||||
// Migrate applies all pending migrations.
|
||||
func (d *DB) Migrate() error {
|
||||
return mcdsldb.Migrate(d.DB, Migrations)
|
||||
}
|
||||
308
internal/db/records.go
Normal file
308
internal/db/records.go
Normal file
@@ -0,0 +1,308 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Record represents a DNS record stored in the database.
|
||||
type Record struct {
|
||||
ID int64
|
||||
ZoneID int64
|
||||
ZoneName string
|
||||
Name string
|
||||
Type string
|
||||
Value string
|
||||
TTL int
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
// ListRecords returns records for a zone, optionally filtered by name and type.
|
||||
func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
|
||||
zone, err := d.GetZone(zoneName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?`
|
||||
args := []any{zone.ID}
|
||||
|
||||
if name != "" {
|
||||
query += ` AND r.name = ?`
|
||||
args = append(args, strings.ToLower(name))
|
||||
}
|
||||
if recordType != "" {
|
||||
query += ` AND r.type = ?`
|
||||
args = append(args, strings.ToUpper(recordType))
|
||||
}
|
||||
|
||||
query += ` ORDER BY r.name, r.type, r.value`
|
||||
|
||||
rows, err := d.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []Record
|
||||
for rows.Next() {
|
||||
var r Record
|
||||
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan record: %w", err)
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// LookupRecords returns records matching a name and type within a zone.
|
||||
// Used by the DNS handler for query resolution.
|
||||
func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
name = strings.ToLower(name)
|
||||
|
||||
rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id
|
||||
WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup records: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var records []Record
|
||||
for rows.Next() {
|
||||
var r Record
|
||||
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan record: %w", err)
|
||||
}
|
||||
records = append(records, r)
|
||||
}
|
||||
return records, rows.Err()
|
||||
}
|
||||
|
||||
// LookupCNAME returns CNAME records for a name within a zone.
|
||||
func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) {
|
||||
return d.LookupRecords(zoneName, name, "CNAME")
|
||||
}
|
||||
|
||||
// HasRecordsForName checks if any records of the given types exist for a name.
|
||||
func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) {
|
||||
placeholders := make([]string, len(types))
|
||||
args := []any{zoneID, strings.ToLower(name)}
|
||||
for i, t := range types {
|
||||
placeholders[i] = "?"
|
||||
args = append(args, strings.ToUpper(t))
|
||||
}
|
||||
query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ","))
|
||||
|
||||
var count int
|
||||
if err := tx.QueryRow(query, args...).Scan(&count); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
// GetRecord returns a record by ID.
|
||||
func (d *DB) GetRecord(id int64) (*Record, error) {
|
||||
var r Record
|
||||
err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
||||
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id).
|
||||
Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get record: %w", err)
|
||||
}
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// CreateRecord inserts a new record, enforcing CNAME exclusivity and
|
||||
// value validation. Bumps the zone serial within the same transaction.
|
||||
func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) {
|
||||
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
||||
name = strings.ToLower(name)
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
if err := validateRecordValue(recordType, value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
zone, err := d.GetZone(zoneName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = 300
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Enforce CNAME exclusivity.
|
||||
if recordType == "CNAME" {
|
||||
hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasAddr {
|
||||
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
||||
}
|
||||
} else if recordType == "A" || recordType == "AAAA" {
|
||||
hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasCNAME {
|
||||
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`,
|
||||
zone.ID, name, recordType, value, ttl)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
||||
}
|
||||
return nil, fmt.Errorf("insert record: %w", err)
|
||||
}
|
||||
|
||||
recordID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("last insert id: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, zone.ID); err != nil {
|
||||
return nil, fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return d.GetRecord(recordID)
|
||||
}
|
||||
|
||||
// UpdateRecord updates an existing record's fields and bumps the zone serial.
|
||||
func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) {
|
||||
name = strings.ToLower(name)
|
||||
recordType = strings.ToUpper(recordType)
|
||||
|
||||
if err := validateRecordValue(recordType, value); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existing, err := d.GetRecord(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ttl <= 0 {
|
||||
ttl = 300
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
// Enforce CNAME exclusivity for the new type/name combo.
|
||||
if recordType == "CNAME" {
|
||||
hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasAddr {
|
||||
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
||||
}
|
||||
} else if recordType == "A" || recordType == "AAAA" {
|
||||
hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
||||
}
|
||||
if hasCNAME {
|
||||
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
||||
}
|
||||
}
|
||||
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
_, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`,
|
||||
name, recordType, value, ttl, now, id)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
||||
}
|
||||
return nil, fmt.Errorf("update record: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
||||
return nil, fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
return nil, fmt.Errorf("commit: %w", err)
|
||||
}
|
||||
|
||||
return d.GetRecord(id)
|
||||
}
|
||||
|
||||
// DeleteRecord deletes a record and bumps the zone serial.
|
||||
func (d *DB) DeleteRecord(id int64) error {
|
||||
existing, err := d.GetRecord(id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("begin tx: %w", err)
|
||||
}
|
||||
defer func() { _ = tx.Rollback() }()
|
||||
|
||||
_, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete record: %w", err)
|
||||
}
|
||||
|
||||
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
||||
return fmt.Errorf("bump serial: %w", err)
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// validateRecordValue checks that a record value is valid for its type.
|
||||
func validateRecordValue(recordType, value string) error {
|
||||
switch recordType {
|
||||
case "A":
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil || ip.To4() == nil {
|
||||
return fmt.Errorf("invalid IPv4 address: %q", value)
|
||||
}
|
||||
case "AAAA":
|
||||
ip := net.ParseIP(value)
|
||||
if ip == nil || ip.To4() != nil {
|
||||
return fmt.Errorf("invalid IPv6 address: %q", value)
|
||||
}
|
||||
case "CNAME":
|
||||
if !strings.HasSuffix(value, ".") {
|
||||
return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unsupported record type: %q", recordType)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
289
internal/db/records_test.go
Normal file
289
internal/db/records_test.go
Normal file
@@ -0,0 +1,289 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createTestZone(t *testing.T, db *DB) *Zone {
|
||||
t.Helper()
|
||||
zone, err := db.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
return zone
|
||||
}
|
||||
|
||||
func TestCreateRecordA(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Name != "metacrypt" {
|
||||
t.Fatalf("got name %q, want %q", record.Name, "metacrypt")
|
||||
}
|
||||
if record.Type != "A" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "A")
|
||||
}
|
||||
if record.Value != "192.168.88.181" {
|
||||
t.Fatalf("got value %q, want %q", record.Value, "192.168.88.181")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordAAAA(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "AAAA", "2001:db8::1", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Type != "AAAA" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "AAAA")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAME(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
if record.Type != "CNAME" {
|
||||
t.Fatalf("got type %q, want %q", record.Type, "CNAME")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordInvalidIP(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "bad", "A", "not-an-ip", 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for invalid IPv4")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEExclusivity(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
// Create an A record first.
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create A record: %v", err)
|
||||
}
|
||||
|
||||
// Trying to add a CNAME for the same name should fail.
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if !errors.Is(err, ErrConflict) {
|
||||
t.Fatalf("expected ErrConflict, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEExclusivityReverse(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
// Create a CNAME record first.
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net.", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create CNAME record: %v", err)
|
||||
}
|
||||
|
||||
// Trying to add an A record for the same name should fail.
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "alias", "A", "192.168.88.181", 300)
|
||||
if !errors.Is(err, ErrConflict) {
|
||||
t.Fatalf("expected ErrConflict, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordBumpsSerial(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
zone := createTestZone(t, db)
|
||||
originalSerial := zone.Serial
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
if updated.Serial <= originalSerial {
|
||||
t.Fatalf("serial should have bumped: %d <= %d", updated.Serial, originalSerial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestListRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 1: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 2: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "mcr", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record 3: %v", err)
|
||||
}
|
||||
|
||||
// List all records.
|
||||
records, err := db.ListRecords("svc.mcp.metacircular.net", "", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list records: %v", err)
|
||||
}
|
||||
if len(records) != 3 {
|
||||
t.Fatalf("got %d records, want 3", len(records))
|
||||
}
|
||||
|
||||
// Filter by name.
|
||||
records, err = db.ListRecords("svc.mcp.metacircular.net", "metacrypt", "")
|
||||
if err != nil {
|
||||
t.Fatalf("list records by name: %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
|
||||
// Filter by type.
|
||||
records, err = db.ListRecords("svc.mcp.metacircular.net", "", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("list records by type: %v", err)
|
||||
}
|
||||
if len(records) != 3 {
|
||||
t.Fatalf("got %d records, want 3", len(records))
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateRecord(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.UpdateRecord(record.ID, "metacrypt", "A", "10.0.0.1", 600)
|
||||
if err != nil {
|
||||
t.Fatalf("update record: %v", err)
|
||||
}
|
||||
if updated.Value != "10.0.0.1" {
|
||||
t.Fatalf("got value %q, want %q", updated.Value, "10.0.0.1")
|
||||
}
|
||||
if updated.TTL != 600 {
|
||||
t.Fatalf("got ttl %d, want 600", updated.TTL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRecord(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteRecord(record.ID); err != nil {
|
||||
t.Fatalf("delete record: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetRecord(record.ID)
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteRecordBumpsSerial(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
record, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
zone, err := db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
serialBefore := zone.Serial
|
||||
|
||||
if err := db.DeleteRecord(record.ID); err != nil {
|
||||
t.Fatalf("delete record: %v", err)
|
||||
}
|
||||
|
||||
zone, err = db.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone after delete: %v", err)
|
||||
}
|
||||
if zone.Serial <= serialBefore {
|
||||
t.Fatalf("serial should have bumped: %d <= %d", zone.Serial, serialBefore)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLookupRecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create record: %v", err)
|
||||
}
|
||||
|
||||
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("lookup records: %v", err)
|
||||
}
|
||||
if len(records) != 1 {
|
||||
t.Fatalf("got %d records, want 1", len(records))
|
||||
}
|
||||
if records[0].Value != "192.168.88.181" {
|
||||
t.Fatalf("got value %q, want %q", records[0].Value, "192.168.88.181")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRecordCNAMEMissingDot(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "rift.mcp.metacircular.net", 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for CNAME without trailing dot")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMultipleARecords(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
createTestZone(t, db)
|
||||
|
||||
_, err := db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create first A record: %v", err)
|
||||
}
|
||||
_, err = db.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create second A record: %v", err)
|
||||
}
|
||||
|
||||
records, err := db.LookupRecords("svc.mcp.metacircular.net", "metacrypt", "A")
|
||||
if err != nil {
|
||||
t.Fatalf("lookup records: %v", err)
|
||||
}
|
||||
if len(records) != 2 {
|
||||
t.Fatalf("got %d records, want 2", len(records))
|
||||
}
|
||||
}
|
||||
182
internal/db/zones.go
Normal file
182
internal/db/zones.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Zone represents a DNS zone stored in the database.
|
||||
type Zone struct {
|
||||
ID int64
|
||||
Name string
|
||||
PrimaryNS string
|
||||
AdminEmail string
|
||||
Refresh int
|
||||
Retry int
|
||||
Expire int
|
||||
MinimumTTL int
|
||||
Serial int64
|
||||
CreatedAt string
|
||||
UpdatedAt string
|
||||
}
|
||||
|
||||
// ErrNotFound is returned when a requested resource does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// ErrConflict is returned when a write conflicts with existing data.
|
||||
var ErrConflict = errors.New("conflict")
|
||||
|
||||
// ListZones returns all zones ordered by name.
|
||||
func (d *DB) ListZones() ([]Zone, error) {
|
||||
rows, err := d.Query(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("list zones: %w", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var zones []Zone
|
||||
for rows.Next() {
|
||||
var z Zone
|
||||
if err := rows.Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt); err != nil {
|
||||
return nil, fmt.Errorf("scan zone: %w", err)
|
||||
}
|
||||
zones = append(zones, z)
|
||||
}
|
||||
return zones, rows.Err()
|
||||
}
|
||||
|
||||
// GetZone returns a zone by name (case-insensitive).
|
||||
func (d *DB) GetZone(name string) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
var z Zone
|
||||
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE name = ?`, name).
|
||||
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get zone: %w", err)
|
||||
}
|
||||
return &z, nil
|
||||
}
|
||||
|
||||
// GetZoneByID returns a zone by ID.
|
||||
func (d *DB) GetZoneByID(id int64) (*Zone, error) {
|
||||
var z Zone
|
||||
err := d.QueryRow(`SELECT id, name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial, created_at, updated_at FROM zones WHERE id = ?`, id).
|
||||
Scan(&z.ID, &z.Name, &z.PrimaryNS, &z.AdminEmail, &z.Refresh, &z.Retry, &z.Expire, &z.MinimumTTL, &z.Serial, &z.CreatedAt, &z.UpdatedAt)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get zone by id: %w", err)
|
||||
}
|
||||
return &z, nil
|
||||
}
|
||||
|
||||
// CreateZone inserts a new zone and returns it with the generated serial.
|
||||
func (d *DB) CreateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
serial := nextSerial(0)
|
||||
|
||||
res, err := d.Exec(`INSERT INTO zones (name, primary_ns, admin_email, refresh, retry, expire, minimum_ttl, serial) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
|
||||
name, primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial)
|
||||
if err != nil {
|
||||
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
||||
return nil, fmt.Errorf("%w: zone %q already exists", ErrConflict, name)
|
||||
}
|
||||
return nil, fmt.Errorf("create zone: %w", err)
|
||||
}
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create zone: last insert id: %w", err)
|
||||
}
|
||||
|
||||
return d.GetZoneByID(id)
|
||||
}
|
||||
|
||||
// UpdateZone updates a zone's SOA parameters and bumps the serial.
|
||||
func (d *DB) UpdateZone(name, primaryNS, adminEmail string, refresh, retry, expire, minimumTTL int) (*Zone, error) {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
|
||||
zone, err := d.GetZone(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serial := nextSerial(zone.Serial)
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
|
||||
_, err = d.Exec(`UPDATE zones SET primary_ns = ?, admin_email = ?, refresh = ?, retry = ?, expire = ?, minimum_ttl = ?, serial = ?, updated_at = ? WHERE id = ?`,
|
||||
primaryNS, adminEmail, refresh, retry, expire, minimumTTL, serial, now, zone.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("update zone: %w", err)
|
||||
}
|
||||
|
||||
return d.GetZoneByID(zone.ID)
|
||||
}
|
||||
|
||||
// DeleteZone deletes a zone and all its records.
|
||||
func (d *DB) DeleteZone(name string) error {
|
||||
name = strings.ToLower(strings.TrimSuffix(name, "."))
|
||||
res, err := d.Exec(`DELETE FROM zones WHERE name = ?`, name)
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete zone: %w", err)
|
||||
}
|
||||
n, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return fmt.Errorf("delete zone: rows affected: %w", err)
|
||||
}
|
||||
if n == 0 {
|
||||
return ErrNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// BumpSerial increments the serial for a zone within a transaction.
|
||||
func (d *DB) BumpSerial(tx *sql.Tx, zoneID int64) error {
|
||||
var current int64
|
||||
if err := tx.QueryRow(`SELECT serial FROM zones WHERE id = ?`, zoneID).Scan(¤t); err != nil {
|
||||
return fmt.Errorf("read serial: %w", err)
|
||||
}
|
||||
serial := nextSerial(current)
|
||||
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
||||
_, err := tx.Exec(`UPDATE zones SET serial = ?, updated_at = ? WHERE id = ?`, serial, now, zoneID)
|
||||
return err
|
||||
}
|
||||
|
||||
// ZoneNames returns all zone names for the DNS handler.
|
||||
func (d *DB) ZoneNames() ([]string, error) {
|
||||
rows, err := d.Query(`SELECT name FROM zones ORDER BY name`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var names []string
|
||||
for rows.Next() {
|
||||
var name string
|
||||
if err := rows.Scan(&name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
return names, rows.Err()
|
||||
}
|
||||
|
||||
// nextSerial computes the next SOA serial in YYYYMMDDNN format.
|
||||
func nextSerial(current int64) int64 {
|
||||
today := time.Now().UTC()
|
||||
datePrefix, _ := strconv.ParseInt(today.Format("20060102"), 10, 64)
|
||||
datePrefix *= 100 // YYYYMMDD00
|
||||
|
||||
if current >= datePrefix {
|
||||
return current + 1
|
||||
}
|
||||
return datePrefix + 1
|
||||
}
|
||||
167
internal/db/zones_test.go
Normal file
167
internal/db/zones_test.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func openTestDB(t *testing.T) *DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = database.Close() })
|
||||
return database
|
||||
}
|
||||
|
||||
func TestCreateZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
zone, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
if zone.Serial == 0 {
|
||||
t.Fatal("serial should not be zero")
|
||||
}
|
||||
if zone.PrimaryNS != "ns1.example.com." {
|
||||
t.Fatalf("got primary_ns %q, want %q", zone.PrimaryNS, "ns1.example.com.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateZoneDuplicate(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err == nil {
|
||||
t.Fatal("expected error for duplicate zone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateZoneNormalization(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
zone, err := db.CreateZone("Example.COM.", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
}
|
||||
|
||||
func TestListZones(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("b.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone b: %v", err)
|
||||
}
|
||||
_, err = db.CreateZone("a.example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone a: %v", err)
|
||||
}
|
||||
|
||||
zones, err := db.ListZones()
|
||||
if err != nil {
|
||||
t.Fatalf("list zones: %v", err)
|
||||
}
|
||||
if len(zones) != 2 {
|
||||
t.Fatalf("got %d zones, want 2", len(zones))
|
||||
}
|
||||
if zones[0].Name != "a.example.com" {
|
||||
t.Fatalf("zones should be ordered by name, got %q first", zones[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
zone, err := db.GetZone("example.com")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
if zone.Name != "example.com" {
|
||||
t.Fatalf("got name %q, want %q", zone.Name, "example.com")
|
||||
}
|
||||
|
||||
_, err = db.GetZone("nonexistent.com")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
original, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
updated, err := db.UpdateZone("example.com", "ns2.example.com.", "newadmin.example.com.", 7200, 1200, 172800, 600)
|
||||
if err != nil {
|
||||
t.Fatalf("update zone: %v", err)
|
||||
}
|
||||
|
||||
if updated.PrimaryNS != "ns2.example.com." {
|
||||
t.Fatalf("got primary_ns %q, want %q", updated.PrimaryNS, "ns2.example.com.")
|
||||
}
|
||||
if updated.Serial <= original.Serial {
|
||||
t.Fatalf("serial should have incremented: %d <= %d", updated.Serial, original.Serial)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteZone(t *testing.T) {
|
||||
db := openTestDB(t)
|
||||
|
||||
_, err := db.CreateZone("example.com", "ns1.example.com.", "admin.example.com.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteZone("example.com"); err != nil {
|
||||
t.Fatalf("delete zone: %v", err)
|
||||
}
|
||||
|
||||
_, err = db.GetZone("example.com")
|
||||
if err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound after delete, got %v", err)
|
||||
}
|
||||
|
||||
if err := db.DeleteZone("nonexistent.com"); err != ErrNotFound {
|
||||
t.Fatalf("expected ErrNotFound for nonexistent zone, got %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextSerial(t *testing.T) {
|
||||
// A zero serial should produce a date-based serial.
|
||||
s1 := nextSerial(0)
|
||||
if s1 < 2026032600 {
|
||||
t.Fatalf("serial %d seems too low", s1)
|
||||
}
|
||||
|
||||
// Incrementing should increase.
|
||||
s2 := nextSerial(s1)
|
||||
if s2 != s1+1 {
|
||||
t.Fatalf("expected %d, got %d", s1+1, s2)
|
||||
}
|
||||
}
|
||||
67
internal/dns/cache.go
Normal file
67
internal/dns/cache.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type cacheKey struct {
|
||||
Name string
|
||||
Qtype uint16
|
||||
Class uint16
|
||||
}
|
||||
|
||||
type cacheEntry struct {
|
||||
msg *dns.Msg
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
// Cache is a thread-safe in-memory DNS response cache with TTL-based expiry.
|
||||
type Cache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[cacheKey]*cacheEntry
|
||||
}
|
||||
|
||||
// NewCache creates an empty DNS cache.
|
||||
func NewCache() *Cache {
|
||||
return &Cache{
|
||||
entries: make(map[cacheKey]*cacheEntry),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns a cached response if it exists and has not expired.
|
||||
func (c *Cache) Get(name string, qtype, class uint16) *dns.Msg {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
key := cacheKey{Name: name, Qtype: qtype, Class: class}
|
||||
entry, ok := c.entries[key]
|
||||
if !ok || time.Now().After(entry.expiresAt) {
|
||||
return nil
|
||||
}
|
||||
return entry.msg
|
||||
}
|
||||
|
||||
// Set stores a DNS response in the cache with the given TTL.
|
||||
func (c *Cache) Set(name string, qtype, class uint16, msg *dns.Msg, ttl time.Duration) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
key := cacheKey{Name: name, Qtype: qtype, Class: class}
|
||||
c.entries[key] = &cacheEntry{
|
||||
msg: msg.Copy(),
|
||||
expiresAt: time.Now().Add(ttl),
|
||||
}
|
||||
|
||||
// Lazy eviction: clean up expired entries if cache is growing.
|
||||
if len(c.entries) > 1000 {
|
||||
now := time.Now()
|
||||
for k, v := range c.entries {
|
||||
if now.After(v.expiresAt) {
|
||||
delete(c.entries, k)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
81
internal/dns/cache_test.go
Normal file
81
internal/dns/cache_test.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func TestCacheSetGet(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
msg.Answer = append(msg.Answer, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
||||
A: []byte{1, 2, 3, 4},
|
||||
})
|
||||
|
||||
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 5*time.Second)
|
||||
|
||||
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
|
||||
if cached == nil {
|
||||
t.Fatal("expected cached response")
|
||||
}
|
||||
if len(cached.Answer) != 1 {
|
||||
t.Fatalf("got %d answers, want 1", len(cached.Answer))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheMiss(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
|
||||
if cached != nil {
|
||||
t.Fatal("expected nil for cache miss")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheExpiry(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
msg := new(dns.Msg)
|
||||
msg.SetQuestion("example.com.", dns.TypeA)
|
||||
|
||||
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 1*time.Millisecond)
|
||||
time.Sleep(2 * time.Millisecond)
|
||||
|
||||
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
|
||||
if cached != nil {
|
||||
t.Fatal("expected nil for expired entry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheDifferentTypes(t *testing.T) {
|
||||
c := NewCache()
|
||||
|
||||
msgA := new(dns.Msg)
|
||||
msgA.SetQuestion("example.com.", dns.TypeA)
|
||||
c.Set("example.com.", dns.TypeA, dns.ClassINET, msgA, 5*time.Second)
|
||||
|
||||
msgAAAA := new(dns.Msg)
|
||||
msgAAAA.SetQuestion("example.com.", dns.TypeAAAA)
|
||||
c.Set("example.com.", dns.TypeAAAA, dns.ClassINET, msgAAAA, 5*time.Second)
|
||||
|
||||
cachedA := c.Get("example.com.", dns.TypeA, dns.ClassINET)
|
||||
if cachedA == nil {
|
||||
t.Fatal("expected cached A response")
|
||||
}
|
||||
|
||||
cachedAAAA := c.Get("example.com.", dns.TypeAAAA, dns.ClassINET)
|
||||
if cachedAAAA == nil {
|
||||
t.Fatal("expected cached AAAA response")
|
||||
}
|
||||
|
||||
// Different type should not match.
|
||||
cachedMX := c.Get("example.com.", dns.TypeMX, dns.ClassINET)
|
||||
if cachedMX != nil {
|
||||
t.Fatal("expected nil for uncached type")
|
||||
}
|
||||
}
|
||||
87
internal/dns/forwarder.go
Normal file
87
internal/dns/forwarder.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Forwarder handles forwarding DNS queries to upstream resolvers.
|
||||
type Forwarder struct {
|
||||
upstreams []string
|
||||
client *dns.Client
|
||||
cache *Cache
|
||||
}
|
||||
|
||||
// NewForwarder creates a Forwarder with the given upstream addresses.
|
||||
func NewForwarder(upstreams []string) *Forwarder {
|
||||
return &Forwarder{
|
||||
upstreams: upstreams,
|
||||
client: &dns.Client{
|
||||
Timeout: 2 * time.Second,
|
||||
},
|
||||
cache: NewCache(),
|
||||
}
|
||||
}
|
||||
|
||||
// Forward sends a query to upstream resolvers and returns the response.
|
||||
// Responses are cached by (qname, qtype, qclass) with TTL-based expiry.
|
||||
func (f *Forwarder) Forward(r *dns.Msg) (*dns.Msg, error) {
|
||||
if len(r.Question) == 0 {
|
||||
return nil, fmt.Errorf("empty question")
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
|
||||
// Check cache.
|
||||
if cached := f.cache.Get(q.Name, q.Qtype, q.Qclass); cached != nil {
|
||||
return cached.Copy(), nil
|
||||
}
|
||||
|
||||
// Try each upstream in order.
|
||||
var lastErr error
|
||||
for _, upstream := range f.upstreams {
|
||||
resp, _, err := f.client.Exchange(r, upstream)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// Don't cache SERVFAIL or REFUSED.
|
||||
if resp.Rcode != dns.RcodeServerFailure && resp.Rcode != dns.RcodeRefused {
|
||||
ttl := minTTL(resp)
|
||||
if ttl > 300 {
|
||||
ttl = 300
|
||||
}
|
||||
if ttl > 0 {
|
||||
f.cache.Set(q.Name, q.Qtype, q.Qclass, resp, time.Duration(ttl)*time.Second)
|
||||
}
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("all upstreams failed: %w", lastErr)
|
||||
}
|
||||
|
||||
// minTTL returns the minimum TTL from all resource records in a response.
|
||||
func minTTL(msg *dns.Msg) uint32 {
|
||||
var min uint32
|
||||
first := true
|
||||
|
||||
for _, sections := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} {
|
||||
for _, rr := range sections {
|
||||
ttl := rr.Header().Ttl
|
||||
if first || ttl < min {
|
||||
min = ttl
|
||||
first = false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if first {
|
||||
return 60 // No records; default to 60s.
|
||||
}
|
||||
return min
|
||||
}
|
||||
280
internal/dns/server.go
Normal file
280
internal/dns/server.go
Normal file
@@ -0,0 +1,280 @@
|
||||
// Package dns implements the authoritative DNS server for MCNS.
|
||||
// It serves records from SQLite for authoritative zones and forwards
|
||||
// all other queries to configured upstream resolvers.
|
||||
package dns
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
// Server is the MCNS DNS server. It listens on both UDP and TCP.
|
||||
type Server struct {
|
||||
db *db.DB
|
||||
forwarder *Forwarder
|
||||
logger *slog.Logger
|
||||
udp *dns.Server
|
||||
tcp *dns.Server
|
||||
}
|
||||
|
||||
// New creates a DNS server that serves records from the database and
|
||||
// forwards non-authoritative queries to the given upstreams.
|
||||
func New(database *db.DB, upstreams []string, logger *slog.Logger) *Server {
|
||||
s := &Server{
|
||||
db: database,
|
||||
forwarder: NewForwarder(upstreams),
|
||||
logger: logger,
|
||||
}
|
||||
|
||||
mux := dns.NewServeMux()
|
||||
mux.HandleFunc(".", s.handleQuery)
|
||||
|
||||
s.udp = &dns.Server{Handler: mux, Net: "udp"}
|
||||
s.tcp = &dns.Server{Handler: mux, Net: "tcp"}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// ListenAndServe starts the DNS server on the given address for both
|
||||
// UDP and TCP. It blocks until Shutdown is called.
|
||||
func (s *Server) ListenAndServe(addr string) error {
|
||||
s.udp.Addr = addr
|
||||
s.tcp.Addr = addr
|
||||
|
||||
errCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
s.logger.Info("dns server listening", "addr", addr, "proto", "udp")
|
||||
errCh <- s.udp.ListenAndServe()
|
||||
}()
|
||||
go func() {
|
||||
s.logger.Info("dns server listening", "addr", addr, "proto", "tcp")
|
||||
errCh <- s.tcp.ListenAndServe()
|
||||
}()
|
||||
|
||||
return <-errCh
|
||||
}
|
||||
|
||||
// Shutdown gracefully stops the DNS server.
|
||||
func (s *Server) Shutdown() {
|
||||
_ = s.udp.Shutdown()
|
||||
_ = s.tcp.Shutdown()
|
||||
}
|
||||
|
||||
// handleQuery is the main DNS query handler. It checks if the query
|
||||
// falls within an authoritative zone and either serves from the database
|
||||
// or forwards to upstream.
|
||||
func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
|
||||
if len(r.Question) == 0 {
|
||||
s.writeResponse(w, r, dns.RcodeFormatError, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
q := r.Question[0]
|
||||
qname := strings.ToLower(q.Name)
|
||||
|
||||
// Find the authoritative zone for this query.
|
||||
zone := s.findZone(qname)
|
||||
if zone == nil {
|
||||
// Not authoritative — forward to upstream.
|
||||
s.forwardQuery(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
s.handleAuthoritativeQuery(w, r, zone, qname, q.Qtype)
|
||||
}
|
||||
|
||||
// findZone returns the best matching zone for the query name, or nil.
|
||||
func (s *Server) findZone(qname string) *db.Zone {
|
||||
// Walk up the domain labels to find the longest matching zone.
|
||||
name := strings.TrimSuffix(qname, ".")
|
||||
parts := strings.Split(name, ".")
|
||||
|
||||
for i := range parts {
|
||||
candidate := strings.Join(parts[i:], ".")
|
||||
zone, err := s.db.GetZone(candidate)
|
||||
if err == nil {
|
||||
return zone
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAuthoritativeQuery serves a query from the database.
|
||||
func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone *db.Zone, qname string, qtype uint16) {
|
||||
// Extract the record name relative to the zone.
|
||||
zoneFQDN := zone.Name + "."
|
||||
var relName string
|
||||
if qname == zoneFQDN {
|
||||
relName = "@"
|
||||
} else {
|
||||
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
|
||||
}
|
||||
|
||||
// Handle SOA queries.
|
||||
if qtype == dns.TypeSOA || relName == "@" && qtype == dns.TypeSOA {
|
||||
soa := s.buildSOA(zone)
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle NS queries at the zone apex.
|
||||
if qtype == dns.TypeNS && relName == "@" {
|
||||
ns := &dns.NS{
|
||||
Hdr: dns.RR_Header{Name: zoneFQDN, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
|
||||
Ns: zone.PrimaryNS,
|
||||
}
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{ns}, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Look up the requested record type.
|
||||
var answers []dns.RR
|
||||
var lookupType string
|
||||
|
||||
switch qtype {
|
||||
case dns.TypeA:
|
||||
lookupType = "A"
|
||||
case dns.TypeAAAA:
|
||||
lookupType = "AAAA"
|
||||
case dns.TypeCNAME:
|
||||
lookupType = "CNAME"
|
||||
default:
|
||||
// For unsupported types, check if the name exists at all.
|
||||
// If it does, return empty answer. If not, NXDOMAIN.
|
||||
exists, _ := s.nameExists(zone.Name, relName)
|
||||
if exists {
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
|
||||
} else {
|
||||
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
records, err := s.db.LookupRecords(zone.Name, relName, lookupType)
|
||||
if err != nil {
|
||||
s.logger.Error("dns lookup failed", "zone", zone.Name, "name", relName, "type", lookupType, "error", err)
|
||||
s.writeResponse(w, r, dns.RcodeServerFailure, nil, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// If no direct records, check for CNAME.
|
||||
if len(records) == 0 && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
||||
cnameRecords, err := s.db.LookupCNAME(zone.Name, relName)
|
||||
if err == nil && len(cnameRecords) > 0 {
|
||||
for _, rec := range cnameRecords {
|
||||
answers = append(answers, &dns.CNAME{
|
||||
Hdr: dns.RR_Header{Name: qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(rec.TTL)},
|
||||
Target: rec.Value,
|
||||
})
|
||||
}
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if len(records) == 0 {
|
||||
// Name might still exist with other record types.
|
||||
exists, _ := s.nameExists(zone.Name, relName)
|
||||
if exists {
|
||||
// NODATA: name exists but no records of requested type.
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
|
||||
} else {
|
||||
// NXDOMAIN: name does not exist.
|
||||
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
for _, rec := range records {
|
||||
rr := s.recordToRR(qname, rec)
|
||||
if rr != nil {
|
||||
answers = append(answers, rr)
|
||||
}
|
||||
}
|
||||
|
||||
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
|
||||
}
|
||||
|
||||
// nameExists checks if any records exist for a name in a zone.
|
||||
func (s *Server) nameExists(zoneName, name string) (bool, error) {
|
||||
records, err := s.db.ListRecords(zoneName, name, "")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return len(records) > 0, nil
|
||||
}
|
||||
|
||||
// recordToRR converts a database Record to a dns.RR.
|
||||
func (s *Server) recordToRR(qname string, rec db.Record) dns.RR {
|
||||
hdr := dns.RR_Header{Name: qname, Class: dns.ClassINET, Ttl: uint32(rec.TTL)}
|
||||
|
||||
switch rec.Type {
|
||||
case "A":
|
||||
hdr.Rrtype = dns.TypeA
|
||||
return &dns.A{Hdr: hdr, A: parseIP(rec.Value)}
|
||||
case "AAAA":
|
||||
hdr.Rrtype = dns.TypeAAAA
|
||||
return &dns.AAAA{Hdr: hdr, AAAA: parseIP(rec.Value)}
|
||||
case "CNAME":
|
||||
hdr.Rrtype = dns.TypeCNAME
|
||||
return &dns.CNAME{Hdr: hdr, Target: rec.Value}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildSOA constructs a SOA record for the given zone.
|
||||
func (s *Server) buildSOA(zone *db.Zone) *dns.SOA {
|
||||
return &dns.SOA{
|
||||
Hdr: dns.RR_Header{Name: zone.Name + ".", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
|
||||
Ns: zone.PrimaryNS,
|
||||
Mbox: zone.AdminEmail,
|
||||
Serial: uint32(zone.Serial),
|
||||
Refresh: uint32(zone.Refresh),
|
||||
Retry: uint32(zone.Retry),
|
||||
Expire: uint32(zone.Expire),
|
||||
Minttl: uint32(zone.MinimumTTL),
|
||||
}
|
||||
}
|
||||
|
||||
// writeResponse constructs and writes a DNS response.
|
||||
func (s *Server) writeResponse(w dns.ResponseWriter, r *dns.Msg, rcode int, answer []dns.RR, ns []dns.RR) {
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Authoritative = true
|
||||
m.Rcode = rcode
|
||||
m.Answer = answer
|
||||
m.Ns = ns
|
||||
|
||||
if err := w.WriteMsg(m); err != nil {
|
||||
s.logger.Error("dns write failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// forwardQuery forwards a DNS query to upstream resolvers.
|
||||
func (s *Server) forwardQuery(w dns.ResponseWriter, r *dns.Msg) {
|
||||
resp, err := s.forwarder.Forward(r)
|
||||
if err != nil {
|
||||
s.logger.Debug("dns forward failed", "error", err)
|
||||
m := new(dns.Msg)
|
||||
m.SetReply(r)
|
||||
m.Rcode = dns.RcodeServerFailure
|
||||
_ = w.WriteMsg(m)
|
||||
return
|
||||
}
|
||||
|
||||
resp.Id = r.Id
|
||||
if err := w.WriteMsg(resp); err != nil {
|
||||
s.logger.Error("dns write failed", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
// parseIP parses an IP address string into a net.IP.
|
||||
func parseIP(s string) net.IP {
|
||||
return net.ParseIP(s)
|
||||
}
|
||||
142
internal/dns/server_test.go
Normal file
142
internal/dns/server_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
func openTestDB(t *testing.T) *db.DB {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
database, err := db.Open(filepath.Join(dir, "test.db"))
|
||||
if err != nil {
|
||||
t.Fatalf("open db: %v", err)
|
||||
}
|
||||
if err := database.Migrate(); err != nil {
|
||||
t.Fatalf("migrate: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = database.Close() })
|
||||
return database
|
||||
}
|
||||
|
||||
func setupTestServer(t *testing.T) (*Server, *db.DB) {
|
||||
t.Helper()
|
||||
database := openTestDB(t)
|
||||
logger := slog.Default()
|
||||
|
||||
_, err := database.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create zone: %v", err)
|
||||
}
|
||||
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create A record: %v", err)
|
||||
}
|
||||
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create A record 2: %v", err)
|
||||
}
|
||||
_, err = database.CreateRecord("svc.mcp.metacircular.net", "mcr", "AAAA", "2001:db8::1", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create AAAA record: %v", err)
|
||||
}
|
||||
_, err = database.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "metacrypt.svc.mcp.metacircular.net.", 300)
|
||||
if err != nil {
|
||||
t.Fatalf("create CNAME record: %v", err)
|
||||
}
|
||||
|
||||
srv := New(database, []string{"1.1.1.1:53"}, logger)
|
||||
return srv, database
|
||||
}
|
||||
|
||||
func TestFindZone(t *testing.T) {
|
||||
srv, _ := setupTestServer(t)
|
||||
|
||||
zone := srv.findZone("metacrypt.svc.mcp.metacircular.net.")
|
||||
if zone == nil {
|
||||
t.Fatal("expected to find zone")
|
||||
}
|
||||
if zone.Name != "svc.mcp.metacircular.net" {
|
||||
t.Fatalf("got zone %q, want %q", zone.Name, "svc.mcp.metacircular.net")
|
||||
}
|
||||
|
||||
zone = srv.findZone("nonexistent.com.")
|
||||
if zone != nil {
|
||||
t.Fatal("expected nil for nonexistent zone")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildSOA(t *testing.T) {
|
||||
srv, database := setupTestServer(t)
|
||||
zone, err := database.GetZone("svc.mcp.metacircular.net")
|
||||
if err != nil {
|
||||
t.Fatalf("get zone: %v", err)
|
||||
}
|
||||
|
||||
soa := srv.buildSOA(zone)
|
||||
if soa.Ns != "ns.mcp.metacircular.net." {
|
||||
t.Fatalf("got ns %q, want %q", soa.Ns, "ns.mcp.metacircular.net.")
|
||||
}
|
||||
if soa.Hdr.Name != "svc.mcp.metacircular.net." {
|
||||
t.Fatalf("got name %q, want %q", soa.Hdr.Name, "svc.mcp.metacircular.net.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordToRR_A(t *testing.T) {
|
||||
srv, _ := setupTestServer(t)
|
||||
|
||||
rec := db.Record{Name: "metacrypt", Type: "A", Value: "192.168.88.181", TTL: 300}
|
||||
rr := srv.recordToRR("metacrypt.svc.mcp.metacircular.net.", rec)
|
||||
if rr == nil {
|
||||
t.Fatal("expected non-nil RR")
|
||||
}
|
||||
|
||||
a, ok := rr.(*dns.A)
|
||||
if !ok {
|
||||
t.Fatalf("expected *dns.A, got %T", rr)
|
||||
}
|
||||
if a.A.String() != "192.168.88.181" {
|
||||
t.Fatalf("got IP %q, want %q", a.A.String(), "192.168.88.181")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordToRR_AAAA(t *testing.T) {
|
||||
srv, _ := setupTestServer(t)
|
||||
|
||||
rec := db.Record{Name: "mcr", Type: "AAAA", Value: "2001:db8::1", TTL: 300}
|
||||
rr := srv.recordToRR("mcr.svc.mcp.metacircular.net.", rec)
|
||||
if rr == nil {
|
||||
t.Fatal("expected non-nil RR")
|
||||
}
|
||||
|
||||
aaaa, ok := rr.(*dns.AAAA)
|
||||
if !ok {
|
||||
t.Fatalf("expected *dns.AAAA, got %T", rr)
|
||||
}
|
||||
if aaaa.AAAA.String() != "2001:db8::1" {
|
||||
t.Fatalf("got IP %q, want %q", aaaa.AAAA.String(), "2001:db8::1")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordToRR_CNAME(t *testing.T) {
|
||||
srv, _ := setupTestServer(t)
|
||||
|
||||
rec := db.Record{Name: "alias", Type: "CNAME", Value: "metacrypt.svc.mcp.metacircular.net.", TTL: 300}
|
||||
rr := srv.recordToRR("alias.svc.mcp.metacircular.net.", rec)
|
||||
if rr == nil {
|
||||
t.Fatal("expected non-nil RR")
|
||||
}
|
||||
|
||||
cname, ok := rr.(*dns.CNAME)
|
||||
if !ok {
|
||||
t.Fatalf("expected *dns.CNAME, got %T", rr)
|
||||
}
|
||||
if cname.Target != "metacrypt.svc.mcp.metacircular.net." {
|
||||
t.Fatalf("got target %q, want %q", cname.Target, "metacrypt.svc.mcp.metacircular.net.")
|
||||
}
|
||||
}
|
||||
20
internal/grpcserver/admin.go
Normal file
20
internal/grpcserver/admin.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
type adminService struct {
|
||||
pb.UnimplementedAdminServiceServer
|
||||
db *db.DB
|
||||
}
|
||||
|
||||
func (s *adminService) Health(_ context.Context, _ *pb.HealthRequest) (*pb.HealthResponse, error) {
|
||||
if err := s.db.Ping(); err != nil {
|
||||
return &pb.HealthResponse{Status: "unhealthy"}, nil
|
||||
}
|
||||
return &pb.HealthResponse{Status: "ok"}, nil
|
||||
}
|
||||
38
internal/grpcserver/auth_handler.go
Normal file
38
internal/grpcserver/auth_handler.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
|
||||
)
|
||||
|
||||
type authService struct {
|
||||
pb.UnimplementedAuthServiceServer
|
||||
auth *mcdslauth.Authenticator
|
||||
}
|
||||
|
||||
func (s *authService) Login(_ context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) {
|
||||
token, _, err := s.auth.Login(req.Username, req.Password, req.TotpCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, mcdslauth.ErrInvalidCredentials) {
|
||||
return nil, status.Error(codes.Unauthenticated, "invalid credentials")
|
||||
}
|
||||
if errors.Is(err, mcdslauth.ErrForbidden) {
|
||||
return nil, status.Error(codes.PermissionDenied, "access denied by login policy")
|
||||
}
|
||||
return nil, status.Error(codes.Unavailable, "authentication service unavailable")
|
||||
}
|
||||
return &pb.LoginResponse{Token: token}, nil
|
||||
}
|
||||
|
||||
func (s *authService) Logout(_ context.Context, req *pb.LogoutRequest) (*pb.LogoutResponse, error) {
|
||||
if err := s.auth.Logout(req.Token); err != nil {
|
||||
return nil, status.Error(codes.Internal, "logout failed")
|
||||
}
|
||||
return &pb.LogoutResponse{}, nil
|
||||
}
|
||||
45
internal/grpcserver/interceptors.go
Normal file
45
internal/grpcserver/interceptors.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
mcdslgrpc "git.wntrmute.dev/kyle/mcdsl/grpcserver"
|
||||
)
|
||||
|
||||
// methodMap builds the mcdsl grpcserver.MethodMap for MCNS.
|
||||
//
|
||||
// Adding a new RPC without adding it to the correct map is a security
|
||||
// defect — the mcdsl auth interceptor denies unmapped methods by default.
|
||||
func methodMap() mcdslgrpc.MethodMap {
|
||||
return mcdslgrpc.MethodMap{
|
||||
Public: publicMethods(),
|
||||
AuthRequired: authRequiredMethods(),
|
||||
AdminRequired: adminRequiredMethods(),
|
||||
}
|
||||
}
|
||||
|
||||
func publicMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/mcns.v1.AdminService/Health": true,
|
||||
"/mcns.v1.AuthService/Login": true,
|
||||
}
|
||||
}
|
||||
|
||||
func authRequiredMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/mcns.v1.AuthService/Logout": true,
|
||||
"/mcns.v1.ZoneService/ListZones": true,
|
||||
"/mcns.v1.ZoneService/GetZone": true,
|
||||
"/mcns.v1.RecordService/ListRecords": true,
|
||||
"/mcns.v1.RecordService/GetRecord": true,
|
||||
}
|
||||
}
|
||||
|
||||
func adminRequiredMethods() map[string]bool {
|
||||
return map[string]bool{
|
||||
"/mcns.v1.ZoneService/CreateZone": true,
|
||||
"/mcns.v1.ZoneService/UpdateZone": true,
|
||||
"/mcns.v1.ZoneService/DeleteZone": true,
|
||||
"/mcns.v1.RecordService/CreateRecord": true,
|
||||
"/mcns.v1.RecordService/UpdateRecord": true,
|
||||
"/mcns.v1.RecordService/DeleteRecord": true,
|
||||
}
|
||||
}
|
||||
110
internal/grpcserver/records.go
Normal file
110
internal/grpcserver/records.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
type recordService struct {
|
||||
pb.UnimplementedRecordServiceServer
|
||||
db *db.DB
|
||||
}
|
||||
|
||||
func (s *recordService) ListRecords(_ context.Context, req *pb.ListRecordsRequest) (*pb.ListRecordsResponse, error) {
|
||||
records, err := s.db.ListRecords(req.Zone, req.Name, req.Type)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "zone not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to list records")
|
||||
}
|
||||
|
||||
resp := &pb.ListRecordsResponse{}
|
||||
for _, r := range records {
|
||||
resp.Records = append(resp.Records, recordToProto(r))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *recordService) GetRecord(_ context.Context, req *pb.GetRecordRequest) (*pb.Record, error) {
|
||||
record, err := s.db.GetRecord(req.Id)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to get record")
|
||||
}
|
||||
return recordToProto(*record), nil
|
||||
}
|
||||
|
||||
func (s *recordService) CreateRecord(_ context.Context, req *pb.CreateRecordRequest) (*pb.Record, error) {
|
||||
record, err := s.db.CreateRecord(req.Zone, req.Name, req.Type, req.Value, int(req.Ttl))
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "zone not found")
|
||||
}
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
return nil, status.Error(codes.AlreadyExists, err.Error())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
return recordToProto(*record), nil
|
||||
}
|
||||
|
||||
func (s *recordService) UpdateRecord(_ context.Context, req *pb.UpdateRecordRequest) (*pb.Record, error) {
|
||||
record, err := s.db.UpdateRecord(req.Id, req.Name, req.Type, req.Value, int(req.Ttl))
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
return nil, status.Error(codes.AlreadyExists, err.Error())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.InvalidArgument, err.Error())
|
||||
}
|
||||
return recordToProto(*record), nil
|
||||
}
|
||||
|
||||
func (s *recordService) DeleteRecord(_ context.Context, req *pb.DeleteRecordRequest) (*pb.DeleteRecordResponse, error) {
|
||||
err := s.db.DeleteRecord(req.Id)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "record not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to delete record")
|
||||
}
|
||||
return &pb.DeleteRecordResponse{}, nil
|
||||
}
|
||||
|
||||
func recordToProto(r db.Record) *pb.Record {
|
||||
return &pb.Record{
|
||||
Id: r.ID,
|
||||
Zone: r.ZoneName,
|
||||
Name: r.Name,
|
||||
Type: r.Type,
|
||||
Value: r.Value,
|
||||
Ttl: int32(r.TTL),
|
||||
CreatedAt: parseRecordTimestamp(r.CreatedAt),
|
||||
UpdatedAt: parseRecordTimestamp(r.UpdatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
func parseRecordTimestamp(s string) *timestamppb.Timestamp {
|
||||
t, err := parseTime(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return timestamppb.New(t)
|
||||
}
|
||||
|
||||
func parseTime(s string) (time.Time, error) {
|
||||
return time.Parse("2006-01-02T15:04:05Z", s)
|
||||
}
|
||||
50
internal/grpcserver/server.go
Normal file
50
internal/grpcserver/server.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
|
||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||
mcdslgrpc "git.wntrmute.dev/kyle/mcdsl/grpcserver"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
// Deps holds the dependencies injected into the gRPC server.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Authenticator *mcdslauth.Authenticator
|
||||
}
|
||||
|
||||
// Server wraps a mcdsl grpcserver.Server with MCNS-specific services.
|
||||
type Server struct {
|
||||
srv *mcdslgrpc.Server
|
||||
}
|
||||
|
||||
// New creates a configured gRPC server with MCNS services registered.
|
||||
func New(certFile, keyFile string, deps Deps, logger *slog.Logger) (*Server, error) {
|
||||
srv, err := mcdslgrpc.New(certFile, keyFile, deps.Authenticator, methodMap(), logger)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &Server{srv: srv}
|
||||
|
||||
pb.RegisterAdminServiceServer(srv.GRPCServer, &adminService{db: deps.DB})
|
||||
pb.RegisterAuthServiceServer(srv.GRPCServer, &authService{auth: deps.Authenticator})
|
||||
pb.RegisterZoneServiceServer(srv.GRPCServer, &zoneService{db: deps.DB})
|
||||
pb.RegisterRecordServiceServer(srv.GRPCServer, &recordService{db: deps.DB})
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// Serve starts the gRPC server on the given listener.
|
||||
func (s *Server) Serve(lis net.Listener) error {
|
||||
return s.srv.GRPCServer.Serve(lis)
|
||||
}
|
||||
|
||||
// GracefulStop gracefully stops the gRPC server.
|
||||
func (s *Server) GracefulStop() {
|
||||
s.srv.Stop()
|
||||
}
|
||||
134
internal/grpcserver/zones.go
Normal file
134
internal/grpcserver/zones.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package grpcserver
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
pb "git.wntrmute.dev/kyle/mcns/gen/mcns/v1"
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
type zoneService struct {
|
||||
pb.UnimplementedZoneServiceServer
|
||||
db *db.DB
|
||||
}
|
||||
|
||||
func (s *zoneService) ListZones(_ context.Context, _ *pb.ListZonesRequest) (*pb.ListZonesResponse, error) {
|
||||
zones, err := s.db.ListZones()
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to list zones")
|
||||
}
|
||||
|
||||
resp := &pb.ListZonesResponse{}
|
||||
for _, z := range zones {
|
||||
resp.Zones = append(resp.Zones, zoneToProto(z))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *zoneService) GetZone(_ context.Context, req *pb.GetZoneRequest) (*pb.Zone, error) {
|
||||
zone, err := s.db.GetZone(req.Name)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "zone not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to get zone")
|
||||
}
|
||||
return zoneToProto(*zone), nil
|
||||
}
|
||||
|
||||
func (s *zoneService) CreateZone(_ context.Context, req *pb.CreateZoneRequest) (*pb.Zone, error) {
|
||||
refresh := int(req.Refresh)
|
||||
if refresh == 0 {
|
||||
refresh = 3600
|
||||
}
|
||||
retry := int(req.Retry)
|
||||
if retry == 0 {
|
||||
retry = 600
|
||||
}
|
||||
expire := int(req.Expire)
|
||||
if expire == 0 {
|
||||
expire = 86400
|
||||
}
|
||||
minTTL := int(req.MinimumTtl)
|
||||
if minTTL == 0 {
|
||||
minTTL = 300
|
||||
}
|
||||
|
||||
zone, err := s.db.CreateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
return nil, status.Error(codes.AlreadyExists, err.Error())
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to create zone")
|
||||
}
|
||||
return zoneToProto(*zone), nil
|
||||
}
|
||||
|
||||
func (s *zoneService) UpdateZone(_ context.Context, req *pb.UpdateZoneRequest) (*pb.Zone, error) {
|
||||
refresh := int(req.Refresh)
|
||||
if refresh == 0 {
|
||||
refresh = 3600
|
||||
}
|
||||
retry := int(req.Retry)
|
||||
if retry == 0 {
|
||||
retry = 600
|
||||
}
|
||||
expire := int(req.Expire)
|
||||
if expire == 0 {
|
||||
expire = 86400
|
||||
}
|
||||
minTTL := int(req.MinimumTtl)
|
||||
if minTTL == 0 {
|
||||
minTTL = 300
|
||||
}
|
||||
|
||||
zone, err := s.db.UpdateZone(req.Name, req.PrimaryNs, req.AdminEmail, refresh, retry, expire, minTTL)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "zone not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to update zone")
|
||||
}
|
||||
return zoneToProto(*zone), nil
|
||||
}
|
||||
|
||||
func (s *zoneService) DeleteZone(_ context.Context, req *pb.DeleteZoneRequest) (*pb.DeleteZoneResponse, error) {
|
||||
err := s.db.DeleteZone(req.Name)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
return nil, status.Error(codes.NotFound, "zone not found")
|
||||
}
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "failed to delete zone")
|
||||
}
|
||||
return &pb.DeleteZoneResponse{}, nil
|
||||
}
|
||||
|
||||
func zoneToProto(z db.Zone) *pb.Zone {
|
||||
return &pb.Zone{
|
||||
Id: z.ID,
|
||||
Name: z.Name,
|
||||
PrimaryNs: z.PrimaryNS,
|
||||
AdminEmail: z.AdminEmail,
|
||||
Refresh: int32(z.Refresh),
|
||||
Retry: int32(z.Retry),
|
||||
Expire: int32(z.Expire),
|
||||
MinimumTtl: int32(z.MinimumTTL),
|
||||
Serial: z.Serial,
|
||||
CreatedAt: parseTimestamp(z.CreatedAt),
|
||||
UpdatedAt: parseTimestamp(z.UpdatedAt),
|
||||
}
|
||||
}
|
||||
|
||||
func parseTimestamp(s string) *timestamppb.Timestamp {
|
||||
// SQLite stores as "2006-01-02T15:04:05Z".
|
||||
t, err := parseTime(s)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return timestamppb.New(t)
|
||||
}
|
||||
62
internal/server/auth.go
Normal file
62
internal/server/auth.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||
)
|
||||
|
||||
type loginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
TOTPCode string `json:"totp_code"`
|
||||
}
|
||||
|
||||
type loginResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
func loginHandler(auth *mcdslauth.Authenticator) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req loginRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
token, _, err := auth.Login(req.Username, req.Password, req.TOTPCode)
|
||||
if err != nil {
|
||||
if errors.Is(err, mcdslauth.ErrInvalidCredentials) {
|
||||
writeError(w, http.StatusUnauthorized, "invalid credentials")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, mcdslauth.ErrForbidden) {
|
||||
writeError(w, http.StatusForbidden, "access denied by login policy")
|
||||
return
|
||||
}
|
||||
writeError(w, http.StatusServiceUnavailable, "authentication service unavailable")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, loginResponse{Token: token})
|
||||
}
|
||||
}
|
||||
|
||||
func logoutHandler(auth *mcdslauth.Authenticator) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractBearerToken(r)
|
||||
if token == "" {
|
||||
writeError(w, http.StatusUnauthorized, "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
if err := auth.Logout(token); err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "logout failed")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
96
internal/server/middleware.go
Normal file
96
internal/server/middleware.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const tokenInfoKey contextKey = "tokenInfo"
|
||||
|
||||
// requireAuth returns middleware that validates Bearer tokens via MCIAS.
|
||||
func requireAuth(auth *mcdslauth.Authenticator) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
token := extractBearerToken(r)
|
||||
if token == "" {
|
||||
writeError(w, http.StatusUnauthorized, "authentication required")
|
||||
return
|
||||
}
|
||||
|
||||
info, err := auth.ValidateToken(token)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusUnauthorized, "invalid or expired token")
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), tokenInfoKey, info)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// requireAdmin is middleware that checks the caller has the admin role.
|
||||
func requireAdmin(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
info := tokenInfoFromContext(r.Context())
|
||||
if info == nil || !info.IsAdmin {
|
||||
writeError(w, http.StatusForbidden, "admin role required")
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
// tokenInfoFromContext extracts the TokenInfo from the request context.
|
||||
func tokenInfoFromContext(ctx context.Context) *mcdslauth.TokenInfo {
|
||||
info, _ := ctx.Value(tokenInfoKey).(*mcdslauth.TokenInfo)
|
||||
return info
|
||||
}
|
||||
|
||||
// extractBearerToken extracts a bearer token from the Authorization header.
|
||||
func extractBearerToken(r *http.Request) string {
|
||||
h := r.Header.Get("Authorization")
|
||||
if h == "" {
|
||||
return ""
|
||||
}
|
||||
const prefix = "Bearer "
|
||||
if !strings.HasPrefix(h, prefix) {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSpace(h[len(prefix):])
|
||||
}
|
||||
|
||||
// loggingMiddleware logs HTTP requests.
|
||||
func loggingMiddleware(logger *slog.Logger) func(http.Handler) http.Handler {
|
||||
return func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
sw := &statusWriter{ResponseWriter: w, status: http.StatusOK}
|
||||
next.ServeHTTP(sw, r)
|
||||
logger.Info("http",
|
||||
"method", r.Method,
|
||||
"path", r.URL.Path,
|
||||
"status", sw.status,
|
||||
"duration", time.Since(start),
|
||||
"remote", r.RemoteAddr,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type statusWriter struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
}
|
||||
|
||||
func (w *statusWriter) WriteHeader(code int) {
|
||||
w.status = code
|
||||
w.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
174
internal/server/records.go
Normal file
174
internal/server/records.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
type createRecordRequest struct {
|
||||
Name string `json:"name"`
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
TTL int `json:"ttl"`
|
||||
}
|
||||
|
||||
func listRecordsHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
zoneName := chi.URLParam(r, "zone")
|
||||
nameFilter := r.URL.Query().Get("name")
|
||||
typeFilter := r.URL.Query().Get("type")
|
||||
|
||||
records, err := database.ListRecords(zoneName, nameFilter, typeFilter)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "zone not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list records")
|
||||
return
|
||||
}
|
||||
if records == nil {
|
||||
records = []db.Record{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"records": records})
|
||||
}
|
||||
}
|
||||
|
||||
func getRecordHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid record ID")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := database.GetRecord(id)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "record not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to get record")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, record)
|
||||
}
|
||||
}
|
||||
|
||||
func createRecordHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
zoneName := chi.URLParam(r, "zone")
|
||||
|
||||
var req createRecordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
if req.Type == "" {
|
||||
writeError(w, http.StatusBadRequest, "type is required")
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
writeError(w, http.StatusBadRequest, "value is required")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := database.CreateRecord(zoneName, req.Name, req.Type, req.Value, req.TTL)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "zone not found")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
writeError(w, http.StatusConflict, err.Error())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, record)
|
||||
}
|
||||
}
|
||||
|
||||
func updateRecordHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid record ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req createRecordRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
if req.Type == "" {
|
||||
writeError(w, http.StatusBadRequest, "type is required")
|
||||
return
|
||||
}
|
||||
if req.Value == "" {
|
||||
writeError(w, http.StatusBadRequest, "value is required")
|
||||
return
|
||||
}
|
||||
|
||||
record, err := database.UpdateRecord(id, req.Name, req.Type, req.Value, req.TTL)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "record not found")
|
||||
return
|
||||
}
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
writeError(w, http.StatusConflict, err.Error())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, record)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteRecordHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
idStr := chi.URLParam(r, "id")
|
||||
id, err := strconv.ParseInt(idStr, 10, 64)
|
||||
if err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid record ID")
|
||||
return
|
||||
}
|
||||
|
||||
err = database.DeleteRecord(id)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "record not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete record")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
71
internal/server/routes.go
Normal file
71
internal/server/routes.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
mcdslauth "git.wntrmute.dev/kyle/mcdsl/auth"
|
||||
"git.wntrmute.dev/kyle/mcdsl/health"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
// Deps holds dependencies injected into the REST handlers.
|
||||
type Deps struct {
|
||||
DB *db.DB
|
||||
Auth *mcdslauth.Authenticator
|
||||
Logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewRouter builds the chi router with all MCNS REST endpoints.
|
||||
func NewRouter(deps Deps) *chi.Mux {
|
||||
r := chi.NewRouter()
|
||||
r.Use(loggingMiddleware(deps.Logger))
|
||||
|
||||
// Public endpoints.
|
||||
r.Post("/v1/auth/login", loginHandler(deps.Auth))
|
||||
r.Get("/v1/health", health.Handler(deps.DB.DB))
|
||||
|
||||
// Authenticated endpoints.
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(requireAuth(deps.Auth))
|
||||
|
||||
r.Post("/v1/auth/logout", logoutHandler(deps.Auth))
|
||||
|
||||
// Zone endpoints — reads for all authenticated users, writes for admin.
|
||||
r.Get("/v1/zones", listZonesHandler(deps.DB))
|
||||
r.Get("/v1/zones/{zone}", getZoneHandler(deps.DB))
|
||||
|
||||
// Admin-only zone mutations.
|
||||
r.With(requireAdmin).Post("/v1/zones", createZoneHandler(deps.DB))
|
||||
r.With(requireAdmin).Put("/v1/zones/{zone}", updateZoneHandler(deps.DB))
|
||||
r.With(requireAdmin).Delete("/v1/zones/{zone}", deleteZoneHandler(deps.DB))
|
||||
|
||||
// Record endpoints — reads for all authenticated users, writes for admin.
|
||||
r.Get("/v1/zones/{zone}/records", listRecordsHandler(deps.DB))
|
||||
r.Get("/v1/zones/{zone}/records/{id}", getRecordHandler(deps.DB))
|
||||
|
||||
// Admin-only record mutations.
|
||||
r.With(requireAdmin).Post("/v1/zones/{zone}/records", createRecordHandler(deps.DB))
|
||||
r.With(requireAdmin).Put("/v1/zones/{zone}/records/{id}", updateRecordHandler(deps.DB))
|
||||
r.With(requireAdmin).Delete("/v1/zones/{zone}/records/{id}", deleteRecordHandler(deps.DB))
|
||||
})
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// writeJSON writes a JSON response with the given status code.
|
||||
func writeJSON(w http.ResponseWriter, status int, v any) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
_ = json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
|
||||
// writeError writes a standard error response.
|
||||
func writeError(w http.ResponseWriter, status int, message string) {
|
||||
writeJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
163
internal/server/zones.go
Normal file
163
internal/server/zones.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
|
||||
"git.wntrmute.dev/kyle/mcns/internal/db"
|
||||
)
|
||||
|
||||
type createZoneRequest struct {
|
||||
Name string `json:"name"`
|
||||
PrimaryNS string `json:"primary_ns"`
|
||||
AdminEmail string `json:"admin_email"`
|
||||
Refresh int `json:"refresh"`
|
||||
Retry int `json:"retry"`
|
||||
Expire int `json:"expire"`
|
||||
MinimumTTL int `json:"minimum_ttl"`
|
||||
}
|
||||
|
||||
func listZonesHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, _ *http.Request) {
|
||||
zones, err := database.ListZones()
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to list zones")
|
||||
return
|
||||
}
|
||||
if zones == nil {
|
||||
zones = []db.Zone{}
|
||||
}
|
||||
writeJSON(w, http.StatusOK, map[string]any{"zones": zones})
|
||||
}
|
||||
}
|
||||
|
||||
func getZoneHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "zone")
|
||||
zone, err := database.GetZone(name)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "zone not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to get zone")
|
||||
return
|
||||
}
|
||||
writeJSON(w, http.StatusOK, zone)
|
||||
}
|
||||
}
|
||||
|
||||
func createZoneHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var req createZoneRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.Name == "" {
|
||||
writeError(w, http.StatusBadRequest, "name is required")
|
||||
return
|
||||
}
|
||||
if req.PrimaryNS == "" {
|
||||
writeError(w, http.StatusBadRequest, "primary_ns is required")
|
||||
return
|
||||
}
|
||||
if req.AdminEmail == "" {
|
||||
writeError(w, http.StatusBadRequest, "admin_email is required")
|
||||
return
|
||||
}
|
||||
|
||||
// Apply defaults for SOA params.
|
||||
if req.Refresh == 0 {
|
||||
req.Refresh = 3600
|
||||
}
|
||||
if req.Retry == 0 {
|
||||
req.Retry = 600
|
||||
}
|
||||
if req.Expire == 0 {
|
||||
req.Expire = 86400
|
||||
}
|
||||
if req.MinimumTTL == 0 {
|
||||
req.MinimumTTL = 300
|
||||
}
|
||||
|
||||
zone, err := database.CreateZone(req.Name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||
if errors.Is(err, db.ErrConflict) {
|
||||
writeError(w, http.StatusConflict, err.Error())
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to create zone")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusCreated, zone)
|
||||
}
|
||||
}
|
||||
|
||||
func updateZoneHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "zone")
|
||||
|
||||
var req createZoneRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
writeError(w, http.StatusBadRequest, "invalid request body")
|
||||
return
|
||||
}
|
||||
|
||||
if req.PrimaryNS == "" {
|
||||
writeError(w, http.StatusBadRequest, "primary_ns is required")
|
||||
return
|
||||
}
|
||||
if req.AdminEmail == "" {
|
||||
writeError(w, http.StatusBadRequest, "admin_email is required")
|
||||
return
|
||||
}
|
||||
if req.Refresh == 0 {
|
||||
req.Refresh = 3600
|
||||
}
|
||||
if req.Retry == 0 {
|
||||
req.Retry = 600
|
||||
}
|
||||
if req.Expire == 0 {
|
||||
req.Expire = 86400
|
||||
}
|
||||
if req.MinimumTTL == 0 {
|
||||
req.MinimumTTL = 300
|
||||
}
|
||||
|
||||
zone, err := database.UpdateZone(name, req.PrimaryNS, req.AdminEmail, req.Refresh, req.Retry, req.Expire, req.MinimumTTL)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "zone not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to update zone")
|
||||
return
|
||||
}
|
||||
|
||||
writeJSON(w, http.StatusOK, zone)
|
||||
}
|
||||
}
|
||||
|
||||
func deleteZoneHandler(database *db.DB) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
name := chi.URLParam(r, "zone")
|
||||
|
||||
err := database.DeleteZone(name)
|
||||
if errors.Is(err, db.ErrNotFound) {
|
||||
writeError(w, http.StatusNotFound, "zone not found")
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
writeError(w, http.StatusInternalServerError, "failed to delete zone")
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user