Compare commits

...

5 Commits

Author SHA1 Message Date
Zoltán Papp
3236a4c7fd [client] Cancel ServeDNS-path mgmt resolve on server Stop
awaitPendingResolve resolved on context.Background(), so a resolve
triggered by an incoming query kept running to dnsTimeout after Stop,
making the DNS listener's ShutdownContext wait it out. Give the Resolver
the server-lifetime ctx and use it there so Stop cancels the resolve.
2026-06-22 22:34:06 +02:00
Zoltán Papp
08ac4855f6 [client] Scope mgmt cache initial resolve to server lifetime
Background initial resolves used context.Background(), so they outlived
a server Stop() and kept running (up to dnsTimeout) against a cache
that was already being discarded. Thread the server-lifetime ctx from
UpdateFromServerDomains through kickoffResolve into AddDomain so Stop()
cancels in-flight resolves. The ctx is not per-sync, so a fast-returning
sync still won't cancel them.
2026-06-22 22:19:51 +02:00
Zoltán Papp
b6c79f1f71 [client] Trim verbose comments in mgmt cache async resolve 2026-06-22 19:46:26 +02:00
Zoltán Papp
37be8811a3 [client] Move mgmt cache pending-resolve test helpers to export_test.go
WaitForPendingResolves and pendingCount existed only to let tests wait
for the background initial resolves; they had no production caller. Keep
the production API surface clean by moving them into export_test.go
(unexported, test-build only) as waitForPendingResolves/pendingCount.
2026-06-22 19:40:04 +02:00
Zoltán Papp
a7d85ff3ab [client] Resolve management cache domains asynchronously
UpdateFromServerDomains resolved the infrastructure domains (signal,
relay, STUN, TURN) synchronously, in a serial loop with a 5s timeout per
domain. It runs deep inside handleSync, under the engine syncMsgMux (and
the DNS server mutex). When DNS is unhealthy each lookup waits out its
full timeout, so the sync lock is held for many seconds — observed as a
16.6s wait for the lock by a signal handler in the field. That starves
the signal/p2p processing that shares syncMsgMux, which in turn prevents
the handshake that would make DNS healthy again: a self-reinforcing loop.

The cache cannot simply be skipped: it is a prerequisite for the relay
connection. The relay dials its server by hostname, resolved through the
NetBird DNS chain where this resolver sits on top (PriorityMgmtCache);
on a miss it falls through to upstream — the dead path the cache exists
to bypass.

Instead, record intent on sync and resolve on demand:

- UpdateFromServerDomains now marks each requested domain pending and
  kicks off resolution in the background via a dedicated singleflight
  group (resolveGroup), then returns immediately. No DNS I/O, no timeout,
  no blocking under the lock.
- ServeDNS, on a miss for a pending domain, waits on the in-flight
  resolve (joining the same singleflight flight, bounded by dnsTimeout)
  instead of falling through to the dead upstream. The wait now happens
  in the relay/DNS-serving goroutine, never under syncMsgMux.
- UpdateServerConfig registers the cache handler from the requested
  domains (RequestedDomains) rather than the already-resolved ones, so
  the resolver owns the names before resolution completes and can
  intercept the lookup.

Background resolves use context.Background() so a fast-returning sync
cannot cancel an in-flight resolve. WaitForPendingResolves lets tests
(and any warm-cache path) wait for the background work to settle.

Cache semantics are otherwise unchanged: TTL, stale-while-revalidate
refresh, backoff, and the flow-domain exclusion all stay.
2026-06-22 19:33:49 +02:00
5 changed files with 180 additions and 38 deletions

View File

@@ -0,0 +1,23 @@
package mgmt
import "time"
// pendingCount returns how many initial resolves are still in flight. Test-only.
func (m *Resolver) pendingCount() int {
m.mutex.RLock()
defer m.mutex.RUnlock()
return len(m.pending)
}
// waitForPendingResolves blocks until all pending resolves settle or the
// timeout elapses, returning true if all settled. Test-only.
func (m *Resolver) waitForPendingResolves(timeout time.Duration) bool {
deadline := time.Now().Add(timeout)
for m.pendingCount() > 0 {
if time.Now().After(deadline) {
return false
}
time.Sleep(10 * time.Millisecond)
}
return true
}

