Clamp served TTL to remaining cache life and guard refresh by identity

This commit is contained in:
Viktor Liu
2026-04-22 11:24:16 +02:00
parent 236e99af63
commit 4c25ac674a

View File

@@ -7,6 +7,7 @@ import (
"net"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
@@ -135,14 +136,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question)
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = append(resp.Answer, cached.records...)
resp.Answer = cloneRecordsWithTTL(cached.records, responseTTL(cached.cachedAt))
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -213,30 +214,35 @@ func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dn
}
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel.
func (m *Resolver) scheduleRefresh(question dns.Question) {
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question)
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
func (m *Resolver) refreshQuestion(question dns.Question) error {
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question)
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question)
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails > 1 {
logf = log.Debugf
@@ -249,18 +255,24 @@ func (m *Resolver) refreshQuestion(question dns.Question) error {
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
delete(m.records, question)
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
now := time.Now()
m.mutex.Lock()
if _, stillCached := m.records[question]; !stillCached {
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh",
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
@@ -286,11 +298,11 @@ func (m *Resolver) clearRefreshing(question dns.Question) {
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question) int {
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok {
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
@@ -529,6 +541,37 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func responseTTL(cachedAt time.Time) uint32 {
remaining := cacheTTL() - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
}
// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header
// so mutating the response never touches the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
out = append(out, &cp)
case *dns.AAAA:
cp := *r
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
out = append(out, &cp)
}
}
return out
}
var cacheTTL = sync.OnceValue(func() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
@@ -570,7 +613,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
}
filtered = append(filtered, &dns.A{
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
A: append(net.IP(nil), r.A...),
A: slices.Clone(r.A),
})
case *dns.AAAA:
if qtype != dns.TypeAAAA {
@@ -578,7 +621,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
}
filtered = append(filtered, &dns.AAAA{
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
AAAA: append(net.IP(nil), r.AAAA...),
AAAA: slices.Clone(r.AAAA),
})
}
}