Compare commits

...

6 Commits

Author SHA1 Message Date
Zoltan Papp
df6e422e10 [client] Resolve cold mgmt cache domains synchronously before DNS takeover
Async server-domain resolution risked blackholing infra DNS at bootstrap:
if OS DNS was reconfigured to route through a dead exit node before the
background resolve ran, the cache never populated.

Resolve cold domains (no cached record) synchronously while NetBird has not
yet taken over the system resolver and will serve DNS (dnsWillBeServed), so
the cache is primed via the working OS resolver before takeover. Stale and
post-takeover resolves stay async to keep the engine sync lock unblocked.
2026-06-23 13:38:16 +02:00
Zoltán Papp
3236a4c7fd [client] Cancel ServeDNS-path mgmt resolve on server Stop
awaitPendingResolve resolved on context.Background(), so a resolve
triggered by an incoming query kept running to dnsTimeout after Stop,
making the DNS listener's ShutdownContext wait it out. Give the Resolver
the server-lifetime ctx and use it there so Stop cancels the resolve.
2026-06-22 22:34:06 +02:00
Zoltán Papp
08ac4855f6 [client] Scope mgmt cache initial resolve to server lifetime
Background initial resolves used context.Background(), so they outlived
a server Stop() and kept running (up to dnsTimeout) against a cache
that was already being discarded. Thread the server-lifetime ctx from
UpdateFromServerDomains through kickoffResolve into AddDomain so Stop()
cancels in-flight resolves. The ctx is not per-sync, so a fast-returning
sync still won't cancel them.
2026-06-22 22:19:51 +02:00
Zoltán Papp
b6c79f1f71 [client] Trim verbose comments in mgmt cache async resolve 2026-06-22 19:46:26 +02:00
Zoltán Papp
37be8811a3 [client] Move mgmt cache pending-resolve test helpers to export_test.go
WaitForPendingResolves and pendingCount existed only to let tests wait
for the background initial resolves; they had no production caller. Keep
the production API surface clean by moving them into export_test.go
(unexported, test-build only) as waitForPendingResolves/pendingCount.
2026-06-22 19:40:04 +02:00
Zoltán Papp
a7d85ff3ab [client] Resolve management cache domains asynchronously
UpdateFromServerDomains resolved the infrastructure domains (signal,
relay, STUN, TURN) synchronously, in a serial loop with a 5s timeout per
domain. It runs deep inside handleSync, under the engine syncMsgMux (and
the DNS server mutex). When DNS is unhealthy each lookup waits out its
full timeout, so the sync lock is held for many seconds — observed as a
16.6s wait for the lock by a signal handler in the field. That starves
the signal/p2p processing that shares syncMsgMux, which in turn prevents
the handshake that would make DNS healthy again: a self-reinforcing loop.

The cache cannot simply be skipped: it is a prerequisite for the relay
connection. The relay dials its server by hostname, resolved through the
NetBird DNS chain where this resolver sits on top (PriorityMgmtCache);
on a miss it falls through to upstream — the dead path the cache exists
to bypass.

Instead, record intent on sync and resolve on demand:

- UpdateFromServerDomains now marks each requested domain pending and
  kicks off resolution in the background via a dedicated singleflight
  group (resolveGroup), then returns immediately. No DNS I/O, no timeout,
  no blocking under the lock.
- ServeDNS, on a miss for a pending domain, waits on the in-flight
  resolve (joining the same singleflight flight, bounded by dnsTimeout)
  instead of falling through to the dead upstream. The wait now happens
  in the relay/DNS-serving goroutine, never under syncMsgMux.
- UpdateServerConfig registers the cache handler from the requested
  domains (RequestedDomains) rather than the already-resolved ones, so
  the resolver owns the names before resolution completes and can
  intercept the lookup.

Background resolves use context.Background() so a fast-returning sync
cannot cancel an in-flight resolve. WaitForPendingResolves lets tests
(and any warm-cache path) wait for the background work to settle.

Cache semantics are otherwise unchanged: TTL, stale-while-revalidate
refresh, backoff, and the flow-domain exclusion all stay.
2026-06-22 19:33:49 +02:00
5 changed files with 234 additions and 50 deletions

View File

