Implement MCNS v1: custom Go DNS server replacing CoreDNS

Replace the CoreDNS precursor with a purpose-built authoritative DNS
server. Zones and records (A, AAAA, CNAME) are stored in SQLite and
managed via synchronized gRPC + REST APIs authenticated through MCIAS.
Non-authoritative queries are forwarded to upstream resolvers with
in-memory caching.

Key components:
- DNS server (miekg/dns) with authoritative zone handling and forwarding
- gRPC + REST management APIs with MCIAS auth (mcdsl integration)
- SQLite storage with CNAME exclusivity enforcement and auto SOA serials
- 30 tests covering database CRUD, DNS resolution, and caching

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-26 18:37:14 -07:00
parent a545fec658
commit f9635578e0
48 changed files with 6015 additions and 87 deletions

67
internal/dns/cache.go Normal file
View File

@@ -0,0 +1,67 @@
package dns
import (
"sync"
"time"
"github.com/miekg/dns"
)
type cacheKey struct {
Name string
Qtype uint16
Class uint16
}
type cacheEntry struct {
msg *dns.Msg
expiresAt time.Time
}
// Cache is a thread-safe in-memory DNS response cache with TTL-based expiry.
type Cache struct {
mu sync.RWMutex
entries map[cacheKey]*cacheEntry
}
// NewCache creates an empty DNS cache.
func NewCache() *Cache {
return &Cache{
entries: make(map[cacheKey]*cacheEntry),
}
}
// Get returns a cached response if it exists and has not expired.
func (c *Cache) Get(name string, qtype, class uint16) *dns.Msg {
c.mu.RLock()
defer c.mu.RUnlock()
key := cacheKey{Name: name, Qtype: qtype, Class: class}
entry, ok := c.entries[key]
if !ok || time.Now().After(entry.expiresAt) {
return nil
}
return entry.msg
}
// Set stores a DNS response in the cache with the given TTL.
func (c *Cache) Set(name string, qtype, class uint16, msg *dns.Msg, ttl time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
key := cacheKey{Name: name, Qtype: qtype, Class: class}
c.entries[key] = &cacheEntry{
msg: msg.Copy(),
expiresAt: time.Now().Add(ttl),
}
// Lazy eviction: clean up expired entries if cache is growing.
if len(c.entries) > 1000 {
now := time.Now()
for k, v := range c.entries {
if now.After(v.expiresAt) {
delete(c.entries, k)
}
}
}
}

View File

@@ -0,0 +1,81 @@
package dns
import (
"testing"
"time"
"github.com/miekg/dns"
)
func TestCacheSetGet(t *testing.T) {
c := NewCache()
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
msg.Answer = append(msg.Answer, &dns.A{
Hdr: dns.RR_Header{Name: "example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: []byte{1, 2, 3, 4},
})
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 5*time.Second)
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached == nil {
t.Fatal("expected cached response")
}
if len(cached.Answer) != 1 {
t.Fatalf("got %d answers, want 1", len(cached.Answer))
}
}
func TestCacheMiss(t *testing.T) {
c := NewCache()
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached != nil {
t.Fatal("expected nil for cache miss")
}
}
func TestCacheExpiry(t *testing.T) {
c := NewCache()
msg := new(dns.Msg)
msg.SetQuestion("example.com.", dns.TypeA)
c.Set("example.com.", dns.TypeA, dns.ClassINET, msg, 1*time.Millisecond)
time.Sleep(2 * time.Millisecond)
cached := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cached != nil {
t.Fatal("expected nil for expired entry")
}
}
func TestCacheDifferentTypes(t *testing.T) {
c := NewCache()
msgA := new(dns.Msg)
msgA.SetQuestion("example.com.", dns.TypeA)
c.Set("example.com.", dns.TypeA, dns.ClassINET, msgA, 5*time.Second)
msgAAAA := new(dns.Msg)
msgAAAA.SetQuestion("example.com.", dns.TypeAAAA)
c.Set("example.com.", dns.TypeAAAA, dns.ClassINET, msgAAAA, 5*time.Second)
cachedA := c.Get("example.com.", dns.TypeA, dns.ClassINET)
if cachedA == nil {
t.Fatal("expected cached A response")
}
cachedAAAA := c.Get("example.com.", dns.TypeAAAA, dns.ClassINET)
if cachedAAAA == nil {
t.Fatal("expected cached AAAA response")
}
// Different type should not match.
cachedMX := c.Get("example.com.", dns.TypeMX, dns.ClassINET)
if cachedMX != nil {
t.Fatal("expected nil for uncached type")
}
}

87
internal/dns/forwarder.go Normal file
View File

