mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-16 21:59:56 +00:00
Clamp served TTL to remaining cache life and guard refresh by identity
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -135,14 +136,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
// this question; singleflight would dedup anyway but skipping avoids
|
||||
// a parked goroutine per stale hit under bursty load.
|
||||
if shouldRefresh && inflight == nil {
|
||||
m.scheduleRefresh(question)
|
||||
m.scheduleRefresh(question, cached)
|
||||
}
|
||||
|
||||
resp := &dns.Msg{}
|
||||
resp.SetReply(r)
|
||||
resp.Authoritative = false
|
||||
resp.RecursionAvailable = true
|
||||
resp.Answer = append(resp.Answer, cached.records...)
|
||||
resp.Answer = cloneRecordsWithTTL(cached.records, responseTTL(cached.cachedAt))
|
||||
|
||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||
|
||||
@@ -213,30 +214,35 @@ func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dn
|
||||
}
|
||||
|
||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
||||
// unique in-flight key; bursty stale hits share its channel.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question) {
|
||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
||||
return nil, m.refreshQuestion(question)
|
||||
return nil, m.refreshQuestion(question, expected)
|
||||
})
|
||||
}
|
||||
|
||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
||||
// a resolver loop by spotting a query for this same question arriving on us.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question) error {
|
||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
||||
// if m.records[question] still points at it.
|
||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
||||
if err != nil {
|
||||
m.markRefreshFailed(question)
|
||||
m.markRefreshFailed(question, expected)
|
||||
return fmt.Errorf("parse domain: %w", err)
|
||||
}
|
||||
|
||||
records, err := m.lookupRecords(ctx, d, question)
|
||||
if err != nil {
|
||||
fails := m.markRefreshFailed(question)
|
||||
fails := m.markRefreshFailed(question, expected)
|
||||
logf := log.Warnf
|
||||
if fails > 1 {
|
||||
logf = log.Debugf
|
||||
@@ -249,18 +255,24 @@ func (m *Resolver) refreshQuestion(question dns.Question) error {
|
||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
||||
if len(records) == 0 {
|
||||
m.mutex.Lock()
|
||||
delete(m.records, question)
|
||||
if m.records[question] == expected {
|
||||
delete(m.records, question)
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
m.mutex.Lock()
|
||||
if _, stillCached := m.records[question]; !stillCached {
|
||||
if m.records[question] != expected {
|
||||
m.mutex.Unlock()
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry was removed during refresh",
|
||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
||||
return nil
|
||||
}
|
||||
@@ -286,11 +298,11 @@ func (m *Resolver) clearRefreshing(question dns.Question) {
|
||||
|
||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
||||
// count so callers can downgrade subsequent failure logs to debug.
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question) int {
|
||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
c, ok := m.records[question]
|
||||
if !ok {
|
||||
if !ok || c != expected {
|
||||
return 0
|
||||
}
|
||||
c.lastFailedRefresh = time.Now()
|
||||
@@ -529,6 +541,37 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
return domains
|
||||
}
|
||||
|
||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
||||
// so downstream resolvers don't cache an answer for longer than we will.
|
||||
func responseTTL(cachedAt time.Time) uint32 {
|
||||
remaining := cacheTTL() - time.Since(cachedAt)
|
||||
if remaining <= 0 {
|
||||
return 0
|
||||
}
|
||||
return uint32((remaining + time.Second - 1) / time.Second)
|
||||
}
|
||||
|
||||
// cloneRecordsWithTTL deep-copies A/AAAA records and stamps ttl on each header
|
||||
// so mutating the response never touches the cached slice.
|
||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
||||
out := make([]dns.RR, 0, len(records))
|
||||
for _, rr := range records {
|
||||
switch r := rr.(type) {
|
||||
case *dns.A:
|
||||
cp := *r
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.A = slices.Clone(r.A)
|
||||
out = append(out, &cp)
|
||||
case *dns.AAAA:
|
||||
cp := *r
|
||||
cp.Hdr.Ttl = ttl
|
||||
cp.AAAA = slices.Clone(r.AAAA)
|
||||
out = append(out, &cp)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
var cacheTTL = sync.OnceValue(func() time.Duration {
|
||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
||||
@@ -570,7 +613,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
|
||||
}
|
||||
filtered = append(filtered, &dns.A{
|
||||
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: ttl},
|
||||
A: append(net.IP(nil), r.A...),
|
||||
A: slices.Clone(r.A),
|
||||
})
|
||||
case *dns.AAAA:
|
||||
if qtype != dns.TypeAAAA {
|
||||
@@ -578,7 +621,7 @@ func lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, d
|
||||
}
|
||||
filtered = append(filtered, &dns.AAAA{
|
||||
Hdr: dns.RR_Header{Name: dnsName, Rrtype: dns.TypeAAAA, Class: dns.ClassINET, Ttl: ttl},
|
||||
AAAA: append(net.IP(nil), r.AAAA...),
|
||||
AAAA: slices.Clone(r.AAAA),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user