View File

@@ -50,17 +50,31 @@ type cachedRecord struct {
consecFailures int
}
// pendingEntry marks a domain whose initial resolve is in flight, so ServeDNS
// can wait on it instead of falling through to upstream.
type pendingEntry struct{}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, pending, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
// ctx is the server-lifetime context for background resolves.
ctx context.Context
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// pending holds domains whose initial resolve is in flight, keyed by
// punycode FQDN (trailing dot).
pending map[string]pendingEntry
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// resolveGroup dedups initial (cold-cache) resolves; kept separate from
// refreshGroup so initial and stale-refresh flights don't collapse.
resolveGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
@@ -74,10 +88,12 @@ type Resolver struct {
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
func NewResolver(ctx context.Context) *Resolver {
return &Resolver{
ctx: ctx,
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
pending: make(map[string]pendingEntry),
cacheTTL: resolveCacheTTL(),
}
}
@@ -117,6 +133,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
_, isPending := m.pending[question.Name]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
@@ -126,8 +143,17 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
m.mutex.RUnlock()
if !found {
m.continueToNext(w, r)
return
// Registered but not resolved yet: wait on the in-flight resolve
// rather than falling through to (possibly dead) upstream.
if isPending && m.awaitPendingResolve(question.Name) {
m.mutex.RLock()
cached, found = m.records[question]
m.mutex.RUnlock()
}
if !found {
m.continueToNext(w, r)
return
}
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
@@ -467,6 +493,13 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
return nil
}
// RequestedDomains returns the cacheable infrastructure domains (signal, relay,
// STUN, TURN; flow excluded) so the cache handler can be registered for them
// before resolution completes.
func (m *Resolver) RequestedDomains(serverDomains dnsconfig.ServerDomains) domain.List {
return m.extractDomainsFromServerDomains(serverDomains)
}
// GetCachedDomains returns a list of all cached domains.
func (m *Resolver) GetCachedDomains() domain.List {
m.mutex.RLock()
@@ -489,6 +522,11 @@ func (m *Resolver) GetCachedDomains() domain.List {
// UpdateFromServerDomains updates the cache with server domains from network configuration.
// It merges new domains with existing ones, replacing entire domain types when updated.
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
// UpdateFromServerDomains records the requested domains and kicks off their
// resolution in the background, returning without blocking on DNS so it stays
// off the engine sync lock held by the caller. ctx scopes the background
// resolves to the server lifetime: it is not per-sync, so a fast-returning
// sync won't cancel them, but a server Stop will.
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
newDomains := m.extractDomainsFromServerDomains(serverDomains)
var removedDomains domain.List
@@ -507,11 +545,95 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
}
m.addNewDomains(ctx, newDomains)
m.kickoffResolve(ctx, newDomains)
return removedDomains, nil
}
// kickoffResolve marks each domain pending and starts a background resolve,
// skipping ones already fresh or in flight. Returns immediately.
func (m *Resolver) kickoffResolve(ctx context.Context, domains domain.List) {
for _, d := range domains {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.Lock()
_, hasPending := m.pending[dnsName]
cached := m.hasFreshRecordLocked(dnsName)
if !hasPending && !cached {
m.pending[dnsName] = pendingEntry{}
}
m.mutex.Unlock()
if hasPending || cached {
continue
}
m.scheduleInitialResolve(ctx, d, dnsName)
}
}
// scheduleInitialResolve runs AddDomain in the background, deduped per domain
// by resolveGroup, clearing the pending marker when it finishes. ctx is the
// server-lifetime context so a Stop cancels in-flight resolves.
func (m *Resolver) scheduleInitialResolve(ctx context.Context, d domain.Domain, dnsName string) {
key := "initial|" + dnsName
_ = m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(ctx, d); err != nil {
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return struct{}{}, err
}
log.Debugf("added/updated management cache domain=%s", d.SafeString())
return struct{}{}, nil
})
}
// hasFreshRecordLocked reports whether a non-stale A or AAAA record exists for
// the name. Caller holds m.mutex.
func (m *Resolver) hasFreshRecordLocked(dnsName string) bool {
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if c, ok := m.records[q]; ok && time.Since(c.cachedAt) <= m.cacheTTL {
return true
}
}
return false
}
func (m *Resolver) clearPending(dnsName string) {
m.mutex.Lock()
delete(m.pending, dnsName)
m.mutex.Unlock()
}
// awaitPendingResolve joins the in-flight resolve for dnsName (bounded by
// dnsTimeout) and reports whether a record became available.
func (m *Resolver) awaitPendingResolve(dnsName string) bool {
key := "initial|" + dnsName
d, err := domain.FromString(strings.TrimSuffix(dnsName, "."))
if err != nil {
return false
}
ch := m.resolveGroup.DoChan(key, func() (any, error) {
defer m.clearPending(dnsName)
if err := m.AddDomain(m.ctx, d); err != nil {
return struct{}{}, err
}
return struct{}{}, nil
})
select {
case <-ch:
case <-time.After(dnsTimeout):
return false
}
m.mutex.RLock()
defer m.mutex.RUnlock()
return m.hasFreshRecordLocked(dnsName)
}
// removeStaleDomains removes cached domains not present in the target domain list.
// Management domains are preserved and never removed during server domain updates.
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
@@ -577,17 +699,6 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
var domains domain.List

