// Package dns implements the authoritative DNS server for MCNS. // It serves records from SQLite for authoritative zones and forwards // all other queries to configured upstream resolvers. package dns import ( "log/slog" "net" "strings" "github.com/miekg/dns" "git.wntrmute.dev/kyle/mcns/internal/db" ) // Server is the MCNS DNS server. It listens on both UDP and TCP. type Server struct { db *db.DB forwarder *Forwarder logger *slog.Logger udp *dns.Server tcp *dns.Server } // New creates a DNS server that serves records from the database and // forwards non-authoritative queries to the given upstreams. func New(database *db.DB, upstreams []string, logger *slog.Logger) *Server { s := &Server{ db: database, forwarder: NewForwarder(upstreams), logger: logger, } mux := dns.NewServeMux() mux.HandleFunc(".", s.handleQuery) s.udp = &dns.Server{Handler: mux, Net: "udp"} s.tcp = &dns.Server{Handler: mux, Net: "tcp"} return s } // ListenAndServe starts the DNS server on the given address for both // UDP and TCP. It blocks until Shutdown is called. func (s *Server) ListenAndServe(addr string) error { s.udp.Addr = addr s.tcp.Addr = addr errCh := make(chan error, 2) go func() { s.logger.Info("dns server listening", "addr", addr, "proto", "udp") errCh <- s.udp.ListenAndServe() }() go func() { s.logger.Info("dns server listening", "addr", addr, "proto", "tcp") errCh <- s.tcp.ListenAndServe() }() return <-errCh } // Shutdown gracefully stops the DNS server. func (s *Server) Shutdown() { _ = s.udp.Shutdown() _ = s.tcp.Shutdown() } // handleQuery is the main DNS query handler. It checks if the query // falls within an authoritative zone and either serves from the database // or forwards to upstream. func (s *Server) handleQuery(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { s.writeResponse(w, r, dns.RcodeFormatError, nil, nil) return } q := r.Question[0] qname := strings.ToLower(q.Name) // Find the authoritative zone for this query. zone := s.findZone(qname) if zone == nil { // Not authoritative — forward to upstream. s.forwardQuery(w, r) return } s.handleAuthoritativeQuery(w, r, zone, qname, q.Qtype) } // findZone returns the best matching zone for the query name, or nil. func (s *Server) findZone(qname string) *db.Zone { // Walk up the domain labels to find the longest matching zone. name := strings.TrimSuffix(qname, ".") parts := strings.Split(name, ".") for i := range parts { candidate := strings.Join(parts[i:], ".") zone, err := s.db.GetZone(candidate) if err == nil { return zone } } return nil } // handleAuthoritativeQuery serves a query from the database. func (s *Server) handleAuthoritativeQuery(w dns.ResponseWriter, r *dns.Msg, zone *db.Zone, qname string, qtype uint16) { // Extract the record name relative to the zone. zoneFQDN := zone.Name + "." var relName string if qname == zoneFQDN { relName = "@" } else { relName = strings.TrimSuffix(qname, "."+zoneFQDN) } // SOA queries always return the zone apex SOA regardless of query name. if qtype == dns.TypeSOA { soa := s.buildSOA(zone) s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{soa}, nil) return } // Handle NS queries at the zone apex. if qtype == dns.TypeNS && relName == "@" { ns := &dns.NS{ Hdr: dns.RR_Header{Name: zoneFQDN, Rrtype: dns.TypeNS, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)}, Ns: zone.PrimaryNS, } s.writeResponse(w, r, dns.RcodeSuccess, []dns.RR{ns}, nil) return } // Look up the requested record type. var answers []dns.RR var lookupType string switch qtype { case dns.TypeA: lookupType = "A" case dns.TypeAAAA: lookupType = "AAAA" case dns.TypeCNAME: lookupType = "CNAME" default: // For unsupported types, check if the name exists at all. // If it does, return empty answer. If not, NXDOMAIN. exists, _ := s.nameExists(zone.Name, relName) if exists { s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)}) } else { s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)}) } return } records, err := s.db.LookupRecords(zone.Name, relName, lookupType) if err != nil { s.logger.Error("dns lookup failed", "zone", zone.Name, "name", relName, "type", lookupType, "error", err) s.writeResponse(w, r, dns.RcodeServerFailure, nil, nil) return } // If no direct records, check for CNAME. if len(records) == 0 && (qtype == dns.TypeA || qtype == dns.TypeAAAA) { cnameRecords, err := s.db.LookupCNAME(zone.Name, relName) if err == nil && len(cnameRecords) > 0 { for _, rec := range cnameRecords { answers = append(answers, &dns.CNAME{ Hdr: dns.RR_Header{Name: qname, Rrtype: dns.TypeCNAME, Class: dns.ClassINET, Ttl: uint32(rec.TTL)}, Target: rec.Value, }) } s.writeResponse(w, r, dns.RcodeSuccess, answers, nil) return } } if len(records) == 0 { // Name might still exist with other record types. exists, _ := s.nameExists(zone.Name, relName) if exists { // NODATA: name exists but no records of requested type. s.writeResponse(w, r, dns.RcodeSuccess, nil, []dns.RR{s.buildSOA(zone)}) } else { // NXDOMAIN: name does not exist. s.writeResponse(w, r, dns.RcodeNameError, nil, []dns.RR{s.buildSOA(zone)}) } return } for _, rec := range records { rr := s.recordToRR(qname, rec) if rr != nil { answers = append(answers, rr) } } s.writeResponse(w, r, dns.RcodeSuccess, answers, nil) } // nameExists checks if any records exist for a name in a zone. func (s *Server) nameExists(zoneName, name string) (bool, error) { records, err := s.db.ListRecords(zoneName, name, "") if err != nil { return false, err } return len(records) > 0, nil } // recordToRR converts a database Record to a dns.RR. func (s *Server) recordToRR(qname string, rec db.Record) dns.RR { hdr := dns.RR_Header{Name: qname, Class: dns.ClassINET, Ttl: uint32(rec.TTL)} switch rec.Type { case "A": hdr.Rrtype = dns.TypeA return &dns.A{Hdr: hdr, A: parseIP(rec.Value)} case "AAAA": hdr.Rrtype = dns.TypeAAAA return &dns.AAAA{Hdr: hdr, AAAA: parseIP(rec.Value)} case "CNAME": hdr.Rrtype = dns.TypeCNAME return &dns.CNAME{Hdr: hdr, Target: rec.Value} } return nil } // buildSOA constructs a SOA record for the given zone. func (s *Server) buildSOA(zone *db.Zone) *dns.SOA { return &dns.SOA{ Hdr: dns.RR_Header{Name: zone.Name + ".", Rrtype: dns.TypeSOA, Class: dns.ClassINET, Ttl: uint32(zone.MinimumTTL)}, Ns: zone.PrimaryNS, Mbox: zone.AdminEmail, Serial: uint32(zone.Serial), Refresh: uint32(zone.Refresh), Retry: uint32(zone.Retry), Expire: uint32(zone.Expire), Minttl: uint32(zone.MinimumTTL), } } // writeResponse constructs and writes a DNS response. func (s *Server) writeResponse(w dns.ResponseWriter, r *dns.Msg, rcode int, answer []dns.RR, ns []dns.RR) { m := new(dns.Msg) m.SetReply(r) m.Authoritative = true m.Rcode = rcode m.Answer = answer m.Ns = ns if err := w.WriteMsg(m); err != nil { s.logger.Error("dns write failed", "error", err) } } // forwardQuery forwards a DNS query to upstream resolvers. func (s *Server) forwardQuery(w dns.ResponseWriter, r *dns.Msg) { resp, err := s.forwarder.Forward(r) if err != nil { s.logger.Debug("dns forward failed", "error", err) m := new(dns.Msg) m.SetReply(r) m.Rcode = dns.RcodeServerFailure _ = w.WriteMsg(m) return } resp.Id = r.Id if err := w.WriteMsg(resp); err != nil { s.logger.Error("dns write failed", "error", err) } } // parseIP parses an IP address string into a net.IP. func parseIP(s string) net.IP { return net.ParseIP(s) }