mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-24 16:59:55 +00:00
Compare commits
8 Commits
feature/le
...
add-trace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f24dc5fa7 | ||
|
|
ecd133ca70 | ||
|
|
3d4a70deeb | ||
|
|
330a03ce75 | ||
|
|
4b89f3be8a | ||
|
|
b4c1db17e4 | ||
|
|
58cd0eae4e | ||
|
|
49c8d571b2 |
@@ -51,20 +51,13 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -83,10 +76,9 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -177,6 +178,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
|
||||
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
@@ -244,6 +249,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("UpdateAccountPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -251,7 +257,7 @@ func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string,
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
@@ -497,7 +503,11 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
|
||||
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
|
||||
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
|
||||
}
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
const (
|
||||
reconnThreshold = 5 * time.Minute
|
||||
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
|
||||
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
|
||||
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
|
||||
)
|
||||
@@ -139,13 +139,22 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
|
||||
state.lastSeen = now
|
||||
}
|
||||
|
||||
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
|
||||
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
|
||||
h := fnv.New64a()
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return h.Sum64()
|
||||
macs := uint64(0)
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
for _, r := range na.Mac {
|
||||
macs += uint64(r)
|
||||
}
|
||||
}
|
||||
|
||||
return h.Sum64() + macs
|
||||
}
|
||||
|
||||
@@ -164,7 +164,9 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
KernelVersion: "5.15.0-76-generic",
|
||||
Hostname: "prod-server-database-01",
|
||||
SystemSerialNumber: "PC-1234567890",
|
||||
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
|
||||
}
|
||||
pubip := "8.8.8.8"
|
||||
|
||||
var resultString string
|
||||
var resultUint uint64
|
||||
@@ -173,7 +175,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = builderString(meta)
|
||||
resultString = builderString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -181,7 +183,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultString = fnvHashToString(meta)
|
||||
resultString = fnvHashToString(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -189,7 +191,7 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
resultUint = metaHash(meta)
|
||||
resultUint = metaHash(meta, pubip)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -197,20 +199,29 @@ func BenchmarkHashingMethods(b *testing.B) {
|
||||
_ = resultUint
|
||||
}
|
||||
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
|
||||
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
h := fnv.New64a()
|
||||
|
||||
if len(meta.NetworkAddresses) != 0 {
|
||||
for _, na := range meta.NetworkAddresses {
|
||||
h.Write([]byte(na.Mac))
|
||||
}
|
||||
}
|
||||
|
||||
h.Write([]byte(meta.WtVersion))
|
||||
h.Write([]byte(meta.OSVersion))
|
||||
h.Write([]byte(meta.KernelVersion))
|
||||
h.Write([]byte(meta.Hostname))
|
||||
h.Write([]byte(meta.SystemSerialNumber))
|
||||
h.Write([]byte(pubip))
|
||||
|
||||
return strconv.FormatUint(h.Sum64(), 16)
|
||||
}
|
||||
|
||||
func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
|
||||
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
|
||||
mac := getMacAddress(meta.NetworkAddresses)
|
||||
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
|
||||
len(pubip) + len(mac) + 6
|
||||
|
||||
var b strings.Builder
|
||||
b.Grow(estimatedSize)
|
||||
@@ -224,10 +235,23 @@ func builderString(meta nbpeer.PeerSystemMeta) string {
|
||||
b.WriteString(meta.Hostname)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(meta.SystemSerialNumber)
|
||||
b.WriteByte('|')
|
||||
b.WriteString(pubip)
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func getMacAddress(nas []nbpeer.NetworkAddress) string {
|
||||
if len(nas) == 0 {
|
||||
return ""
|
||||
}
|
||||
macs := make([]string, 0, len(nas))
|
||||
for _, na := range nas {
|
||||
macs = append(macs, na.Mac)
|
||||
}
|
||||
return strings.Join(macs, "/")
|
||||
}
|
||||
|
||||
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
|
||||
filter := newLoginFilterWithCfg(testAdvancedCfg())
|
||||
numKeys := 100000
|
||||
|
||||
@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
return mapError(ctx, err)
|
||||
}
|
||||
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.appMetrics != nil {
|
||||
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
|
||||
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
|
||||
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
|
||||
}
|
||||
|
||||
metahash := metaHash(peerMeta)
|
||||
metahash := metaHash(peerMeta, realIP.String())
|
||||
s.loginFilter.addLogin(peerKey.String(), metahash)
|
||||
|
||||
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
|
||||
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
|
||||
}
|
||||
|
||||
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
|
||||
metahashed := metaHash(peerMeta)
|
||||
metahashed := metaHash(peerMeta, sRealIP)
|
||||
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
|
||||
if s.logBlockedPeers {
|
||||
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
|
||||
|
||||
@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -188,15 +189,6 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -209,14 +201,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
|
||||
return nil
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
@@ -1051,8 +1043,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return nil, nil, nil, 0, err
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged, metaDiff.Hostname) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1063,6 +1055,29 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
|
||||
reason := ""
|
||||
switch {
|
||||
case isStatusChanged:
|
||||
reason = "status changed"
|
||||
case updateAccountPeers:
|
||||
reason = "update account peers"
|
||||
case ipv6CapabilityChanged:
|
||||
reason = "ipv6 capability changed"
|
||||
case metaDiffAffectsPosture:
|
||||
reason = "meta diff affects posture"
|
||||
case versionChanged:
|
||||
reason = "version changed"
|
||||
case hostname:
|
||||
reason = "hostname changed"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("peer update required: %s", reason)
|
||||
return true
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
@@ -1486,6 +1501,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
log.WithContext(ctx).Tracef("update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
@@ -1584,6 +1600,7 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
log.WithContext(ctx).Tracef("buffering update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -2893,3 +2894,141 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
|
||||
// returns a record built from the configured city geoname id, or an error when set.
|
||||
type fakeGeo struct {
|
||||
geoNameID uint
|
||||
isoCode string
|
||||
cityName string
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
|
||||
if g.err != nil {
|
||||
return nil, g.err
|
||||
}
|
||||
record := &geolocation.Record{}
|
||||
record.City.GeonameID = g.geoNameID
|
||||
record.City.Names.En = g.cityName
|
||||
record.Country.ISOCode = g.isoCode
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) Stop() error { return nil }
|
||||
|
||||
func TestResolvePeerLocation(t *testing.T) {
|
||||
realIP := net.ParseIP("203.0.113.10")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
geo geolocation.Geolocation
|
||||
peer *nbpeer.Peer
|
||||
realIP net.IP
|
||||
want *nbpeer.Location
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "no geo configured returns nil",
|
||||
geo: nil,
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil real IP returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "lookup error returns nil",
|
||||
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP and same geoname returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP but changed geoname returns location",
|
||||
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City B",
|
||||
GeoNameID: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IP returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("198.51.100.7"),
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no prior location returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
am := &DefaultAccountManager{geo: tt.geo}
|
||||
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
|
||||
if tt.wantNil {
|
||||
assert.Nil(t, got, "resolved location should be nil")
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got, "resolved location should not be nil")
|
||||
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
|
||||
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
|
||||
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
|
||||
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
130
management/server/posture/affects_test.go
Normal file
130
management/server/posture/affects_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestAffectsPosture(t *testing.T) {
|
||||
processCheck := &Checks{Checks: ChecksDefinition{ProcessCheck: &ProcessCheck{}}}
|
||||
osCheck := &Checks{Checks: ChecksDefinition{OSVersionCheck: &OSVersionCheck{}}}
|
||||
nbCheck := &Checks{Checks: ChecksDefinition{NBVersionCheck: &NBVersionCheck{}}}
|
||||
geoCheck := &Checks{Checks: ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{}}}
|
||||
|
||||
privateRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
},
|
||||
}}
|
||||
publicRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix("203.0.113.0/24")},
|
||||
},
|
||||
}}
|
||||
mixedRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
diff *nbpeer.MetaDiff
|
||||
checks []*Checks
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil diff never affects posture",
|
||||
diff: nil,
|
||||
checks: []*Checks{processCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "process check affected by files change",
|
||||
diff: &nbpeer.MetaDiff{Files: true},
|
||||
checks: []*Checks{processCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "process check ignores unrelated change",
|
||||
diff: &nbpeer.MetaDiff{Hostname: true},
|
||||
checks: []*Checks{processCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "os check affected by os version change",
|
||||
diff: &nbpeer.MetaDiff{OSVersion: true},
|
||||
checks: []*Checks{osCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nb check affected by wt version change",
|
||||
diff: &nbpeer.MetaDiff{WtVersion: true},
|
||||
checks: []*Checks{nbCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "geo check affected by location change",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{geoCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "network range check not affected without network address or location change",
|
||||
diff: &nbpeer.MetaDiff{Hostname: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "private range check affected by network address change",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public range check not affected by network address change alone",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "public range check affected by location change alone",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "private range check affected by location change alone",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public range check affected when location also changed",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true, LocationChanged: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "mixed ranges affected by network address change due to private range",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{mixedRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := AffectsPosture(context.Background(), tt.diff, tt.checks)
|
||||
assert.Equal(t, tt.want, got, "AffectsPosture result should match expectation")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"regexp"
|
||||
|
||||
"github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
@@ -56,25 +57,38 @@ type Checks struct {
|
||||
// alter the outcome of any of the given posture checks. It maps each check kind to
|
||||
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
|
||||
// does not force a posture re-evaluation.
|
||||
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
for _, c := range checks {
|
||||
if c.Checks.ProcessCheck != nil && diff.Files {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by files change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by OS version check", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by NB version change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
|
||||
return true
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && (diff.NetworkAddresses || diff.LocationChanged) {
|
||||
if diff.LocationChanged {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
|
||||
return true
|
||||
}
|
||||
for _, r := range c.Checks.PeerNetworkRangeCheck.Ranges {
|
||||
if r.Addr().IsPrivate() {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by network address change", c.Name)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
Reference in New Issue
Block a user