Compare commits

..

8 Commits

12 changed files with 384 additions and 434 deletions

View File

@@ -51,20 +51,13 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -83,10 +76,9 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
}
}
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
// addNewDomains resolves and caches all domains from the update
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
}
}
}

View File

@@ -21,7 +21,6 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -1,183 +0,0 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"os"
"runtime/debug"
"slices"
"strconv"
"sync"
@@ -177,6 +178,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return fmt.Errorf("failed to get account zones: %v", err)
}
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
}
for _, peer := range account.Peers {
if !c.peersUpdateManager.HasChannel(peer.ID) {
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
@@ -244,6 +249,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("UpdateAccountPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
if len(peerIDs) == 0 {
return nil
}
@@ -251,7 +257,7 @@ func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string,
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
@@ -497,7 +503,11 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
}
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),

View File

@@ -11,7 +11,7 @@ import (
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
@@ -139,13 +139,22 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return h.Sum64()
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
}

View File

@@ -164,7 +164,9 @@ func BenchmarkHashingMethods(b *testing.B) {
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
@@ -173,7 +175,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta)
resultString = builderString(meta, pubip)
}
})
@@ -181,7 +183,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta)
resultString = fnvHashToString(meta, pubip)
}
})
@@ -189,7 +191,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta)
resultUint = metaHash(meta, pubip)
}
})
@@ -197,20 +199,29 @@ func BenchmarkHashingMethods(b *testing.B) {
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta) string {
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
var b strings.Builder
b.Grow(estimatedSize)
@@ -224,10 +235,23 @@ func builderString(meta nbpeer.PeerSystemMeta) string {
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000

View File

@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return mapError(ctx, err)
}
metahashed := metaHash(peerMeta)
metahashed := metaHash(peerMeta, sRealIP)
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
metahash := metaHash(peerMeta)
metahash := metaHash(peerMeta, realIP.String())
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta)
metahashed := metaHash(peerMeta, sRealIP)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)

View File

@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
// Establish a session so the matching-token disconnect is actually applied.
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
// Install the mock only now, so the assertion observes the disconnect, not
// the earlier connect.
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
// expected: disconnect re-armed the inactivity expiry timer
case <-time.After(time.Second):
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
// stays disabled, so disconnect must not schedule anything.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
case <-time.After(200 * time.Millisecond):
// expected: nothing scheduled
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")

View File

@@ -11,6 +11,7 @@ import (
"strings"
"time"
"github.com/netbirdio/netbird/util"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
@@ -188,15 +189,6 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
}
}
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
} else if settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
return nil
}
@@ -209,14 +201,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
if am.geo == nil || realIP == nil {
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return nil
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
return nil
}
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
@@ -1051,8 +1043,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged, metaDiff.Hostname) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1063,6 +1055,29 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return peer, nmap, resPostureChecks, dnsFwdPort, nil
}
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
reason := ""
switch {
case isStatusChanged:
reason = "status changed"
case updateAccountPeers:
reason = "update account peers"
case ipv6CapabilityChanged:
reason = "ipv6 capability changed"
case metaDiffAffectsPosture:
reason = "meta diff affects posture"
case versionChanged:
reason = "version changed"
case hostname:
reason = "hostname changed"
default:
return false
}
log.WithContext(ctx).Tracef("peer update required: %s", reason)
return true
}
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
// peer's own validated network map is bidirectional for policy and routing
// reachability, so when the peer stays valid and no source-posture gate is in
@@ -1486,6 +1501,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// UpdateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
log.WithContext(ctx).Tracef("update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
@@ -1584,6 +1600,7 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
log.WithContext(ctx).Tracef("buffering update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}

View File

@@ -49,6 +49,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -2893,3 +2894,141 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
require.NoError(t, err, "renaming to unique FQDN should succeed")
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
}
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
// returns a record built from the configured city geoname id, or an error when set.
type fakeGeo struct {
geoNameID uint
isoCode string
cityName string
err error
}
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
if g.err != nil {
return nil, g.err
}
record := &geolocation.Record{}
record.City.GeonameID = g.geoNameID
record.City.Names.En = g.cityName
record.Country.ISOCode = g.isoCode
return record, nil
}
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
func (g *fakeGeo) Stop() error { return nil }
func TestResolvePeerLocation(t *testing.T) {
realIP := net.ParseIP("203.0.113.10")
tests := []struct {
name string
geo geolocation.Geolocation
peer *nbpeer.Peer
realIP net.IP
want *nbpeer.Location
wantNil bool
}{
{
name: "no geo configured returns nil",
geo: nil,
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "nil real IP returns nil",
geo: &fakeGeo{geoNameID: 100},
peer: &nbpeer.Peer{ID: "p1"},
realIP: nil,
wantNil: true,
},
{
name: "lookup error returns nil",
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "same IP and same geoname returns nil",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
wantNil: true,
},
{
name: "same IP but changed geoname returns location",
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City B",
GeoNameID: 200,
},
},
{
name: "different IP returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("198.51.100.7"),
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
{
name: "no prior location returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
am := &DefaultAccountManager{geo: tt.geo}
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
if tt.wantNil {
assert.Nil(t, got, "resolved location should be nil")
return
}
require.NotNil(t, got, "resolved location should not be nil")
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
})
}
}