@@ -0,0 +1,87 @@
package dns
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Forwarder handles forwarding DNS queries to upstream resolvers.
type Forwarder struct {
upstreams []string
client *dns.Client
cache *Cache
}
// NewForwarder creates a Forwarder with the given upstream addresses.
func NewForwarder(upstreams []string) *Forwarder {
return &Forwarder{
upstreams: upstreams,
client: &dns.Client{
Timeout: 2 * time.Second,
},
cache: NewCache(),
}
}
// Forward sends a query to upstream resolvers and returns the response.
// Responses are cached by (qname, qtype, qclass) with TTL-based expiry.
func (f *Forwarder) Forward(r *dns.Msg) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
q := r.Question[0]
// Check cache.
if cached := f.cache.Get(q.Name, q.Qtype, q.Qclass); cached != nil {
return cached.Copy(), nil
}
// Try each upstream in order.
var lastErr error
for _, upstream := range f.upstreams {
resp, _, err := f.client.Exchange(r, upstream)
if err != nil {
lastErr = err
continue
}
// Don't cache SERVFAIL or REFUSED.
if resp.Rcode != dns.RcodeServerFailure && resp.Rcode != dns.RcodeRefused {
ttl := minTTL(resp)
if ttl > 300 {
ttl = 300
}
if ttl > 0 {
f.cache.Set(q.Name, q.Qtype, q.Qclass, resp, time.Duration(ttl)*time.Second)
}
}
return resp, nil
}
return nil, fmt.Errorf("all upstreams failed: %w", lastErr)
}
// minTTL returns the minimum TTL from all resource records in a response.
func minTTL(msg *dns.Msg) uint32 {
var min uint32
first := true
for _, sections := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} {
for _, rr := range sections {
ttl := rr.Header().Ttl
if first || ttl < min {
min = ttl
first = false
}
}
}
if first {
return 60 // No records; default to 60s.
}
return min
}

280
internal/dns/server.go Normal file
View File