View File

@@ -130,7 +130,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -146,7 +146,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
@@ -162,7 +162,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -183,7 +183,7 @@ func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
@@ -213,7 +213,7 @@ func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -262,7 +262,7 @@ func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
@@ -299,7 +299,7 @@ func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
@@ -320,7 +320,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -346,7 +346,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -373,7 +373,7 @@ func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
@@ -393,7 +393,7 @@ func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
r := NewResolver(context.Background())
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")

View File

@@ -17,7 +17,7 @@ import (
)
func TestResolver_NewResolver(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
assert.NotNil(t, resolver)
assert.NotNil(t, resolver.records)
@@ -49,7 +49,7 @@ func TestResolveCacheTTL(t *testing.T) {
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
r := NewResolver(context.Background())
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
@@ -169,7 +169,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := NewResolver()
resolver := NewResolver(context.Background())
// Test with IP address - should return error since IP addresses are rejected
mgmtURL, _ := url.Parse("https://127.0.0.1")
@@ -184,7 +184,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
}
func TestResolver_ServeDNS(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Add a test domain to the cache - use example.org which is reserved for testing
@@ -284,7 +284,7 @@ func TestResolver_ServeDNS(t *testing.T) {
}
func TestResolver_GetCachedDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
testDomain, err := domain.FromString("example.org")
@@ -304,7 +304,7 @@ func TestResolver_GetCachedDomains(t *testing.T) {
}
func TestResolver_ManagementDomainProtection(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
mgmtURL, _ := url.Parse("https://example.org")
@@ -329,6 +329,7 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
if err != nil {
t.Logf("Server domains update failed: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
finalDomains := resolver.GetCachedDomains()
@@ -351,7 +352,7 @@ func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
}
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial domains using resolvable domains
@@ -366,6 +367,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Verify domains were added
cachedDomains := resolver.GetCachedDomains()
@@ -385,7 +387,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
}
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -400,6 +402,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
@@ -410,6 +413,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
// Should remove only the old signal domain
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
@@ -429,7 +433,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
}
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
resolver := NewResolver()
resolver := NewResolver(context.Background())
ctx := context.Background()
// Set up initial complete domains using resolvable domains
@@ -444,6 +448,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, resolver.GetCachedDomains(), 3)
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
@@ -455,6 +460,7 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
if err != nil {
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
}
resolver.waitForPendingResolves(10 * time.Second)
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")

View File

@@ -282,7 +282,7 @@ func newDefaultServer(
handlerChain := NewHandlerChain()
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver := mgmt.NewResolver(ctx)
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
@@ -622,7 +622,9 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
}
newDomains := s.mgmtCacheResolver.GetCachedDomains()
// Register for the requested domains, not just resolved ones: resolution
// now runs in the background, so the cache may still be empty here.
newDomains := s.mgmtCacheResolver.RequestedDomains(domains)
if len(newDomains) > 0 {
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
}