Replace the CoreDNS precursor with a purpose-built authoritative DNS server. Zones and records (A, AAAA, CNAME) are stored in SQLite and managed via synchronized gRPC + REST APIs authenticated through MCIAS. Non-authoritative queries are forwarded to upstream resolvers with in-memory caching. Key components: - DNS server (miekg/dns) with authoritative zone handling and forwarding - gRPC + REST management APIs with MCIAS auth (mcdsl integration) - SQLite storage with CNAME exclusivity enforcement and auto SOA serials - 30 tests covering database CRUD, DNS resolution, and caching Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
309 lines
8.8 KiB
Go
309 lines
8.8 KiB
Go
package db
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// Record represents a DNS record stored in the database.
|
|
type Record struct {
|
|
ID int64
|
|
ZoneID int64
|
|
ZoneName string
|
|
Name string
|
|
Type string
|
|
Value string
|
|
TTL int
|
|
CreatedAt string
|
|
UpdatedAt string
|
|
}
|
|
|
|
// ListRecords returns records for a zone, optionally filtered by name and type.
|
|
func (d *DB) ListRecords(zoneName, name, recordType string) ([]Record, error) {
|
|
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
|
|
|
zone, err := d.GetZone(zoneName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
query := `SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
|
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.zone_id = ?`
|
|
args := []any{zone.ID}
|
|
|
|
if name != "" {
|
|
query += ` AND r.name = ?`
|
|
args = append(args, strings.ToLower(name))
|
|
}
|
|
if recordType != "" {
|
|
query += ` AND r.type = ?`
|
|
args = append(args, strings.ToUpper(recordType))
|
|
}
|
|
|
|
query += ` ORDER BY r.name, r.type, r.value`
|
|
|
|
rows, err := d.Query(query, args...)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("list records: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var records []Record
|
|
for rows.Next() {
|
|
var r Record
|
|
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan record: %w", err)
|
|
}
|
|
records = append(records, r)
|
|
}
|
|
return records, rows.Err()
|
|
}
|
|
|
|
// LookupRecords returns records matching a name and type within a zone.
|
|
// Used by the DNS handler for query resolution.
|
|
func (d *DB) LookupRecords(zoneName, name, recordType string) ([]Record, error) {
|
|
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
|
name = strings.ToLower(name)
|
|
|
|
rows, err := d.Query(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
|
FROM records r JOIN zones z ON r.zone_id = z.id
|
|
WHERE z.name = ? AND r.name = ? AND r.type = ?`, zoneName, name, strings.ToUpper(recordType))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("lookup records: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var records []Record
|
|
for rows.Next() {
|
|
var r Record
|
|
if err := rows.Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt); err != nil {
|
|
return nil, fmt.Errorf("scan record: %w", err)
|
|
}
|
|
records = append(records, r)
|
|
}
|
|
return records, rows.Err()
|
|
}
|
|
|
|
// LookupCNAME returns CNAME records for a name within a zone.
|
|
func (d *DB) LookupCNAME(zoneName, name string) ([]Record, error) {
|
|
return d.LookupRecords(zoneName, name, "CNAME")
|
|
}
|
|
|
|
// HasRecordsForName checks if any records of the given types exist for a name.
|
|
func (d *DB) HasRecordsForName(tx *sql.Tx, zoneID int64, name string, types []string) (bool, error) {
|
|
placeholders := make([]string, len(types))
|
|
args := []any{zoneID, strings.ToLower(name)}
|
|
for i, t := range types {
|
|
placeholders[i] = "?"
|
|
args = append(args, strings.ToUpper(t))
|
|
}
|
|
query := fmt.Sprintf(`SELECT COUNT(*) FROM records WHERE zone_id = ? AND name = ? AND type IN (%s)`, strings.Join(placeholders, ","))
|
|
|
|
var count int
|
|
if err := tx.QueryRow(query, args...).Scan(&count); err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
// GetRecord returns a record by ID.
|
|
func (d *DB) GetRecord(id int64) (*Record, error) {
|
|
var r Record
|
|
err := d.QueryRow(`SELECT r.id, r.zone_id, z.name, r.name, r.type, r.value, r.ttl, r.created_at, r.updated_at
|
|
FROM records r JOIN zones z ON r.zone_id = z.id WHERE r.id = ?`, id).
|
|
Scan(&r.ID, &r.ZoneID, &r.ZoneName, &r.Name, &r.Type, &r.Value, &r.TTL, &r.CreatedAt, &r.UpdatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("get record: %w", err)
|
|
}
|
|
return &r, nil
|
|
}
|
|
|
|
// CreateRecord inserts a new record, enforcing CNAME exclusivity and
|
|
// value validation. Bumps the zone serial within the same transaction.
|
|
func (d *DB) CreateRecord(zoneName, name, recordType, value string, ttl int) (*Record, error) {
|
|
zoneName = strings.ToLower(strings.TrimSuffix(zoneName, "."))
|
|
name = strings.ToLower(name)
|
|
recordType = strings.ToUpper(recordType)
|
|
|
|
if err := validateRecordValue(recordType, value); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
zone, err := d.GetZone(zoneName)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ttl <= 0 {
|
|
ttl = 300
|
|
}
|
|
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
// Enforce CNAME exclusivity.
|
|
if recordType == "CNAME" {
|
|
hasAddr, err := d.HasRecordsForName(tx, zone.ID, name, []string{"A", "AAAA"})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
|
}
|
|
if hasAddr {
|
|
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
|
}
|
|
} else if recordType == "A" || recordType == "AAAA" {
|
|
hasCNAME, err := d.HasRecordsForName(tx, zone.ID, name, []string{"CNAME"})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
|
}
|
|
if hasCNAME {
|
|
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
|
}
|
|
}
|
|
|
|
res, err := tx.Exec(`INSERT INTO records (zone_id, name, type, value, ttl) VALUES (?, ?, ?, ?, ?)`,
|
|
zone.ID, name, recordType, value, ttl)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
|
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
|
}
|
|
return nil, fmt.Errorf("insert record: %w", err)
|
|
}
|
|
|
|
recordID, err := res.LastInsertId()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("last insert id: %w", err)
|
|
}
|
|
|
|
if err := d.BumpSerial(tx, zone.ID); err != nil {
|
|
return nil, fmt.Errorf("bump serial: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit: %w", err)
|
|
}
|
|
|
|
return d.GetRecord(recordID)
|
|
}
|
|
|
|
// UpdateRecord updates an existing record's fields and bumps the zone serial.
|
|
func (d *DB) UpdateRecord(id int64, name, recordType, value string, ttl int) (*Record, error) {
|
|
name = strings.ToLower(name)
|
|
recordType = strings.ToUpper(recordType)
|
|
|
|
if err := validateRecordValue(recordType, value); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
existing, err := d.GetRecord(id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if ttl <= 0 {
|
|
ttl = 300
|
|
}
|
|
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
// Enforce CNAME exclusivity for the new type/name combo.
|
|
if recordType == "CNAME" {
|
|
hasAddr, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"A", "AAAA"})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
|
}
|
|
if hasAddr {
|
|
return nil, fmt.Errorf("%w: CNAME record conflicts with existing A/AAAA record for %q", ErrConflict, name)
|
|
}
|
|
} else if recordType == "A" || recordType == "AAAA" {
|
|
hasCNAME, err := d.HasRecordsForName(tx, existing.ZoneID, name, []string{"CNAME"})
|
|
if err != nil {
|
|
return nil, fmt.Errorf("check cname exclusivity: %w", err)
|
|
}
|
|
if hasCNAME {
|
|
return nil, fmt.Errorf("%w: A/AAAA record conflicts with existing CNAME record for %q", ErrConflict, name)
|
|
}
|
|
}
|
|
|
|
now := time.Now().UTC().Format("2006-01-02T15:04:05Z")
|
|
_, err = tx.Exec(`UPDATE records SET name = ?, type = ?, value = ?, ttl = ?, updated_at = ? WHERE id = ?`,
|
|
name, recordType, value, ttl, now, id)
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), "UNIQUE constraint") {
|
|
return nil, fmt.Errorf("%w: record already exists", ErrConflict)
|
|
}
|
|
return nil, fmt.Errorf("update record: %w", err)
|
|
}
|
|
|
|
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
|
return nil, fmt.Errorf("bump serial: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return nil, fmt.Errorf("commit: %w", err)
|
|
}
|
|
|
|
return d.GetRecord(id)
|
|
}
|
|
|
|
// DeleteRecord deletes a record and bumps the zone serial.
|
|
func (d *DB) DeleteRecord(id int64) error {
|
|
existing, err := d.GetRecord(id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("begin tx: %w", err)
|
|
}
|
|
defer func() { _ = tx.Rollback() }()
|
|
|
|
_, err = tx.Exec(`DELETE FROM records WHERE id = ?`, id)
|
|
if err != nil {
|
|
return fmt.Errorf("delete record: %w", err)
|
|
}
|
|
|
|
if err := d.BumpSerial(tx, existing.ZoneID); err != nil {
|
|
return fmt.Errorf("bump serial: %w", err)
|
|
}
|
|
|
|
return tx.Commit()
|
|
}
|
|
|
|
// validateRecordValue checks that a record value is valid for its type.
|
|
func validateRecordValue(recordType, value string) error {
|
|
switch recordType {
|
|
case "A":
|
|
ip := net.ParseIP(value)
|
|
if ip == nil || ip.To4() == nil {
|
|
return fmt.Errorf("invalid IPv4 address: %q", value)
|
|
}
|
|
case "AAAA":
|
|
ip := net.ParseIP(value)
|
|
if ip == nil || ip.To4() != nil {
|
|
return fmt.Errorf("invalid IPv6 address: %q", value)
|
|
}
|
|
case "CNAME":
|
|
if !strings.HasSuffix(value, ".") {
|
|
return fmt.Errorf("CNAME value must be a fully-qualified domain name ending with '.': %q", value)
|
|
}
|
|
default:
|
|
return fmt.Errorf("unsupported record type: %q", recordType)
|
|
}
|
|
return nil
|
|
}
|