- Migration v2: INSERT → INSERT OR IGNORE for idempotency - Config: validate server.tls_cert and server.tls_key are non-empty - gRPC: add input validation matching REST handlers - gRPC: add logger to zone/record services, log timestamp parse errors - REST+gRPC: extract SOA defaults into shared db.ApplySOADefaults() - DNS: simplify SOA query condition (remove dead code from precedence bug) - Startup: consolidate shutdown into shutdownAll(), clean up gRPC listener on error path, shut down sibling servers when one fails Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
281 lines
7.5 KiB
Go
281 lines
7.5 KiB
Go
// Package dns implements the authoritative DNS server for MCNS.
|
|
// It serves records from SQLite for authoritative zones and forwards
|
|
// all other queries to configured upstream resolvers.
|
|
package dns
|
|
|
|
import (
|
|
"log/slog"
|
|
"net"
|
|
"strings"
|
|
|
|
"github.com/miekg/dns"
|
|
|
|
"git.wntrmute.dev/kyle/mcns/internal/db"
|
|
)
|
|
|
|
// Server is the MCNS DNS server. It listens on both UDP and TCP.
|
|
type Server struct {
|
|
db *db.DB
|
|
forwarder *Forwarder
|
|
logger *slog.Logger
|
|
udp *dns.Server
|
|
tcp *dns.Server
|
|
}
|
|
|
|
// New creates a DNS server that serves records from the database and
|
|
// forwards non-authoritative queries to the given upstreams.
|
|
func New(database *db.DB, upstreams []string, logger *slog.Logger) *Server {
|
|
s := &Server{
|
|
db: database,
|
|
forwarder: NewForwarder(upstreams),
|
|
logger: logger,
|
|
}
|
|
|
|
mux := dns.NewServeMux()
|
|
mux.HandleFunc(".", s.handleQuery)
|
|
|
|
s.udp = &dns.Server{Handler: mux, Net: "udp"}
|
|
s.tcp = &dns.Server{Handler: mux, Net: "tcp"}
|
|
|
|
return s
|
|
}
|
|
|
|
// ListenAndServe starts the DNS server on the given address for both
|
|
// UDP and TCP. It blocks until Shutdown is called.
|
|
func (s *Server) ListenAndServe(addr string) error {
|
|
s.udp.Addr = addr
|
|
s.tcp.Addr = addr
|
|
|
|
errCh := make(chan error, 2)
|
|
|
|
go func() {
|
|
s.logger.Info("dns server listening", "addr", addr, "proto", "udp")
|
|
errCh <- s.udp.ListenAndServe()
|
|
}()
|
|
go func() {
|
|
s.logger.Info("dns server listening", "addr", addr, "proto", "tcp")
|
|
errCh <- s.tcp.ListenAndServe()
|
|
}()
|
|
|
|
return <-errCh
|
|
}
|
|
|
|
// Shutdown gracefully stops the DNS server.
|
|
func (s *Server) Shutdown() {
|
|
_ = s.udp.Shutdown()
|
|
_ = s.tcp.Shutdown()
|
|
}
|
|
|
|
// handleQuery is the main DNS query handler. It checks if the query
|
|
// falls within an authoritative zone and either serves from the database
|
|
// or forwards to upstream.
|
|
func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) {
|
|
if len(r.Question) == 0 {
|
|
s.writeResponse(w, r, dns.RcodeFormatError, nil, nil)
|
|
return
|
|
}
|
|
|
|
q := r.Question[0]
|
|
qname := strings.ToLower(q.Name)
|
|
|
|
// Find the authoritative zone for this query.
|
|
zone := s.findZone(qname)
|
|
if zone == nil {
|
|
// Not authoritative — forward to upstream.
|
|
s.forwardQuery(w, r)
|
|
return
|
|
}
|
|
|
|
s.handleAuthoritativeQuery(w, r, zone, qname, q.Qtype)
|
|
}
|
|
|
|
// findZone returns the best matching zone for the query name, or nil.
|
|
func (s *Server) findZone(qname string) *db.Zone {
|
|
// Walk up the domain labels to find the longest matching zone.
|
|
name := strings.TrimSuffix(qname, ".")
|
|
parts := strings.Split(name, ".")
|
|
|
|
for i := range parts {
|
|
candidate := strings.Join(parts[i:], ".")
|
|
zone, err := s.db.GetZone(candidate)
|
|
if err == nil {
|
|
return zone
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleAuthoritativeQuery serves a query from the database.
|
|
func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone *db.Zone, qname string, qtype uint16) {
|
|
// Extract the record name relative to the zone.
|
|
zoneFQDN := zone.Name + "."
|
|
var relName string
|
|
if qname == zoneFQDN {
|
|
relName = "@"
|
|
} else {
|
|
relName = strings.TrimSuffix(qname, "."+zoneFQDN)
|
|
}
|
|
|
|
// SOA queries always return the zone apex SOA regardless of query name.
|
|
if qtype == dns.TypeSOA {
|
|
soa := s.buildSOA(zone)
|
|
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil)
|
|
return
|
|
}
|
|
|
|
// Handle NS queries at the zone apex.
|
|
if qtype == dns.TypeNS && relName == "@" {
|
|
ns := &dns.NS{
|
|
Hdr: dns.RR_Header{Name: zoneFQDN, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
|
|
Ns: zone.PrimaryNS,
|
|
}
|
|
s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{ns}, nil)
|
|
return
|
|
}
|
|
|
|
// Look up the requested record type.
|
|
var answers []dns.RR
|
|
var lookupType string
|
|
|
|
switch qtype {
|
|
case dns.TypeA:
|
|
lookupType = "A"
|
|
case dns.TypeAAAA:
|
|
lookupType = "AAAA"
|
|
case dns.TypeCNAME:
|
|
lookupType = "CNAME"
|
|
default:
|
|
// For unsupported types, check if the name exists at all.
|
|
// If it does, return empty answer. If not, NXDOMAIN.
|
|
exists, _ := s.nameExists(zone.Name, relName)
|
|
if exists {
|
|
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
|
|
} else {
|
|
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
|
|
}
|
|
return
|
|
}
|
|
|
|
records, err := s.db.LookupRecords(zone.Name, relName, lookupType)
|
|
if err != nil {
|
|
s.logger.Error("dns lookup failed", "zone", zone.Name, "name", relName, "type", lookupType, "error", err)
|
|
s.writeResponse(w, r, dns.RcodeServerFailure, nil, nil)
|
|
return
|
|
}
|
|
|
|
// If no direct records, check for CNAME.
|
|
if len(records) == 0 && (qtype == dns.TypeA || qtype == dns.TypeAAAA) {
|
|
cnameRecords, err := s.db.LookupCNAME(zone.Name, relName)
|
|
if err == nil && len(cnameRecords) > 0 {
|
|
for _, rec := range cnameRecords {
|
|
answers = append(answers, &dns.CNAME{
|
|
Hdr: dns.RR_Header{Name: qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(rec.TTL)},
|
|
Target: rec.Value,
|
|
})
|
|
}
|
|
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
|
|
return
|
|
}
|
|
}
|
|
|
|
if len(records) == 0 {
|
|
// Name might still exist with other record types.
|
|
exists, _ := s.nameExists(zone.Name, relName)
|
|
if exists {
|
|
// NODATA: name exists but no records of requested type.
|
|
s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)})
|
|
} else {
|
|
// NXDOMAIN: name does not exist.
|
|
s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)})
|
|
}
|
|
return
|
|
}
|
|
|
|
for _, rec := range records {
|
|
rr := s.recordToRR(qname, rec)
|
|
if rr != nil {
|
|
answers = append(answers, rr)
|
|
}
|
|
}
|
|
|
|
s.writeResponse(w, r, dns.RcodeSuccess, answers, nil)
|
|
}
|
|
|
|
// nameExists checks if any records exist for a name in a zone.
|
|
func (s *Server) nameExists(zoneName, name string) (bool, error) {
|
|
records, err := s.db.ListRecords(zoneName, name, "")
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return len(records) > 0, nil
|
|
}
|
|
|
|
// recordToRR converts a database Record to a dns.RR.
|
|
func (s *Server) recordToRR(qname string, rec db.Record) dns.RR {
|
|
hdr := dns.RR_Header{Name: qname, Class: dns.ClassINET, Ttl: uint32(rec.TTL)}
|
|
|
|
switch rec.Type {
|
|
case "A":
|
|
hdr.Rrtype = dns.TypeA
|
|
return &dns.A{Hdr: hdr, A: parseIP(rec.Value)}
|
|
case "AAAA":
|
|
hdr.Rrtype = dns.TypeAAAA
|
|
return &dns.AAAA{Hdr: hdr, AAAA: parseIP(rec.Value)}
|
|
case "CNAME":
|
|
hdr.Rrtype = dns.TypeCNAME
|
|
return &dns.CNAME{Hdr: hdr, Target: rec.Value}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// buildSOA constructs a SOA record for the given zone.
|
|
func (s *Server) buildSOA(zone *db.Zone) *dns.SOA {
|
|
return &dns.SOA{
|
|
Hdr: dns.RR_Header{Name: zone.Name + ".", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)},
|
|
Ns: zone.PrimaryNS,
|
|
Mbox: zone.AdminEmail,
|
|
Serial: uint32(zone.Serial),
|
|
Refresh: uint32(zone.Refresh),
|
|
Retry: uint32(zone.Retry),
|
|
Expire: uint32(zone.Expire),
|
|
Minttl: uint32(zone.MinimumTTL),
|
|
}
|
|
}
|
|
|
|
// writeResponse constructs and writes a DNS response.
|
|
func (s *Server) writeResponse(w dns.ResponseWriter, r *dns.Msg, rcode int, answer []dns.RR, ns []dns.RR) {
|
|
m := new(dns.Msg)
|
|
m.SetReply(r)
|
|
m.Authoritative = true
|
|
m.Rcode = rcode
|
|
m.Answer = answer
|
|
m.Ns = ns
|
|
|
|
if err := w.WriteMsg(m); err != nil {
|
|
s.logger.Error("dns write failed", "error", err)
|
|
}
|
|
}
|
|
|
|
// forwardQuery forwards a DNS query to upstream resolvers.
|
|
func (s *Server) forwardQuery(w dns.ResponseWriter, r *dns.Msg) {
|
|
resp, err := s.forwarder.Forward(r)
|
|
if err != nil {
|
|
s.logger.Debug("dns forward failed", "error", err)
|
|
m := new(dns.Msg)
|
|
m.SetReply(r)
|
|
m.Rcode = dns.RcodeServerFailure
|
|
_ = w.WriteMsg(m)
|
|
return
|
|
}
|
|
|
|
resp.Id = r.Id
|
|
if err := w.WriteMsg(resp); err != nil {
|
|
s.logger.Error("dns write failed", "error", err)
|
|
}
|
|
}
|
|
|
|
// parseIP parses an IP address string into a net.IP.
|
|
func parseIP(s string) net.IP {
|
|
return net.ParseIP(s)
|
|
}
|