Files
mcns/internal/dns/forwarder.go
Kyle Isom f9635578e0 Implement MCNS v1: custom Go DNS server replacing CoreDNS
Replace the CoreDNS precursor with a purpose-built authoritative DNS
server. Zones and records (A, AAAA, CNAME) are stored in SQLite and
managed via synchronized gRPC + REST APIs authenticated through MCIAS.
Non-authoritative queries are forwarded to upstream resolvers with
in-memory caching.

Key components:
- DNS server (miekg/dns) with authoritative zone handling and forwarding
- gRPC + REST management APIs with MCIAS auth (mcdsl integration)
- SQLite storage with CNAME exclusivity enforcement and auto SOA serials
- 30 tests covering database CRUD, DNS resolution, and caching

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-26 18:37:14 -07:00

88 lines
1.8 KiB
Go

package dns
import (
"fmt"
"time"
"github.com/miekg/dns"
)
// Forwarder handles forwarding DNS queries to upstream resolvers.
type Forwarder struct {
upstreams []string
client *dns.Client
cache *Cache
}
// NewForwarder creates a Forwarder with the given upstream addresses.
func NewForwarder(upstreams []string) *Forwarder {
return &Forwarder{
upstreams: upstreams,
client: &dns.Client{
Timeout: 2 * time.Second,
},
cache: NewCache(),
}
}
// Forward sends a query to upstream resolvers and returns the response.
// Responses are cached by (qname, qtype, qclass) with TTL-based expiry.
func (f *Forwarder) Forward(r *dns.Msg) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
q := r.Question[0]
// Check cache.
if cached := f.cache.Get(q.Name, q.Qtype, q.Qclass); cached != nil {
return cached.Copy(), nil
}
// Try each upstream in order.
var lastErr error
for _, upstream := range f.upstreams {
resp, _, err := f.client.Exchange(r, upstream)
if err != nil {
lastErr = err
continue
}
// Don't cache SERVFAIL or REFUSED.
if resp.Rcode != dns.RcodeServerFailure && resp.Rcode != dns.RcodeRefused {
ttl := minTTL(resp)
if ttl > 300 {
ttl = 300
}
if ttl > 0 {
f.cache.Set(q.Name, q.Qtype, q.Qclass, resp, time.Duration(ttl)*time.Second)
}
}
return resp, nil
}
return nil, fmt.Errorf("all upstreams failed: %w", lastErr)
}
// minTTL returns the minimum TTL from all resource records in a response.
func minTTL(msg *dns.Msg) uint32 {
var min uint32
first := true
for _, sections := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} {
for _, rr := range sections {
ttl := rr.Header().Ttl
if first || ttl < min {
min = ttl
first = false
}
}
}
if first {
return 60 // No records; default to 60s.
}
return min
}