package dns import ( "fmt" "time" "github.com/miekg/dns" ) // Forwarder handles forwarding DNS queries to upstream resolvers. type Forwarder struct { upstreams []string client *dns.Client cache *Cache } // NewForwarder creates a Forwarder with the given upstream addresses. func NewForwarder(upstreams []string) *Forwarder { return &Forwarder{ upstreams: upstreams, client: &dns.Client{ Timeout: 2 * time.Second, }, cache: NewCache(), } } // Forward sends a query to upstream resolvers and returns the response. // Responses are cached by (qname, qtype, qclass) with TTL-based expiry. func (f *Forwarder) Forward(r *dns.Msg) (*dns.Msg, error) { if len(r.Question) == 0 { return nil, fmt.Errorf("empty question") } q := r.Question[0] // Check cache. if cached := f.cache.Get(q.Name, q.Qtype, q.Qclass); cached != nil { return cached.Copy(), nil } // Try each upstream in order. var lastErr error for _, upstream := range f.upstreams { resp, _, err := f.client.Exchange(r, upstream) if err != nil { lastErr = err continue } // Don't cache SERVFAIL or REFUSED. if resp.Rcode != dns.RcodeServerFailure && resp.Rcode != dns.RcodeRefused { ttl := minTTL(resp) if ttl > 300 { ttl = 300 } if ttl > 0 { f.cache.Set(q.Name, q.Qtype, q.Qclass, resp, time.Duration(ttl)*time.Second) } } return resp, nil } return nil, fmt.Errorf("all upstreams failed: %w", lastErr) } // minTTL returns the minimum TTL from all resource records in a response. func minTTL(msg *dns.Msg) uint32 { var min uint32 first := true for _, sections := range [][]dns.RR{msg.Answer, msg.Ns, msg.Extra} { for _, rr := range sections { ttl := rr.Header().Ttl if first || ttl < min { min = ttl first = false } } } if first { return 60 // No records; default to 60s. } return min }