@@ -0,0 +1,280 @@
// Package dns implements the authoritative DNS server for MCNS.
// It serves records from SQLite for authoritative zones and forwards
// all other queries to configured upstream resolvers.
package dns
import (
"log/slog"
"net"
"strings"
"github.com/miekg/dns"
"git.wntrmute.dev/kyle/mcns/internal/db"
)
// Server is the MCNS DNS server. It listens on both UDP and TCP.
type Server struct {
db *db.DB
forwarder *Forwarder
logger *slog.Logger
udp *dns.Server
tcp *dns.Server
}
// New creates a DNS server that serves records from the database and
// forwards non-authoritative queries to the given upstreams.
func New(database *db.DB, upstreams []string, logger *slog.Logger) *Server {
s := &Server{
db: database,
forwarder: NewForwarder(upstreams),
logger: logger,
}
mux := dns.NewServeMux()
mux.HandleFunc(".", s.handleQuery)
s.udp = &dns.Server{Handler: mux, Net: "udp"}
s.tcp = &dns.Server{Handler: mux, Net: "tcp"}
return s
}
// ListenAndServe starts the DNS server on the given address for both
// UDP and TCP. It blocks until Shutdown is called.
func (s *Server) ListenAndServe(addr string) error {
s.udp.Addr = addr
s.tcp.Addr = addr
errCh := make(chan error, 2)
go func() {
s.logger.Info("dns server listening", "addr", addr, "proto", "udp")
errCh <- s.udp.ListenAndServe()
}()
go func() {
s.logger.Info("dns server listening", "addr", addr, "proto", "tcp")
errCh <- s.tcp.ListenAndServe()
}()
return <-errCh
}
// Shutdown gracefully stops the DNS server.
func (s *Server) Shutdown() {
_ = s.udp.Shutdown()
_ = s.tcp.Shutdown()
}
// handleQuery is the main DNS query handler. It checks if the query
// falls within an authoritative zone and either serves from the database
// or forwards to upstream.
func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
s.writeResponse(w, r, dns.RcodeFormatError, nil, nil)
return
}
q := r.Question[0]
qname := strings.ToLower(q.Name)
// Find the authoritative zone for this query.
zone := s.findZone(qname)
if zone == nil {
// Not authoritative — forward to upstream.
s.forwardQuery(w, r)
return
}
s.handleAuthoritativeQuery(w, r, zone, qname, q.Qtype)
}
// findZone returns the best matching zone for the query name, or nil.
func (s *Server) findZone(qname string) *db.Zone {
// Walk up the domain labels to find the longest matching zone.
name := strings.TrimSuffix(qname, ".")
parts := strings.Split(name, ".")
for i := range parts {
candidate := strings.Join(parts[i:], ".")
zone, err := s.db.GetZone(candidate)
if err == nil {
return zone
}
}
return nil
}
// handleAuthoritativeQuery serves a query from the database.
func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone *db.Zone, qname string, qtype uint16) {
// Extract the record name relative to the zone.
zoneFQDN := zone.Name + "."
var relName string
if qname == zoneFQDN {
relName = "@"
} else {
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
}
// Handle SOA queries.
if qtype == dns.TypeSOA || relName == "@" && qtype == dns.TypeSOA {
soa := s.buildSOA(zone)
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
return
}
// Handle NS queries at the zone apex.
if qtype == dns.TypeNS && relName == "@" {
ns := &dns.NS{
Hdr: dns.RR_Header{Name: zoneFQDN, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
Ns: zone.PrimaryNS,
}
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{ns}, nil)
return
}
// Look up the requested record type.
var answers []dns.RR
var lookupType string
switch qtype {
case dns.TypeA:
lookupType = "A"
case dns.TypeAAAA:
lookupType = "AAAA"
case dns.TypeCNAME:
lookupType = "CNAME"
default:
// For unsupported types, check if the name exists at all.
// If it does, return empty answer. If not, NXDOMAIN.
exists, _ := s.nameExists(zone.Name, relName)
if exists {
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
} else {
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
}
return
}
records, err := s.db.LookupRecords(zone.Name, relName, lookupType)
if err != nil {
s.logger.Error("dns lookup failed", "zone", zone.Name, "name", relName, "type", lookupType, "error", err)
s.writeResponse(w, r, dns.RcodeServerFailure, nil, nil)
return
}
// If no direct records, check for CNAME.
if len(records) == 0 && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
cnameRecords, err := s.db.LookupCNAME(zone.Name, relName)
if err == nil && len(cnameRecords) > 0 {
for _, rec := range cnameRecords {
answers = append(answers, &dns.CNAME{
Hdr: dns.RR_Header{Name: qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(rec.TTL)},
Target: rec.Value,
})
}
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
return
}
}
if len(records) == 0 {
// Name might still exist with other record types.
exists, _ := s.nameExists(zone.Name, relName)
if exists {
// NODATA: name exists but no records of requested type.
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
} else {
// NXDOMAIN: name does not exist.
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
}
return
}
for _, rec := range records {
rr := s.recordToRR(qname, rec)
if rr != nil {
answers = append(answers, rr)
}
}
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
}
// nameExists checks if any records exist for a name in a zone.
func (s *Server) nameExists(zoneName, name string) (bool, error) {
records, err := s.db.ListRecords(zoneName, name, "")
if err != nil {
return false, err
}
return len(records) > 0, nil
}
// recordToRR converts a database Record to a dns.RR.
func (s *Server) recordToRR(qname string, rec db.Record) dns.RR {
hdr := dns.RR_Header{Name: qname, Class: dns.ClassINET, Ttl: uint32(rec.TTL)}
switch rec.Type {
case "A":
hdr.Rrtype = dns.TypeA
return &dns.A{Hdr: hdr, A: parseIP(rec.Value)}
case "AAAA":
hdr.Rrtype = dns.TypeAAAA
return &dns.AAAA{Hdr: hdr, AAAA: parseIP(rec.Value)}
case "CNAME":
hdr.Rrtype = dns.TypeCNAME
return &dns.CNAME{Hdr: hdr, Target: rec.Value}
}
return nil
}
// buildSOA constructs a SOA record for the given zone.
func (s *Server) buildSOA(zone *db.Zone) *dns.SOA {
return &dns.SOA{
Hdr: dns.RR_Header{Name: zone.Name + ".", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
Ns: zone.PrimaryNS,
Mbox: zone.AdminEmail,
Serial: uint32(zone.Serial),
Refresh: uint32(zone.Refresh),
Retry: uint32(zone.Retry),
Expire: uint32(zone.Expire),
Minttl: uint32(zone.MinimumTTL),
}
}
// writeResponse constructs and writes a DNS response.
func (s *Server) writeResponse(w dns.ResponseWriter, r *dns.Msg, rcode int, answer []dns.RR, ns []dns.RR) {
m := new(dns.Msg)
m.SetReply(r)
m.Authoritative = true
m.Rcode = rcode
m.Answer = answer
m.Ns = ns
if err := w.WriteMsg(m); err != nil {
s.logger.Error("dns write failed", "error", err)
}
}
// forwardQuery forwards a DNS query to upstream resolvers.
func (s *Server) forwardQuery(w dns.ResponseWriter, r *dns.Msg) {
resp, err := s.forwarder.Forward(r)
if err != nil {
s.logger.Debug("dns forward failed", "error", err)
m := new(dns.Msg)
m.SetReply(r)
m.Rcode = dns.RcodeServerFailure
_ = w.WriteMsg(m)
return
}
resp.Id = r.Id
if err := w.WriteMsg(resp); err != nil {
s.logger.Error("dns write failed", "error", err)
}
}
// parseIP parses an IP address string into a net.IP.
func parseIP(s string) net.IP {
return net.ParseIP(s)
}

142
internal/dns/server_test.go Normal file
View File

@@ -0,0 +1,142 @@
package dns
import (
"path/filepath"
"testing"
"github.com/miekg/dns"
"git.wntrmute.dev/kyle/mcns/internal/db"
"log/slog"
)
func openTestDB(t *testing.T) *db.DB {
t.Helper()
dir := t.TempDir()
database, err := db.Open(filepath.Join(dir, "test.db"))
if err != nil {
t.Fatalf("open db: %v", err)
}
if err := database.Migrate(); err != nil {
t.Fatalf("migrate: %v", err)
}
t.Cleanup(func() { _ = database.Close() })
return database
}
func setupTestServer(t *testing.T) (*Server, *db.DB) {
t.Helper()
database := openTestDB(t)
logger := slog.Default()
_, err := database.CreateZone("svc.mcp.metacircular.net", "ns.mcp.metacircular.net.", "admin.metacircular.net.", 3600, 600, 86400, 300)
if err != nil {
t.Fatalf("create zone: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "192.168.88.181", 300)
if err != nil {
t.Fatalf("create A record: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "metacrypt", "A", "100.95.252.120", 300)
if err != nil {
t.Fatalf("create A record 2: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "mcr", "AAAA", "2001:db8::1", 300)
if err != nil {
t.Fatalf("create AAAA record: %v", err)
}
_, err = database.CreateRecord("svc.mcp.metacircular.net", "alias", "CNAME", "metacrypt.svc.mcp.metacircular.net.", 300)
if err != nil {
t.Fatalf("create CNAME record: %v", err)
}
srv := New(database, []string{"1.1.1.1:53"}, logger)
return srv, database
}
func TestFindZone(t *testing.T) {
srv, _ := setupTestServer(t)
zone := srv.findZone("metacrypt.svc.mcp.metacircular.net.")
if zone == nil {
t.Fatal("expected to find zone")
}
if zone.Name != "svc.mcp.metacircular.net" {
t.Fatalf("got zone %q, want %q", zone.Name, "svc.mcp.metacircular.net")
}
zone = srv.findZone("nonexistent.com.")
if zone != nil {
t.Fatal("expected nil for nonexistent zone")
}
}
func TestBuildSOA(t *testing.T) {
srv, database := setupTestServer(t)
zone, err := database.GetZone("svc.mcp.metacircular.net")
if err != nil {
t.Fatalf("get zone: %v", err)
}
soa := srv.buildSOA(zone)
if soa.Ns != "ns.mcp.metacircular.net." {
t.Fatalf("got ns %q, want %q", soa.Ns, "ns.mcp.metacircular.net.")
}
if soa.Hdr.Name != "svc.mcp.metacircular.net." {
t.Fatalf("got name %q, want %q", soa.Hdr.Name, "svc.mcp.metacircular.net.")
}
}
func TestRecordToRR_A(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "metacrypt", Type: "A", Value: "192.168.88.181", TTL: 300}
rr := srv.recordToRR("metacrypt.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
a, ok := rr.(*dns.A)
if !ok {
t.Fatalf("expected *dns.A, got %T", rr)
}
if a.A.String() != "192.168.88.181" {
t.Fatalf("got IP %q, want %q", a.A.String(), "192.168.88.181")
}
}
func TestRecordToRR_AAAA(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "mcr", Type: "AAAA", Value: "2001:db8::1", TTL: 300}
rr := srv.recordToRR("mcr.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
aaaa, ok := rr.(*dns.AAAA)
if !ok {
t.Fatalf("expected *dns.AAAA, got %T", rr)
}
if aaaa.AAAA.String() != "2001:db8::1" {
t.Fatalf("got IP %q, want %q", aaaa.AAAA.String(), "2001:db8::1")
}
}
func TestRecordToRR_CNAME(t *testing.T) {
srv, _ := setupTestServer(t)
rec := db.Record{Name: "alias", Type: "CNAME", Value: "metacrypt.svc.mcp.metacircular.net.", TTL: 300}
rr := srv.recordToRR("alias.svc.mcp.metacircular.net.", rec)
if rr == nil {
t.Fatal("expected non-nil RR")
}
cname, ok := rr.(*dns.CNAME)
if !ok {
t.Fatalf("expected *dns.CNAME, got %T", rr)
}
if cname.Target != "metacrypt.svc.mcp.metacircular.net." {
t.Fatalf("got target %q, want %q", cname.Target, "metacrypt.svc.mcp.metacircular.net.")
}
}