View File

@@ -0,0 +1,130 @@
package posture
import (
"context"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
func TestAffectsPosture(t *testing.T) {
processCheck := &Checks{Checks: ChecksDefinition{ProcessCheck: &ProcessCheck{}}}
osCheck := &Checks{Checks: ChecksDefinition{OSVersionCheck: &OSVersionCheck{}}}
nbCheck := &Checks{Checks: ChecksDefinition{NBVersionCheck: &NBVersionCheck{}}}
geoCheck := &Checks{Checks: ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{}}}
privateRangeCheck := &Checks{Checks: ChecksDefinition{
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Ranges: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
},
}}
publicRangeCheck := &Checks{Checks: ChecksDefinition{
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Ranges: []netip.Prefix{netip.MustParsePrefix("203.0.113.0/24")},
},
}}
mixedRangeCheck := &Checks{Checks: ChecksDefinition{
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Ranges: []netip.Prefix{
netip.MustParsePrefix("203.0.113.0/24"),
netip.MustParsePrefix("192.168.0.0/16"),
},
},
}}
tests := []struct {
name string
diff *nbpeer.MetaDiff
checks []*Checks
want bool
}{
{
name: "nil diff never affects posture",
diff: nil,
checks: []*Checks{processCheck},
want: false,
},
{
name: "process check affected by files change",
diff: &nbpeer.MetaDiff{Files: true},
checks: []*Checks{processCheck},
want: true,
},
{
name: "process check ignores unrelated change",
diff: &nbpeer.MetaDiff{Hostname: true},
checks: []*Checks{processCheck},
want: false,
},
{
name: "os check affected by os version change",
diff: &nbpeer.MetaDiff{OSVersion: true},
checks: []*Checks{osCheck},
want: true,
},
{
name: "nb check affected by wt version change",
diff: &nbpeer.MetaDiff{WtVersion: true},
checks: []*Checks{nbCheck},
want: true,
},
{
name: "geo check affected by location change",
diff: &nbpeer.MetaDiff{LocationChanged: true},
checks: []*Checks{geoCheck},
want: true,
},
{
name: "network range check not affected without network address or location change",
diff: &nbpeer.MetaDiff{Hostname: true},
checks: []*Checks{privateRangeCheck},
want: false,
},
{
name: "private range check affected by network address change",
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
checks: []*Checks{privateRangeCheck},
want: true,
},
{
name: "public range check not affected by network address change alone",
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
checks: []*Checks{publicRangeCheck},
want: false,
},
{
name: "public range check affected by location change alone",
diff: &nbpeer.MetaDiff{LocationChanged: true},
checks: []*Checks{publicRangeCheck},
want: true,
},
{
name: "private range check affected by location change alone",
diff: &nbpeer.MetaDiff{LocationChanged: true},
checks: []*Checks{privateRangeCheck},
want: true,
},
{
name: "public range check affected when location also changed",
diff: &nbpeer.MetaDiff{NetworkAddresses: true, LocationChanged: true},
checks: []*Checks{publicRangeCheck},
want: true,
},
{
name: "mixed ranges affected by network address change due to private range",
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
checks: []*Checks{mixedRangeCheck},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := AffectsPosture(context.Background(), tt.diff, tt.checks)
assert.Equal(t, tt.want, got, "AffectsPosture result should match expectation")
})
}
}

View File

@@ -7,6 +7,7 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -56,25 +57,38 @@ type Checks struct {
// alter the outcome of any of the given posture checks. It maps each check kind to
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
// does not force a posture re-evaluation.
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
for _, c := range checks {
if c.Checks.ProcessCheck != nil && diff.Files {
log.WithContext(ctx).Tracef("posture check %s is affected by files change", c.Name)
return true
}
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
log.WithContext(ctx).Tracef("posture check %s is affected by OS version check", c.Name)
return true
}
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
log.WithContext(ctx).Tracef("posture check %s is affected by NB version change", c.Name)
return true
}
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
return true
}
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
return true
if c.Checks.PeerNetworkRangeCheck != nil && (diff.NetworkAddresses || diff.LocationChanged) {
if diff.LocationChanged {
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
return true
}
for _, r := range c.Checks.PeerNetworkRangeCheck.Ranges {
if r.Addr().IsPrivate() {
log.WithContext(ctx).Tracef("posture check %s is affected by network address change", c.Name)
return true
}
}
}
}
return false