@@ -0,0 +1,23 @@
package mgmt
import "time"
// pendingCount returns how many initial resolves are still in flight. Test-only.
func (m *Resolver) pendingCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.pending)
}
// waitForPendingResolves blocks until all pending resolves settle or the
// timeout elapses, returning true if all settled. Test-only.
func (m *Resolver) waitForPendingResolves(timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for m.pendingCount() > 0 {
if time.Now().After(deadline) {
return false
}
time.Sleep(10 * time.Millisecond)
}
return true
}

View File

@@ -50,17 +50,31 @@ type cachedRecord struct {
consecFailures int
}
// pendingEntry marks a domain whose initial resolve is in flight, so ServeDNS
// can wait on it instead of falling through to upstream.
type pendingEntry struct{}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, pending, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
// ctx is the server-lifetime context for background resolves.
ctx context.Context
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// pending holds domains whose initial resolve is in flight, keyed by
// punycode FQDN (trailing dot).
pending map[string]pendingEntry
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// resolveGroup dedups initial (cold-cache) resolves; kept separate from
// refreshGroup so initial and stale-refresh flights don't collapse.
resolveGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
@@ -74,10 +88,12 @@ type Resolver struct {
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
func NewResolver(ctx context.Context) *Resolver {
return &Resolver{
ctx: ctx,
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
pending: make(map[string]pendingEntry),
cacheTTL: resolveCacheTTL(),
}
}
@@ -117,6 +133,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
_, isPending := m.pending[question.Name]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
@@ -126,8 +143,17 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
// Registered but not resolved yet: wait on the in-flight resolve
// rather than falling through to (possibly dead) upstream.
if isPending && m.awaitPendingResolve(question.Name) {
m.mutex.RLock()
cached, found = m.records[question]
m.mutex.RUnlock()
}
if !found {
m.continueToNext(w, r)
return
}
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
@@ -467,6 +493,13 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
return nil
}
// RequestedDomains returns the cacheable infrastructure domains (signal, relay,
// STUN, TURN; flow excluded) so the cache handler can be registered for them
// before resolution completes.
func (m *Resolver) RequestedDomains(serverDomains dnsconfig.ServerDomains) domain.List {
return m.extractDomainsFromServerDomains(serverDomains)
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
@@ -486,10 +519,12 @@ func (m *Resolver) GetCachedDomains() domain.List {
return domains
}
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
// UpdateFromServerDomains merges server domains into the cache and resolves
// them. New types replace whole types; empty updates are ignored. Resolution is
// async (off the caller's sync lock) except for cold domains when dnsWillBeServed
// and takeover is pending, which kickoffResolve primes synchronously. ctx is the
// server lifetime, so a fast sync won't cancel resolves but Stop will.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains, dnsWillBeServed bool) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
@@ -507,11 +542,136 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
m.kickoffResolve(ctx, newDomains, dnsWillBeServed)
return removedDomains, nil
}
// kickoffResolve resolves each unresolved domain, skipping fresh/in-flight ones.
// Cold domains resolve synchronously only before takeover (no upstream root
// handler) and when dnsWillBeServed, to prime the cache via the working OS
// resolver before OS DNS routes through the tunnel; otherwise async.
func (m *Resolver) kickoffResolve(ctx context.Context, domains domain.List, dnsWillBeServed bool) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
preTakeover := chain == nil || !chain.HasRootHandlerAtOrBelow(maxPriority)
for _, d := range domains {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
_, hasPending := m.pending[dnsName]
fresh := m.hasFreshRecordLocked(dnsName)
cold := !m.hasAnyRecordLocked(dnsName)
if !hasPending && !fresh {
m.pending[dnsName] = pendingEntry{}
}
m.mutex.Unlock()
if hasPending || fresh {
continue
}
if cold && preTakeover && dnsWillBeServed {
m.resolveInitial(ctx, d, dnsName)
continue
}
m.scheduleInitialResolve(ctx, d, dnsName)
}
}
// resolveInitial resolves a cold domain synchronously, deduped via resolveGroup
// so a concurrent ServeDNS await joins the same flight. Clears pending when done.
func (m *Resolver) resolveInitial(ctx context.Context, d domain.Domain, dnsName string) {
key := "initial|" + dnsName
_, _, _ = m.resolveGroup.Do(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("initial resolve mgmt domain=%s: %v", d.SafeString(), err)
return struct{}{}, err
}
log.Debugf("added/updated management cache domain=%s", d.SafeString())
return struct{}{}, nil
})
}
// scheduleInitialResolve runs AddDomain in the background, deduped per domain
// by resolveGroup, clearing the pending marker when it finishes. ctx is the
// server-lifetime context so a Stop cancels in-flight resolves.
func (m *Resolver) scheduleInitialResolve(ctx context.Context, d domain.Domain, dnsName string) {
key := "initial|" + dnsName
_ = m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return struct{}{}, err
}
log.Debugf("added/updated management cache domain=%s", d.SafeString())
return struct{}{}, nil
})
}
// hasFreshRecordLocked reports whether a non-stale A or AAAA record exists for
// the name. Caller holds m.mutex.
func (m *Resolver) hasFreshRecordLocked(dnsName string) bool {
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if c, ok := m.records[q]; ok && time.Since(c.cachedAt) <= m.cacheTTL {
return true
}
}
return false
}
// hasAnyRecordLocked reports whether any A or AAAA record exists for the name,
// fresh or stale. Caller holds m.mutex.
func (m *Resolver) hasAnyRecordLocked(dnsName string) bool {
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return true
}
}
return false
}
func (m *Resolver) clearPending(dnsName string) {
m.mutex.Lock()
delete(m.pending, dnsName)
m.mutex.Unlock()
}
// awaitPendingResolve joins the in-flight resolve for dnsName (bounded by
// dnsTimeout) and reports whether a record became available.
func (m *Resolver) awaitPendingResolve(dnsName string) bool {
key := "initial|" + dnsName
d, err := domain.FromString(strings.TrimSuffix(dnsName, "."))
if err != nil {
return false
}
ch := m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(m.ctx, d); err != nil {
return struct{}{}, err
}
return struct{}{}, nil
})
select {
case <-ch:
case <-time.After(dnsTimeout):
return false
}
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.hasFreshRecordLocked(dnsName)
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
@@ -577,17 +737,6 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List

