mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-25 01:09:54 +00:00
Compare commits
6 Commits
t850
...
fix/mgmt-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df6e422e10 | ||
|
|
3236a4c7fd | ||
|
|
08ac4855f6 | ||
|
|
b6c79f1f71 | ||
|
|
37be8811a3 | ||
|
|
a7d85ff3ab |
23
client/internal/dns/mgmt/export_test.go
Normal file
23
client/internal/dns/mgmt/export_test.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package mgmt
|
||||
|
||||
import "time"
|
||||
|
||||
// pendingCount returns how many initial resolves are still in flight. Test-only.
|
||||
func (m *Resolver) pendingCount() int {
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return len(m.pending)
|
||||
}
|
||||
|
||||
// waitForPendingResolves blocks until all pending resolves settle or the
|
||||
// timeout elapses, returning true if all settled. Test-only.
|
||||
func (m *Resolver) waitForPendingResolves(timeout time.Duration) bool {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for m.pendingCount() > 0 {
|
||||
if time.Now().After(deadline) {
|
||||
return false
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -50,17 +50,31 @@ type cachedRecord struct {
|
||||
consecFailures int
|
||||
}
|
||||
|
||||
// pendingEntry marks a domain whose initial resolve is in flight, so ServeDNS
|
||||
// can wait on it instead of falling through to upstream.
|
||||
type pendingEntry struct{}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
// records, refreshing, pending, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
// ctx is the server-lifetime context for background resolves.
|
||||
ctx context.Context
|
||||
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// pending holds domains whose initial resolve is in flight, keyed by
|
||||
// punycode FQDN (trailing dot).
|
||||
pending map[string]pendingEntry
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
// resolveGroup dedups initial (cold-cache) resolves; kept separate from
|
||||
// refreshGroup so initial and stale-refresh flights don't collapse.
|
||||
resolveGroup singleflight.Group
|
||||
|
||||
// refreshing tracks questions whose refresh is running via the OS
|
||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
||||
@@ -74,10 +88,12 @@ type Resolver struct {
|
||||
}
|
||||
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
func NewResolver(ctx context.Context) *Resolver {
|
||||
return &Resolver{
|
||||
ctx: ctx,
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
pending: make(map[string]pendingEntry),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
@@ -117,6 +133,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.mutex.RLock()
|
||||
cached, found := m.records[question]
|
||||
inflight := m.refreshing[question]
|
||||
_, isPending := m.pending[question.Name]
|
||||
var shouldRefresh bool
|
||||
if found {
|
||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
||||
@@ -126,8 +143,17 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
m.mutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
// Registered but not resolved yet: wait on the in-flight resolve
|
||||
// rather than falling through to (possibly dead) upstream.
|
||||
if isPending && m.awaitPendingResolve(question.Name) {
|
||||
m.mutex.RLock()
|
||||
cached, found = m.records[question]
|
||||
m.mutex.RUnlock()
|
||||
}
|
||||
if !found {
|
||||
m.continueToNext(w, r)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
||||
@@ -467,6 +493,13 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RequestedDomains returns the cacheable infrastructure domains (signal, relay,
|
||||
// STUN, TURN; flow excluded) so the cache handler can be registered for them
|
||||
// before resolution completes.
|
||||
func (m *Resolver) RequestedDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
||||
return m.extractDomainsFromServerDomains(serverDomains)
|
||||
}
|
||||
|
||||
// GetCachedDomains returns a list of all cached domains.
|
||||
func (m *Resolver) GetCachedDomains() domain.List {
|
||||
m.mutex.RLock()
|
||||
@@ -486,10 +519,12 @@ func (m *Resolver) GetCachedDomains() domain.List {
|
||||
return domains
|
||||
}
|
||||
|
||||
// UpdateFromServerDomains updates the cache with server domains from network configuration.
|
||||
// It merges new domains with existing ones, replacing entire domain types when updated.
|
||||
// Empty updates are ignored to prevent clearing infrastructure domains during partial updates.
|
||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains) (domain.List, error) {
|
||||
// UpdateFromServerDomains merges server domains into the cache and resolves
|
||||
// them. New types replace whole types; empty updates are ignored. Resolution is
|
||||
// async (off the caller's sync lock) except for cold domains when dnsWillBeServed
|
||||
// and takeover is pending, which kickoffResolve primes synchronously. ctx is the
|
||||
// server lifetime, so a fast sync won't cancel resolves but Stop will.
|
||||
func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dnsconfig.ServerDomains, dnsWillBeServed bool) (domain.List, error) {
|
||||
newDomains := m.extractDomainsFromServerDomains(serverDomains)
|
||||
var removedDomains domain.List
|
||||
|
||||
@@ -507,11 +542,136 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
m.kickoffResolve(ctx, newDomains, dnsWillBeServed)
|
||||
|
||||
return removedDomains, nil
|
||||
}
|
||||
|
||||
// kickoffResolve resolves each unresolved domain, skipping fresh/in-flight ones.
|
||||
// Cold domains resolve synchronously only before takeover (no upstream root
|
||||
// handler) and when dnsWillBeServed, to prime the cache via the working OS
|
||||
// resolver before OS DNS routes through the tunnel; otherwise async.
|
||||
func (m *Resolver) kickoffResolve(ctx context.Context, domains domain.List, dnsWillBeServed bool) {
|
||||
m.mutex.RLock()
|
||||
chain := m.chain
|
||||
maxPriority := m.chainMaxPriority
|
||||
m.mutex.RUnlock()
|
||||
preTakeover := chain == nil || !chain.HasRootHandlerAtOrBelow(maxPriority)
|
||||
|
||||
for _, d := range domains {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.Lock()
|
||||
_, hasPending := m.pending[dnsName]
|
||||
fresh := m.hasFreshRecordLocked(dnsName)
|
||||
cold := !m.hasAnyRecordLocked(dnsName)
|
||||
if !hasPending && !fresh {
|
||||
m.pending[dnsName] = pendingEntry{}
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
|
||||
if hasPending || fresh {
|
||||
continue
|
||||
}
|
||||
|
||||
if cold && preTakeover && dnsWillBeServed {
|
||||
m.resolveInitial(ctx, d, dnsName)
|
||||
continue
|
||||
}
|
||||
|
||||
m.scheduleInitialResolve(ctx, d, dnsName)
|
||||
}
|
||||
}
|
||||
|
||||
// resolveInitial resolves a cold domain synchronously, deduped via resolveGroup
|
||||
// so a concurrent ServeDNS await joins the same flight. Clears pending when done.
|
||||
func (m *Resolver) resolveInitial(ctx context.Context, d domain.Domain, dnsName string) {
|
||||
key := "initial|" + dnsName
|
||||
_, _, _ = m.resolveGroup.Do(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("initial resolve mgmt domain=%s: %v", d.SafeString(), err)
|
||||
return struct{}{}, err
|
||||
}
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
return struct{}{}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// scheduleInitialResolve runs AddDomain in the background, deduped per domain
|
||||
// by resolveGroup, clearing the pending marker when it finishes. ctx is the
|
||||
// server-lifetime context so a Stop cancels in-flight resolves.
|
||||
func (m *Resolver) scheduleInitialResolve(ctx context.Context, d domain.Domain, dnsName string) {
|
||||
key := "initial|" + dnsName
|
||||
_ = m.resolveGroup.DoChan(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return struct{}{}, err
|
||||
}
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
return struct{}{}, nil
|
||||
})
|
||||
}
|
||||
|
||||
// hasFreshRecordLocked reports whether a non-stale A or AAAA record exists for
|
||||
// the name. Caller holds m.mutex.
|
||||
func (m *Resolver) hasFreshRecordLocked(dnsName string) bool {
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if c, ok := m.records[q]; ok && time.Since(c.cachedAt) <= m.cacheTTL {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// hasAnyRecordLocked reports whether any A or AAAA record exists for the name,
|
||||
// fresh or stale. Caller holds m.mutex.
|
||||
func (m *Resolver) hasAnyRecordLocked(dnsName string) bool {
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *Resolver) clearPending(dnsName string) {
|
||||
m.mutex.Lock()
|
||||
delete(m.pending, dnsName)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// awaitPendingResolve joins the in-flight resolve for dnsName (bounded by
|
||||
// dnsTimeout) and reports whether a record became available.
|
||||
func (m *Resolver) awaitPendingResolve(dnsName string) bool {
|
||||
key := "initial|" + dnsName
|
||||
d, err := domain.FromString(strings.TrimSuffix(dnsName, "."))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
ch := m.resolveGroup.DoChan(key, func() (any, error) {
|
||||
defer m.clearPending(dnsName)
|
||||
if err := m.AddDomain(m.ctx, d); err != nil {
|
||||
return struct{}{}, err
|
||||
}
|
||||
return struct{}{}, nil
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-time.After(dnsTimeout):
|
||||
return false
|
||||
}
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
return m.hasFreshRecordLocked(dnsName)
|
||||
}
|
||||
|
||||
// removeStaleDomains removes cached domains not present in the target domain list.
|
||||
// Management domains are preserved and never removed during server domain updates.
|
||||
func (m *Resolver) removeStaleDomains(currentDomains, newDomains domain.List) domain.List {
|
||||
@@ -577,17 +737,6 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
for _, newDomain := range newDomains {
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.ServerDomains) domain.List {
|
||||
var domains domain.List
|
||||
|
||||
|
||||
@@ -130,7 +130,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
|
||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
r.cacheTTL = 10 * time.Millisecond
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
@@ -146,7 +146,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
r.cacheTTL = time.Hour
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
||||
@@ -162,7 +162,7 @@ func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -183,7 +183,7 @@ func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -213,7 +213,7 @@ func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
|
||||
@@ -262,7 +262,7 @@ func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
@@ -299,7 +299,7 @@ func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.hasRoot = false
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
@@ -320,7 +320,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
// ServeDNS being invoked for a question while a refresh for that question
|
||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -346,7 +346,7 @@ func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -373,7 +373,7 @@ func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
|
||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
||||
r.records[q] = &cachedRecord{
|
||||
@@ -393,7 +393,7 @@ func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
func TestResolver_NewResolver(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
|
||||
assert.NotNil(t, resolver)
|
||||
assert.NotNil(t, resolver.records)
|
||||
@@ -49,7 +49,7 @@ func TestResolveCacheTTL(t *testing.T) {
|
||||
|
||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
||||
t.Setenv(envMgmtCacheTTL, "7s")
|
||||
r := NewResolver()
|
||||
r := NewResolver(context.Background())
|
||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
|
||||
// Test with IP address - should return error since IP addresses are rejected
|
||||
mgmtURL, _ := url.Parse("https://127.0.0.1")
|
||||
@@ -184,7 +184,7 @@ func TestResolver_PopulateFromConfig(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ServeDNS(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Add a test domain to the cache - use example.org which is reserved for testing
|
||||
@@ -284,7 +284,7 @@ func TestResolver_ServeDNS(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_GetCachedDomains(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
testDomain, err := domain.FromString("example.org")
|
||||
@@ -304,7 +304,7 @@ func TestResolver_GetCachedDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
mgmtURL, _ := url.Parse("https://example.org")
|
||||
@@ -325,10 +325,11 @@ func TestResolver_ManagementDomainProtection(t *testing.T) {
|
||||
Relay: []domain.Domain{"cloudflare.com"},
|
||||
}
|
||||
|
||||
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains)
|
||||
_, err = resolver.UpdateFromServerDomains(ctx, serverDomains, true)
|
||||
if err != nil {
|
||||
t.Logf("Server domains update failed: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
|
||||
@@ -351,7 +352,7 @@ func extractDomainFromURL(u *url.URL) (domain.Domain, error) {
|
||||
}
|
||||
|
||||
func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial domains using resolvable domains
|
||||
@@ -362,10 +363,11 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
// Verify domains were added
|
||||
cachedDomains := resolver.GetCachedDomains()
|
||||
@@ -373,7 +375,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
|
||||
// Update with empty ServerDomains (simulating partial network map update)
|
||||
emptyDomains := dnsconfig.ServerDomains{}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, emptyDomains, true)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Verify no domains were removed
|
||||
@@ -385,7 +387,7 @@ func TestResolver_EmptyUpdateDoesNotRemoveDomains(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial complete domains using resolvable domains
|
||||
@@ -396,20 +398,22 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only signal domain - this should replace signal but preserve stun/turn)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Signal: "github.com",
|
||||
}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
// Should remove only the old signal domain
|
||||
assert.Len(t, removedDomains, 1, "Should remove only the old signal domain")
|
||||
@@ -429,7 +433,7 @@ func TestResolver_PartialUpdateReplacesOnlyUpdatedTypes(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
resolver := NewResolver()
|
||||
resolver := NewResolver(context.Background())
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up initial complete domains using resolvable domains
|
||||
@@ -440,10 +444,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
|
||||
// Add initial domains
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains)
|
||||
_, err := resolver.UpdateFromServerDomains(ctx, initialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
@@ -451,10 +456,11 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains)
|
||||
removedDomains, err := resolver.UpdateFromServerDomains(ctx, partialDomains, true)
|
||||
if err != nil {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
resolver.waitForPendingResolves(10 * time.Second)
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
|
||||
|
||||
@@ -282,7 +282,7 @@ func newDefaultServer(
|
||||
handlerChain := NewHandlerChain()
|
||||
ctx, stop := context.WithCancel(ctx)
|
||||
|
||||
mgmtCacheResolver := mgmt.NewResolver()
|
||||
mgmtCacheResolver := mgmt.NewResolver(ctx)
|
||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
||||
|
||||
defaultServer := &DefaultServer{
|
||||
@@ -613,7 +613,11 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
||||
defer s.mux.Unlock()
|
||||
|
||||
if s.mgmtCacheResolver != nil {
|
||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains)
|
||||
// Mirrors the Initialize guard: without it NetBird never becomes the
|
||||
// system resolver, so the mgmt cache is never queried and need not be
|
||||
// primed synchronously.
|
||||
dnsWillBeServed := !s.disableSys && !netstack.IsEnabled()
|
||||
removedDomains, err := s.mgmtCacheResolver.UpdateFromServerDomains(s.ctx, domains, dnsWillBeServed)
|
||||
if err != nil {
|
||||
return fmt.Errorf("update management cache resolver: %w", err)
|
||||
}
|
||||
@@ -622,7 +626,9 @@ func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) erro
|
||||
s.deregisterHandler(removedDomains.ToPunycodeList(), PriorityMgmtCache)
|
||||
}
|
||||
|
||||
newDomains := s.mgmtCacheResolver.GetCachedDomains()
|
||||
// Register for the requested domains, not just resolved ones: resolution
|
||||
// now runs in the background, so the cache may still be empty here.
|
||||
newDomains := s.mgmtCacheResolver.RequestedDomains(domains)
|
||||
if len(newDomains) > 0 {
|
||||
s.registerHandler(newDomains.ToPunycodeList(), s.mgmtCacheResolver, PriorityMgmtCache)
|
||||
}
|
||||
|
||||
@@ -82,12 +82,6 @@ const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
disableAutoUpdate = "disabled"
|
||||
|
||||
// systemInfoTimeout bounds how long the sync loop waits for system info / posture
|
||||
// check gathering. The gathering runs uncancellable system calls (process scan,
|
||||
// exec, os.Stat); without this bound a single stuck call freezes handleSync, and
|
||||
// thus syncMsgMux, for as long as the call hangs (observed multi-minute freezes).
|
||||
systemInfoTimeout = 15 * time.Second
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
@@ -1072,22 +1066,11 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||
}
|
||||
e.checks = checks
|
||||
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, checks)
|
||||
if !ok {
|
||||
// Gathering timed out; skip the meta sync this cycle rather than blocking the
|
||||
// sync loop (and syncMsgMux) on a stuck system call. A later sync will retry.
|
||||
return nil
|
||||
info, err := system.GetInfoWithChecks(e.ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
return fmt.Errorf("could not sync meta: error %s", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyInfoFlags sets the engine's config-derived feature flags on the gathered system info.
|
||||
func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
@@ -1106,6 +1089,12 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||
log.Errorf("could not sync meta: error %s", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
@@ -1251,15 +1240,31 @@ func (e *Engine) receiveManagementEvents() {
|
||||
e.shutdownWg.Add(1)
|
||||
go func() {
|
||||
defer e.shutdownWg.Done()
|
||||
info, ok := system.GetInfoWithChecksTimeout(e.ctx, systemInfoTimeout, e.checks)
|
||||
if !ok {
|
||||
// Gathering timed out; connect the stream with base info so management
|
||||
// connectivity still comes up rather than blocking here.
|
||||
info, err := system.GetInfoWithChecks(e.ctx, e.checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = system.GetInfo(e.ctx)
|
||||
}
|
||||
e.applyInfoFlags(info)
|
||||
info.SetFlags(
|
||||
e.config.RosenpassEnabled,
|
||||
e.config.RosenpassPermissive,
|
||||
&e.config.ServerSSHAllowed,
|
||||
e.config.DisableClientRoutes,
|
||||
e.config.DisableServerRoutes,
|
||||
e.config.DisableDNS,
|
||||
e.config.DisableFirewall,
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
e.config.EnableSSHRemotePortForwarding,
|
||||
e.config.DisableSSHAuth,
|
||||
)
|
||||
|
||||
err := e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||
if err != nil {
|
||||
// happens if management is unavailable for a long time.
|
||||
// We want to cancel the operation of the whole client
|
||||
|
||||
@@ -2,10 +2,8 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -156,7 +154,7 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
processCheckPaths = append(processCheckPaths, check.GetFiles()...)
|
||||
}
|
||||
|
||||
files, err := checkFileAndProcess(ctx, processCheckPaths)
|
||||
files, err := checkFileAndProcess(processCheckPaths)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -168,39 +166,3 @@ func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, erro
|
||||
log.Debugf("all system information gathered successfully")
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// GetInfoWithChecksTimeout is GetInfoWithChecks bounded by timeout. Posture-check gathering
|
||||
// runs uncancellable system calls (process enumeration, os.Stat), so calling it inline can
|
||||
// block the caller for as long as such a call hangs. It runs in a goroutine instead: if it
|
||||
// does not return within timeout the caller gets (nil, false) and should proceed with
|
||||
// degraded behavior rather than block. On a gathering error it falls back to base GetInfo.
|
||||
//
|
||||
// The buffered channel lets the abandoned goroutine finish and exit once its blocking call
|
||||
// returns, so it does not leak beyond the duration of that call.
|
||||
func GetInfoWithChecksTimeout(ctx context.Context, timeout time.Duration, checks []*proto.Checks) (*Info, bool) {
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
infoCh := make(chan *Info, 1)
|
||||
go func() {
|
||||
info, err := GetInfoWithChecks(ctx, checks)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get system info with checks: %v", err)
|
||||
info = GetInfo(ctx)
|
||||
}
|
||||
infoCh <- info
|
||||
}()
|
||||
|
||||
select {
|
||||
case info := <-infoCh:
|
||||
return info, true
|
||||
case <-ctx.Done():
|
||||
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||
log.Warnf("gathering system info with checks timed out after %s", timeout)
|
||||
} else {
|
||||
// Parent context canceled (e.g. shutdown), not a timeout.
|
||||
log.Warnf("gathering system info with checks canceled: %v", ctx.Err())
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,7 +50,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ func GetInfo(ctx context.Context) *Info {
|
||||
sysName := string(bytes.Split(utsname.Sysname[:], []byte{0})[0])
|
||||
machine := string(bytes.Split(utsname.Machine[:], []byte{0})[0])
|
||||
release := string(bytes.Split(utsname.Release[:], []byte{0})[0])
|
||||
swVersion, err := exec.CommandContext(ctx, "sw_vers", "-productVersion").Output()
|
||||
swVersion, err := exec.Command("sw_vers", "-productVersion").Output()
|
||||
if err != nil {
|
||||
log.Warnf("got an error while retrieving macOS version with sw_vers, error: %s. Using darwin version instead.\n", err)
|
||||
swVersion = []byte(release)
|
||||
|
||||
@@ -105,7 +105,7 @@ func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -103,7 +103,7 @@ func collectLocationInfo(info *Info) {
|
||||
}
|
||||
}
|
||||
|
||||
func checkFileAndProcess(_ context.Context, _ []string) ([]File, error) {
|
||||
func checkFileAndProcess(_ []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package system
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/grpc/metadata"
|
||||
@@ -35,20 +34,6 @@ func Test_CustomHostname(t *testing.T) {
|
||||
assert.Equal(t, want, got.Hostname)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Success(t *testing.T) {
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), 30*time.Second, nil)
|
||||
assert.True(t, ok, "expected gathering to complete within the timeout")
|
||||
assert.NotNil(t, info)
|
||||
}
|
||||
|
||||
func TestGetInfoWithChecksTimeout_Timeout(t *testing.T) {
|
||||
// A 1ns budget expires before the (real) system-info gathering can finish, so the
|
||||
// caller must get (nil, false) instead of blocking on the in-flight goroutine.
|
||||
info, ok := GetInfoWithChecksTimeout(context.Background(), time.Nanosecond, nil)
|
||||
assert.False(t, ok, "expected timeout to be reported")
|
||||
assert.Nil(t, info)
|
||||
}
|
||||
|
||||
func Test_NetAddresses(t *testing.T) {
|
||||
addr, err := networkAddresses()
|
||||
if err != nil {
|
||||
|
||||
@@ -3,30 +3,24 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"slices"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
)
|
||||
|
||||
// getRunningProcesses returns a list of running process paths. The context bounds the work:
|
||||
// the per-PID loop bails as soon as ctx is done, and the gopsutil calls honor it where they
|
||||
// can, so a stuck enumeration cannot run unbounded.
|
||||
func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
processIDs, err := process.PidsWithContext(ctx)
|
||||
// getRunningProcesses returns a list of running process paths.
|
||||
func getRunningProcesses() ([]string, error) {
|
||||
processIDs, err := process.Pids()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
processMap := make(map[string]bool)
|
||||
for _, pID := range processIDs {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &process.Process{Pid: pID}
|
||||
|
||||
path, _ := p.ExeWithContext(ctx)
|
||||
path, _ := p.Exe()
|
||||
if path != "" {
|
||||
processMap[path] = false
|
||||
}
|
||||
@@ -41,21 +35,18 @@ func getRunningProcesses(ctx context.Context) ([]string, error) {
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(ctx context.Context, paths []string) ([]File, error) {
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
files := make([]File, len(paths))
|
||||
if len(paths) == 0 {
|
||||
return files, nil
|
||||
}
|
||||
|
||||
runningProcesses, err := getRunningProcesses(ctx)
|
||||
runningProcesses, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, path := range paths {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
file := File{Path: path}
|
||||
|
||||
_, err := os.Stat(path)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/shirou/gopsutil/v3/process"
|
||||
@@ -10,7 +9,7 @@ import (
|
||||
func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
b.Run("getRunningProcesses new", func(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
ps, err := getRunningProcesses(context.Background())
|
||||
ps, err := getRunningProcesses()
|
||||
if err != nil {
|
||||
b.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
@@ -30,38 +29,12 @@ func Benchmark_getRunningProcesses(b *testing.B) {
|
||||
}
|
||||
}
|
||||
})
|
||||
s, _ := getRunningProcesses(context.Background())
|
||||
s, _ := getRunningProcesses()
|
||||
b.Logf("getRunningProcesses returned %d processes", len(s))
|
||||
s, _ = getRunningProcessesOld()
|
||||
b.Logf("getRunningProcessesOld returned %d processes", len(s))
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_ContextCanceled(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
// With a canceled context and non-empty paths the gathering must bail with an error
|
||||
// instead of running the (potentially blocking) process scan / stat loop.
|
||||
if _, err := checkFileAndProcess(ctx, []string{"/does/not/exist"}); err == nil {
|
||||
t.Fatal("expected error on canceled context, got nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckFileAndProcess_EmptyPaths(t *testing.T) {
|
||||
// No check paths means no work to do: it must return immediately with no error,
|
||||
// even on a canceled context (nothing to scan or stat).
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
files, err := checkFileAndProcess(ctx, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error for empty paths: %v", err)
|
||||
}
|
||||
if len(files) != 0 {
|
||||
t.Fatalf("expected no files, got %d", len(files))
|
||||
}
|
||||
}
|
||||
|
||||
func getRunningProcessesOld() ([]string, error) {
|
||||
processes, err := process.Processes()
|
||||
if err != nil {
|
||||
|
||||
@@ -1170,7 +1170,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.Meta = login.Meta
|
||||
peer.UpdateMetaIfNew(ctx, login.Meta)
|
||||
|
||||
peerGroupIDs, err = getPeerGroupIDs(ctx, am.Store, accountID, peer.ID)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user