From dc8c2edf50a53eda77044fc1e9280cd240c7da57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Zolt=C3=A1n=20Papp?= Date: Thu, 23 Apr 2026 21:29:46 +0200 Subject: [PATCH] Revert "[client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945)" This reverts commit 801de8c68d4725f313344d4b6f684c3a86e59b90. --- client/internal/dns/handler_chain.go | 94 ---- client/internal/dns/handler_chain_test.go | 164 ------ client/internal/dns/mgmt/mgmt.go | 489 ++++-------------- client/internal/dns/mgmt/mgmt_refresh_test.go | 408 --------------- client/internal/dns/mgmt/mgmt_test.go | 55 -- client/internal/dns/server.go | 1 - 6 files changed, 97 insertions(+), 1114 deletions(-) delete mode 100644 client/internal/dns/mgmt/mgmt_refresh_test.go diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 57e7722d4..6fbdedc59 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,10 +1,7 @@ package dns import ( - "context" "fmt" - "math" - "net" "slices" "strconv" "strings" @@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - c.dispatch(w, r, math.MaxInt) -} - -// dispatch routes a DNS request through the chain, skipping handlers with -// priority > maxPriority. Shared by ServeDNS and ResolveInternal. -func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } @@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in // Try handlers in priority order for _, entry := range handlers { - if entry.Priority > maxPriority { - continue - } if !c.isHandlerMatch(qname, entry) { continue } @@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q cw.response.Len(), meta, time.Since(startTime)) } -// ResolveInternal runs an in-process DNS query against the chain, skipping any -// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt -// cache refresher) that must bypass themselves to avoid loops. Honors ctx -// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own -// (bounded by the invoked handler's internal timeout). -func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { - if len(r.Question) == 0 { - return nil, fmt.Errorf("empty question") - } - - base := &internalResponseWriter{} - done := make(chan struct{}) - go func() { - c.dispatch(base, r, maxPriority) - close(done) - }() - - select { - case <-done: - case <-ctx.Done(): - // Prefer a completed response if dispatch finished concurrently with cancellation. - select { - case <-done: - default: - return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) - } - } - - if base.response == nil || base.response.Rcode == dns.RcodeRefused { - return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", - strings.ToLower(r.Question[0].Name), maxPriority) - } - return base.response, nil -} - -// HasRootHandlerAtOrBelow reports whether any "." handler is registered at -// priority ≤ maxPriority. -func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { - c.mu.RLock() - defer c.mu.RUnlock() - - for _, h := range c.handlers { - if h.Pattern == "." && h.Priority <= maxPriority { - return true - } - } - return false -} - func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": @@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } - -// internalResponseWriter captures a dns.Msg for in-process chain queries. -type internalResponseWriter struct { - response *dns.Msg -} - -func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } -func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } -func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } - -// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg -// still surface their answer to ResolveInternal. -func (w *internalResponseWriter) Write(p []byte) (int, error) { - msg := new(dns.Msg) - if err := msg.Unpack(p); err != nil { - return 0, err - } - w.response = msg - return len(p), nil -} - -func (w *internalResponseWriter) Close() error { return nil } -func (w *internalResponseWriter) TsigStatus() error { return nil } - -// TsigTimersOnly is part of dns.ResponseWriter. -func (w *internalResponseWriter) TsigTimersOnly(bool) { - // no-op: in-process queries carry no TSIG state. -} - -// Hijack is part of dns.ResponseWriter. -func (w *internalResponseWriter) Hijack() { - // no-op: in-process queries have no underlying connection to hand off. -} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index 034a760dc..fa9525069 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,15 +1,11 @@ package dns_test import ( - "context" - "net" "testing" - "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } - -// answeringHandler writes a fixed A record to ack the query. Used to verify -// which handler ResolveInternal dispatches to. -type answeringHandler struct { - name string - ip string -} - -func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - resp := &dns.Msg{} - resp.SetReply(r) - resp.Answer = []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP(h.ip).To4(), - }} - _ = w.WriteMsg(resp) -} - -func (h *answeringHandler) String() string { return h.name } - -func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { - chain := nbdns.NewHandlerChain() - - high := &answeringHandler{name: "high", ip: "10.0.0.1"} - low := &answeringHandler{name: "low", ip: "10.0.0.2"} - - chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) - chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) - - r := new(dns.Msg) - r.SetQuestion("example.com.", dns.TypeA) - - resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) - assert.NoError(t, err) - assert.NotNil(t, resp) - assert.Equal(t, 1, len(resp.Answer)) - a, ok := resp.Answer[0].(*dns.A) - assert.True(t, ok) - assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") -} - -func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { - chain := nbdns.NewHandlerChain() - high := &answeringHandler{name: "high", ip: "10.0.0.1"} - chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) - - r := new(dns.Msg) - r.SetQuestion("example.com.", dns.TypeA) - - _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) - assert.Error(t, err, "no handler at or below maxPriority should error") -} - -// rawWriteHandler packs a response and calls ResponseWriter.Write directly -// (instead of WriteMsg), exercising the internalResponseWriter.Write path. -type rawWriteHandler struct { - ip string -} - -func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - resp := &dns.Msg{} - resp.SetReply(r) - resp.Answer = []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP(h.ip).To4(), - }} - packed, err := resp.Pack() - if err != nil { - return - } - _, _ = w.Write(packed) -} - -func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { - chain := nbdns.NewHandlerChain() - chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) - - r := new(dns.Msg) - r.SetQuestion("example.com.", dns.TypeA) - - resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) - assert.NoError(t, err) - require.NotNil(t, resp) - require.Len(t, resp.Answer, 1) - a, ok := resp.Answer[0].(*dns.A) - require.True(t, ok) - assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") -} - -func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { - chain := nbdns.NewHandlerChain() - _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) - assert.Error(t, err) -} - -// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. -type hangingHandler struct { - block chan struct{} -} - -func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { - <-h.block - resp := &dns.Msg{} - resp.SetReply(r) - _ = w.WriteMsg(resp) -} - -func (h *hangingHandler) String() string { return "hangingHandler" } - -func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { - chain := nbdns.NewHandlerChain() - h := &hangingHandler{block: make(chan struct{})} - defer close(h.block) - - chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) - - r := new(dns.Msg) - r.SetQuestion("example.com.", dns.TypeA) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - start := time.Now() - _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) - elapsed := time.Since(start) - - assert.Error(t, err) - assert.ErrorIs(t, err, context.DeadlineExceeded) - assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") -} - -func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { - chain := nbdns.NewHandlerChain() - h := &answeringHandler{name: "h", ip: "10.0.0.1"} - - assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") - - chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) - assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") - - chain.AddHandler(".", h, nbdns.PriorityMgmtCache) - assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") - - chain.AddHandler(".", h, nbdns.PriorityDefault) - assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") - - chain.RemoveHandler(".", nbdns.PriorityDefault) - assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) - - // Primary nsgroup case: root handler lands at PriorityUpstream. - chain.AddHandler(".", h, nbdns.PriorityUpstream) - assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") - chain.RemoveHandler(".", nbdns.PriorityUpstream) - - // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. - chain.AddHandler(".", h, nbdns.PriorityFallback) - assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") - chain.RemoveHandler(".", nbdns.PriorityFallback) - assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) -} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 988e427fb..314af51d9 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,83 +2,40 @@ package mgmt import ( "context" - "errors" "fmt" "net" + "net/netip" "net/url" - "os" - "slices" "strings" "sync" - "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" - "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" - "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const ( - dnsTimeout = 5 * time.Second - defaultTTL = 300 * time.Second - refreshBackoff = 30 * time.Second +const dnsTimeout = 5 * time.Second - // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. - envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" -) - -// ChainResolver lets the cache refresh stale entries through the DNS handler -// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the -// system resolver. -type ChainResolver interface { - ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) - HasRootHandlerAtOrBelow(maxPriority int) bool -} - -// cachedRecord holds DNS records plus timestamps used for TTL refresh. -// records and cachedAt are set at construction and treated as immutable; -// lastFailedRefresh and consecFailures are mutable and must be accessed under -// Resolver.mutex. -type cachedRecord struct { - records []dns.RR - cachedAt time.Time - lastFailedRefresh time.Time - consecFailures int -} - -// Resolver caches critical NetBird infrastructure domains. -// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. +// Resolver caches critical NetBird infrastructure domains type Resolver struct { - records map[dns.Question]*cachedRecord + records map[dns.Question][]dns.RR mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex +} - chain ChainResolver - chainMaxPriority int - refreshGroup singleflight.Group - - // refreshing tracks questions whose refresh is running via the OS - // fallback path. A ServeDNS hit for a question in this map indicates - // the OS resolver routed the recursive query back to us (loop). Only - // the OS path arms this so chain-path refreshes don't produce false - // positives. The atomic bool is CAS-flipped once per refresh to - // throttle the warning log. - refreshing map[dns.Question]*atomic.Bool - - cacheTTL time.Duration +type ipsResponse struct { + ips []netip.Addr + err error } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question]*cachedRecord), - refreshing: make(map[dns.Question]*atomic.Bool), - cacheTTL: resolveCacheTTL(), + records: make(map[dns.Question][]dns.RR), } } @@ -87,19 +44,7 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// SetChainResolver wires the handler chain used to refresh stale cache entries. -// maxPriority caps which handlers may answer refresh queries (typically -// PriorityUpstream, so upstream/default/fallback handlers are consulted and -// mgmt/route/local handlers are skipped). -func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { - m.mutex.Lock() - m.chain = chain - m.chainMaxPriority = maxPriority - m.mutex.Unlock() -} - -// ServeDNS serves cached A/AAAA records. Stale entries are returned -// immediately and refreshed asynchronously (stale-while-revalidate). +// ServeDNS implements dns.Handler interface. func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - cached, found := m.records[question] - inflight := m.refreshing[question] - var shouldRefresh bool - if found { - stale := time.Since(cached.cachedAt) > m.cacheTTL - inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff - shouldRefresh = stale && !inBackoff - } + records, found := m.records[question] m.mutex.RUnlock() if !found { @@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } - if inflight != nil && inflight.CompareAndSwap(false, true) { - log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", - question.Name) - } - - // Skip scheduling a refresh goroutine if one is already inflight for - // this question; singleflight would dedup anyway but skipping avoids - // a parked goroutine per stale hit under bursty load. - if shouldRefresh && inflight == nil { - m.scheduleRefresh(question, cached) - } - resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) + + resp.Answer = append(resp.Answer, records...) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain resolves a domain and stores its A/AAAA records in the cache. -// A family that resolves NODATA (nil err, zero records) evicts any stale -// entry for that qtype. +// AddDomain manually adds a domain to cache by resolving it. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) - - if errA != nil && errAAAA != nil { - return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) + ips, err := lookupIPWithExtraTimeout(ctx, d) + if err != nil { + return err } - if len(aRecords) == 0 && len(aaaaRecords) == 0 { - if err := errors.Join(errA, errAAAA); err != nil { - return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) + var aRecords, aaaaRecords []dns.RR + for _, ip := range ips { + if ip.Is4() { + rr := &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: ip.AsSlice(), + } + aRecords = append(aRecords, rr) + } else if ip.Is6() { + rr := &dns.AAAA{ + Hdr: dns.RR_Header{ + Name: dnsName, + Rrtype: dns.TypeAAAA, + Class: dns.ClassINET, + Ttl: 300, + }, + AAAA: ip.AsSlice(), + } + aaaaRecords = append(aaaaRecords, rr) } - return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } - now := time.Now() m.mutex.Lock() - defer m.mutex.Unlock() - m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) - m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) + if len(aRecords) > 0 { + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + m.records[aQuestion] = aRecords + } - log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", + if len(aaaaRecords) > 0 { + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + m.records[aaaaQuestion] = aaaaRecords + } + + m.mutex.Unlock() + + log.Debugf("added domain=%s with %d A records and %d AAAA records", d.SafeString(), len(aRecords), len(aaaaRecords)) return nil } -// applyFamilyRecords writes records, evicts on NODATA, leaves the cache -// untouched on error. Caller holds m.mutex. -func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { - q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} - switch { - case len(records) > 0: - m.records[q] = &cachedRecord{records: records, cachedAt: now} - case err == nil: - delete(m.records, q) - } -} +func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { + log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) + defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) + resultChan := make(chan *ipsResponse, 1) -// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per -// unique in-flight key; bursty stale hits share its channel. expected is the -// cachedRecord pointer observed by the caller; the refresh only mutates the -// cache if that pointer is still the one stored, so a stale in-flight refresh -// can't clobber a newer entry written by AddDomain or a competing refresh. -func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { - key := question.Name + "|" + dns.TypeToString[question.Qtype] - _ = m.refreshGroup.DoChan(key, func() (any, error) { - return nil, m.refreshQuestion(question, expected) - }) -} - -// refreshQuestion replaces the cached records on success, or marks the entry -// failed (arming the backoff) on failure. While this runs, ServeDNS can detect -// a resolver loop by spotting a query for this same question arriving on us. -// expected pins the cache entry observed at schedule time; mutations only apply -// if m.records[question] still points at it. -func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { - ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) - defer cancel() - - d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) - if err != nil { - m.markRefreshFailed(question, expected) - return fmt.Errorf("parse domain: %w", err) - } - - records, err := m.lookupRecords(ctx, d, question) - if err != nil { - fails := m.markRefreshFailed(question, expected) - logf := log.Warnf - if fails == 0 || fails > 1 { - logf = log.Debugf + go func() { + ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) + resultChan <- &ipsResponse{ + err: err, + ips: ips, } - logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", - d.SafeString(), dns.TypeToString[question.Qtype], err, fails) - return err + }() + + var resp *ipsResponse + + select { + case <-time.After(dnsTimeout + time.Millisecond*500): + log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) + return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) + case <-ctx.Done(): + return nil, ctx.Err() + case resp = <-resultChan: } - // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. - if len(records) == 0 { - m.mutex.Lock() - if m.records[question] == expected { - delete(m.records, question) - m.mutex.Unlock() - log.Infof("removed mgmt cache domain=%s type=%s: no records returned", - d.SafeString(), dns.TypeToString[question.Qtype]) - return nil - } - m.mutex.Unlock() - log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", - d.SafeString(), dns.TypeToString[question.Qtype]) - return nil + if resp.err != nil { + return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) } - - now := time.Now() - m.mutex.Lock() - if m.records[question] != expected { - m.mutex.Unlock() - log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", - d.SafeString(), dns.TypeToString[question.Qtype]) - return nil - } - m.records[question] = &cachedRecord{records: records, cachedAt: now} - m.mutex.Unlock() - - log.Infof("refreshed mgmt cache domain=%s type=%s", - d.SafeString(), dns.TypeToString[question.Qtype]) - return nil -} - -func (m *Resolver) markRefreshing(question dns.Question) { - m.mutex.Lock() - m.refreshing[question] = &atomic.Bool{} - m.mutex.Unlock() -} - -func (m *Resolver) clearRefreshing(question dns.Question) { - m.mutex.Lock() - delete(m.refreshing, question) - m.mutex.Unlock() -} - -// markRefreshFailed arms the backoff and returns the new consecutive-failure -// count so callers can downgrade subsequent failure logs to debug. -func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { - m.mutex.Lock() - defer m.mutex.Unlock() - c, ok := m.records[question] - if !ok || c != expected { - return 0 - } - c.lastFailedRefresh = time.Now() - c.consecFailures++ - return c.consecFailures -} - -// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let -// callers tell records, NODATA (nil err, no records), and failure apart. -func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { - m.mutex.RLock() - chain := m.chain - maxPriority := m.chainMaxPriority - m.mutex.RUnlock() - - if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { - aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) - aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) - return - } - - // TODO: drop once every supported OS registers a fallback resolver. Safe - // today: no root handler at priority ≤ PriorityUpstream means NetBird is - // not the system resolver, so net.DefaultResolver will not loop back. - aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) - aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) - return -} - -// lookupRecords resolves a single record type via chain or OS. The OS branch -// arms the loop detector for the duration of its call so that ServeDNS can -// spot the OS resolver routing the recursive query back to us. -func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { - m.mutex.RLock() - chain := m.chain - maxPriority := m.chainMaxPriority - m.mutex.RUnlock() - - if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { - return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) - } - - // TODO: drop once every supported OS registers a fallback resolver. - m.markRefreshing(q) - defer m.clearRefreshing(q) - - return m.osLookup(ctx, d, q.Name, q.Qtype) -} - -// lookupViaChain resolves via the handler chain and rewrites each RR to use -// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache -// target-owned records or upstream TTLs. NODATA returns (nil, nil). -func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { - msg := &dns.Msg{} - msg.SetQuestion(dnsName, qtype) - msg.RecursionDesired = true - - resp, err := chain.ResolveInternal(ctx, msg, maxPriority) - if err != nil { - return nil, fmt.Errorf("chain resolve: %w", err) - } - if resp == nil { - return nil, fmt.Errorf("chain resolve returned nil response") - } - if resp.Rcode != dns.RcodeSuccess { - return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) - } - - ttl := uint32(m.cacheTTL.Seconds()) - owners := cnameOwners(dnsName, resp.Answer) - var filtered []dns.RR - for _, rr := range resp.Answer { - h := rr.Header() - if h.Class != dns.ClassINET || h.Rrtype != qtype { - continue - } - if !owners[strings.ToLower(dns.Fqdn(h.Name))] { - continue - } - if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { - filtered = append(filtered, cp) - } - } - return filtered, nil -} - -// osLookup resolves a single family via net.DefaultResolver using resutil, -// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA -// returns (nil, nil). -func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { - network := resutil.NetworkForQtype(qtype) - if network == "" { - return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) - } - - log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) - - result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) - if result.Rcode == dns.RcodeSuccess { - return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil - } - - if result.Err != nil { - return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) - } - return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) -} - -// responseTTL returns the remaining cache lifetime in seconds (rounded up), -// so downstream resolvers don't cache an answer for longer than we will. -func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { - remaining := m.cacheTTL - time.Since(cachedAt) - if remaining <= 0 { - return 0 - } - return uint32((remaining + time.Second - 1) / time.Second) + return resp.ips, nil } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} - qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} - delete(m.records, qA) - delete(m.records, qAAAA) - delete(m.refreshing, qA) - delete(m.refreshing, qAAAA) + aQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeA, + Qclass: dns.ClassINET, + } + delete(m.records, aQuestion) + + aaaaQuestion := dns.Question{ + Name: dnsName, + Qtype: dns.TypeAAAA, + Qclass: dns.ClassINET, + } + delete(m.records, aaaaQuestion) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } - -// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non -// A/AAAA records return nil. -func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { - switch r := rr.(type) { - case *dns.A: - cp := *r - cp.Hdr.Name = owner - cp.Hdr.Ttl = ttl - cp.A = slices.Clone(r.A) - return &cp - case *dns.AAAA: - cp := *r - cp.Hdr.Name = owner - cp.Hdr.Ttl = ttl - cp.AAAA = slices.Clone(r.AAAA) - return &cp - } - return nil -} - -// cloneRecordsWithTTL clones A/AAAA records preserving their owner and -// stamping ttl so the response shares no memory with the cached slice. -func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { - out := make([]dns.RR, 0, len(records)) - for _, rr := range records { - if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { - out = append(out, cp) - } - } - return out -} - -// cnameOwners returns dnsName plus every target reachable by following CNAMEs -// in answer, iterating until fixed point so out-of-order chains resolve. -func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { - owners := map[string]bool{dnsName: true} - for { - added := false - for _, rr := range answer { - cname, ok := rr.(*dns.CNAME) - if !ok { - continue - } - name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) - if !owners[name] { - continue - } - target := strings.ToLower(dns.Fqdn(cname.Target)) - if !owners[target] { - owners[target] = true - added = true - } - } - if !added { - return owners - } - } -} - -// resolveCacheTTL reads the cache TTL override env var; invalid or empty -// values fall back to defaultTTL. Called once per Resolver from NewResolver. -func resolveCacheTTL() time.Duration { - if v := os.Getenv(envMgmtCacheTTL); v != "" { - if d, err := time.ParseDuration(v); err == nil && d > 0 { - return d - } - } - return defaultTTL -} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go deleted file mode 100644 index 9faa5a0b8..000000000 --- a/client/internal/dns/mgmt/mgmt_refresh_test.go +++ /dev/null @@ -1,408 +0,0 @@ -package mgmt - -import ( - "context" - "errors" - "net" - "sync" - "sync/atomic" - "testing" - "time" - - "github.com/miekg/dns" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/netbirdio/netbird/client/internal/dns/test" - "github.com/netbirdio/netbird/shared/management/domain" -) - -type fakeChain struct { - mu sync.Mutex - calls map[string]int - answers map[string][]dns.RR - err error - hasRoot bool - onLookup func() -} - -func newFakeChain() *fakeChain { - return &fakeChain{ - calls: map[string]int{}, - answers: map[string][]dns.RR{}, - hasRoot: true, - } -} - -func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { - f.mu.Lock() - defer f.mu.Unlock() - return f.hasRoot -} - -func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { - f.mu.Lock() - q := msg.Question[0] - key := q.Name + "|" + dns.TypeToString[q.Qtype] - f.calls[key]++ - answers := f.answers[key] - err := f.err - onLookup := f.onLookup - f.mu.Unlock() - - if onLookup != nil { - onLookup() - } - if err != nil { - return nil, err - } - resp := &dns.Msg{} - resp.SetReply(msg) - resp.Answer = answers - return resp, nil -} - -func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { - f.mu.Lock() - defer f.mu.Unlock() - key := name + "|" + dns.TypeToString[qtype] - hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} - switch qtype { - case dns.TypeA: - f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} - case dns.TypeAAAA: - f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} - } -} - -func (f *fakeChain) callCount(name string, qtype uint16) int { - f.mu.Lock() - defer f.mu.Unlock() - return f.calls[name+"|"+dns.TypeToString[qtype]] -} - -// waitFor polls the predicate until it returns true or the deadline passes. -func waitFor(t *testing.T, d time.Duration, fn func() bool) { - t.Helper() - deadline := time.Now().Add(d) - for time.Now().Before(deadline) { - if fn() { - return - } - time.Sleep(5 * time.Millisecond) - } - t.Fatalf("condition not met within %s", d) -} - -func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { - t.Helper() - msg := new(dns.Msg) - msg.SetQuestion(name, dns.TypeA) - w := &test.MockResponseWriter{} - r.ServeDNS(w, msg) - return w.GetLastResponse() -} - -func firstA(t *testing.T, resp *dns.Msg) string { - t.Helper() - require.NotNil(t, resp) - require.Greater(t, len(resp.Answer), 0, "expected at least one answer") - a, ok := resp.Answer[0].(*dns.A) - require.True(t, ok, "expected A record") - return a.A.String() -} - -func TestResolver_CacheTTLGatesRefresh(t *testing.T) { - // Same cached entry age, different cacheTTL values: the shorter TTL must - // trigger a background refresh, the longer one must not. Proves that the - // per-Resolver cacheTTL field actually drives the stale decision. - cachedAt := time.Now().Add(-100 * time.Millisecond) - - newRec := func() *cachedRecord { - return &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: cachedAt, - } - } - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - - t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { - r := NewResolver() - r.cacheTTL = 10 * time.Millisecond - chain := newFakeChain() - chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") - r.SetChainResolver(chain, 50) - r.records[q] = newRec() - - resp := queryA(t, r, q.Name) - assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") - - waitFor(t, time.Second, func() bool { - return chain.callCount(q.Name, dns.TypeA) >= 1 - }) - }) - - t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { - r := NewResolver() - r.cacheTTL = time.Hour - chain := newFakeChain() - chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") - r.SetChainResolver(chain, 50) - r.records[q] = newRec() - - resp := queryA(t, r, q.Name) - assert.Equal(t, "10.0.0.1", firstA(t, resp)) - - time.Sleep(50 * time.Millisecond) - assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") - }) -} - -func TestResolver_ServeFresh_NoRefresh(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") - r.SetChainResolver(chain, 50) - - r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now(), // fresh - } - - resp := queryA(t, r, "mgmt.example.com.") - assert.Equal(t, "10.0.0.1", firstA(t, resp)) - - time.Sleep(20 * time.Millisecond) - assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") -} - -func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") - r.SetChainResolver(chain, 50) - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now().Add(-2 * defaultTTL), // stale - } - - // First query: serves stale immediately. - resp := queryA(t, r, "mgmt.example.com.") - assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") - - waitFor(t, time.Second, func() bool { - return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 - }) - - // Next query should now return the refreshed IP. - waitFor(t, time.Second, func() bool { - resp := queryA(t, r, "mgmt.example.com.") - return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" - }) -} - -func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") - - var inflight atomic.Int32 - var maxInflight atomic.Int32 - chain.onLookup = func() { - cur := inflight.Add(1) - defer inflight.Add(-1) - for { - prev := maxInflight.Load() - if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { - break - } - } - time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide - } - - r.SetChainResolver(chain, 50) - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now().Add(-2 * defaultTTL), - } - - var wg sync.WaitGroup - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { - defer wg.Done() - queryA(t, r, "mgmt.example.com.") - }() - } - wg.Wait() - - waitFor(t, 2*time.Second, func() bool { - return inflight.Load() == 0 - }) - - calls := chain.callCount("mgmt.example.com.", dns.TypeA) - assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) - assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") -} - -func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.err = errors.New("boom") - r.SetChainResolver(chain, 50) - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now().Add(-2 * defaultTTL), - } - - // First stale hit triggers a refresh attempt that fails. - resp := queryA(t, r, "mgmt.example.com.") - assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") - - waitFor(t, time.Second, func() bool { - return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 - }) - waitFor(t, time.Second, func() bool { - r.mutex.RLock() - defer r.mutex.RUnlock() - c, ok := r.records[q] - return ok && !c.lastFailedRefresh.IsZero() - }) - - // Subsequent stale hits within backoff window should not schedule more refreshes. - for i := 0; i < 10; i++ { - queryA(t, r, "mgmt.example.com.") - } - time.Sleep(50 * time.Millisecond) - assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") -} - -func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.hasRoot = false - chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") - r.SetChainResolver(chain, 50) - - // With hasRoot=false the chain must not be consulted. Use a short - // deadline so the OS fallback returns quickly without waiting on a - // real network call in CI. - ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) - defer cancel() - _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") - - assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), - "chain must not be used when no root handler is registered at the bound priority") -} - -func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { - // ServeDNS being invoked for a question while a refresh for that question - // is inflight indicates a resolver loop (OS resolver sent the recursive - // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. - r := NewResolver() - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now(), - } - - // Simulate an inflight refresh. - r.markRefreshing(q) - defer r.clearRefreshing(q) - - resp := queryA(t, r, "mgmt.example.com.") - assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") - - r.mutex.RLock() - inflight := r.refreshing[q] - r.mutex.RUnlock() - require.NotNil(t, inflight) - assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") -} - -func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { - r := NewResolver() - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now(), - } - - r.markRefreshing(q) - defer r.clearRefreshing(q) - - // Multiple ServeDNS calls during the same refresh must not re-set the flag - // (CompareAndSwap from false -> true returns true only on the first call). - for range 5 { - queryA(t, r, "mgmt.example.com.") - } - - r.mutex.RLock() - inflight := r.refreshing[q] - r.mutex.RUnlock() - assert.True(t, inflight.Load()) -} - -func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { - r := NewResolver() - - q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} - r.records[q] = &cachedRecord{ - records: []dns.RR{&dns.A{ - Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, - A: net.ParseIP("10.0.0.1").To4(), - }}, - cachedAt: time.Now(), - } - - queryA(t, r, "mgmt.example.com.") - - r.mutex.RLock() - _, ok := r.refreshing[q] - r.mutex.RUnlock() - assert.False(t, ok, "no refresh inflight means no loop tracking") -} - -func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { - r := NewResolver() - chain := newFakeChain() - chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") - chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") - r.SetChainResolver(chain, 50) - - require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) - - resp := queryA(t, r, "mgmt.example.com.") - assert.Equal(t, "10.0.0.2", firstA(t, resp)) - assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) - assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) -} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 276cbba0a..9e8a746f3 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,7 +6,6 @@ import ( "net/url" "strings" "testing" - "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } -func TestResolveCacheTTL(t *testing.T) { - tests := []struct { - name string - value string - want time.Duration - }{ - {"unset falls back to default", "", defaultTTL}, - {"valid duration", "45s", 45 * time.Second}, - {"valid minutes", "2m", 2 * time.Minute}, - {"malformed falls back to default", "not-a-duration", defaultTTL}, - {"zero falls back to default", "0s", defaultTTL}, - {"negative falls back to default", "-5s", defaultTTL}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Setenv(envMgmtCacheTTL, tc.value) - got := resolveCacheTTL() - assert.Equal(t, tc.want, got, "parsed TTL should match") - }) - } -} - -func TestNewResolver_CacheTTLFromEnv(t *testing.T) { - t.Setenv(envMgmtCacheTTL, "7s") - r := NewResolver() - assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") -} - -func TestResolver_ResponseTTL(t *testing.T) { - now := time.Now() - tests := []struct { - name string - cacheTTL time.Duration - cachedAt time.Time - wantMin uint32 - wantMax uint32 - }{ - {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, - {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, - {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, - {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - r := &Resolver{cacheTTL: tc.cacheTTL} - got := r.responseTTL(tc.cachedAt) - assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") - assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") - }) - } -} - func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index d4f54dec5..f7865047b 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -212,7 +212,6 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() - mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx,