diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index cef960366..d6bb3d91a 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -7,6 +7,7 @@ import ( "net" "net/url" "os" + "slices" "strings" "sync" "sync/atomic" @@ -135,14 +136,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // this question; singleflight would dedup anyway but skipping avoids // a parked goroutine per stale hit under bursty load. if shouldRefresh && inflight == nil { - m.scheduleRefresh(question) + m.scheduleRefresh(question, cached) } resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - resp.Answer = append(resp.Answer, cached.records...) + resp.Answer = cloneRecordsWithTTL(cached.records, responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -213,30 +214,35 @@ func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dn } // scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per -// unique in-flight key; bursty stale hits share its channel. -func (m *Resolver) scheduleRefresh(question dns.Question) { +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { key := question.Name + "|" + dns.TypeToString[question.Qtype] _ = m.refreshGroup.DoChan(key, func() (any, error) { - return nil, m.refreshQuestion(question) + return nil, m.refreshQuestion(question, expected) }) } // refreshQuestion replaces the cached records on success, or marks the entry // failed (arming the backoff) on failure. While this runs, ServeDNS can detect // a resolver loop by spotting a query for this same question arriving on us. -func (m *Resolver) refreshQuestion(question dns.Question) error { +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) defer cancel() d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) if err != nil { - m.markRefreshFailed(question) + m.markRefreshFailed(question, expected) return fmt.Errorf("parse domain: %w", err) } records, err := m.lookupRecords(ctx, d, question) if err != nil { - fails := m.markRefreshFailed(question) + fails := m.markRefreshFailed(question, expected) logf := log.Warnf if fails > 1 { logf = log.Debugf @@ -249,18 +255,24 @@ func (m *Resolver) refreshQuestion(question dns.Question) error { // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. if len(records) == 0 { m.mutex.Lock() - delete(m.records, question) + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } m.mutex.Unlock() - log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", d.SafeString(), dns.TypeToString[question.Qtype]) return nil } now := time.Now() m.mutex.Lock() - if _, stillCached := m.records[question]; !stillCached { + if m.records[question] != expected { m.mutex.Unlock() - log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh", + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", d.SafeString(), dns.TypeToString[question.Qtype]) return nil } @@ -286,11 +298,11 @@ func (m *Resolver) clearRefreshing(question dns.Question) { // markRefreshFailed arms the backoff and returns the new consecutive-failure // count so callers can downgrade subsequent failure logs to debug. -func (m *Resolver) markRefreshFailed(question dns.Question) int { +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { m.mutex.Lock() defer m.mutex.Unlock() c, ok := m.records[question] - if !ok { + if !ok || c != expected { return 0 } c.lastFailedRefresh = time.Now() @@ -529,6 +541,37 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func responseTTL(cachedAt time.Time) uint32 { + remaining := cacheTTL() - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) +} + +// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header +// so mutating the response never touches the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + out = append(out, &cp) + case *dns.AAAA: + cp := *r + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + out = append(out, &cp) + } + } + return out +} + var cacheTTL = sync.OnceValue(func() time.Duration { if v := os.Getenv(envMgmtCacheTTL); v != "" { if d, err := time.ParseDuration(v); err == nil && d > 0 { @@ -570,7 +613,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } filtered = append(filtered, &dns.A{ Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl}, - A: append(net.IP(nil), r.A...), + A: slices.Clone(r.A), }) case *dns.AAAA: if qtype != dns.TypeAAAA { @@ -578,7 +621,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d } filtered = append(filtered, &dns.AAAA{ Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl}, - AAAA: append(net.IP(nil), r.AAAA...), + AAAA: slices.Clone(r.AAAA), }) } }