View File

@@ -130,7 +130,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -146,7 +146,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -162,7 +162,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -183,7 +183,7 @@ func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -213,7 +213,7 @@ func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -262,7 +262,7 @@ func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
@@ -299,7 +299,7 @@ func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -320,7 +320,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -346,7 +346,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -373,7 +373,7 @@ func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -393,7 +393,7 @@ func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")

View File

@@ -17,7 +17,7 @@ import (
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
@@ -49,7 +49,7 @@ func TestResolveCacheTTL(t *testing.T) {
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
r := NewResolver(context.Background())
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
@@ -169,7 +169,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
resolver := NewResolver(context.Background())
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
@@ -184,7 +184,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
@@ -284,7 +284,7 @@ func TestResolver_ServeDNS(t *testing.T) {
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
@@ -304,7 +304,7 @@ func TestResolver_GetCachedDomains(t *testing.T) {
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
@@ -325,10 +325,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
Relay: []domain.Domain{"cloudflare.com"},
}
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains, true)
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
finalDomains := resolver.GetCachedDomains()
@@ -351,7 +352,7 @@ func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial domains using resolvable domains
@@ -362,10 +363,11 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
@@ -373,7 +375,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
// Update with empty ServerDomains (simulating partial network map update)
emptyDomains := dnsconfig.ServerDomains{}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains, true)
assert.NoError(t, err)
// Verify no domains were removed
@@ -385,7 +387,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -396,20 +398,22 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
partialDomains := dnsconfig.ServerDomains{
Signal: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
@@ -429,7 +433,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -440,10 +444,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
}
// Add initial domains
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
@@ -451,10 +456,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
partialDomains := dnsconfig.ServerDomains{
Flow: "github.com",
}
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")

View File

@@ -282,7 +282,7 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver := mgmt.NewResolver(ctx)
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
@@ -613,7 +613,11 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
defer s.mux.Unlock()
if s.mgmtCacheResolver != nil {
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
// Mirrors the Initialize guard: without it NetBird never becomes the
// system resolver, so the mgmt cache is never queried and need not be
// primed synchronously.
dnsWillBeServed := !s.disableSys && !netstack.IsEnabled()
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains, dnsWillBeServed)
if err != nil {
return fmt.Errorf("update management cache resolver: %w", err)
}
@@ -622,7 +626,9 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
// Register for the requested domains, not just resolved ones: resolution
// now runs in the background, so the cache may still be empty here.
newDomains := s.mgmtCacheResolver.RequestedDomains(domains)
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}