mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-24 16:59:55 +00:00
Compare commits
8 Commits
feature/me
...
add-trace-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7f24dc5fa7 | ||
|
|
ecd133ca70 | ||
|
|
3d4a70deeb | ||
|
|
330a03ce75 | ||
|
|
4b89f3be8a | ||
|
|
b4c1db17e4 | ||
|
|
58cd0eae4e | ||
|
|
49c8d571b2 |
@@ -51,20 +51,13 @@ type cachedRecord struct {
|
||||
}
|
||||
|
||||
// Resolver caches critical NetBird infrastructure domains.
|
||||
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
|
||||
// guarded by mutex.
|
||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
||||
type Resolver struct {
|
||||
records map[dns.Question]*cachedRecord
|
||||
mgmtDomain *domain.Domain
|
||||
serverDomains *dnsconfig.ServerDomains
|
||||
mutex sync.RWMutex
|
||||
|
||||
// failedResolves records the last failed initial resolve per domain so a
|
||||
// domain that never resolves isn't retried on every server-domains update
|
||||
// until refreshBackoff elapses. Entries are cleared on success and pruned
|
||||
// to the current server-domains set.
|
||||
failedResolves map[domain.Domain]time.Time
|
||||
|
||||
chain ChainResolver
|
||||
chainMaxPriority int
|
||||
refreshGroup singleflight.Group
|
||||
@@ -83,10 +76,9 @@ type Resolver struct {
|
||||
// NewResolver creates a new management domains cache resolver.
|
||||
func NewResolver() *Resolver {
|
||||
return &Resolver{
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
failedResolves: make(map[domain.Domain]time.Time),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
records: make(map[dns.Question]*cachedRecord),
|
||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
||||
cacheTTL: resolveCacheTTL(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,9 +173,7 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
||||
// entry for that qtype. When one family hard-errors while the other succeeds,
|
||||
// the resolved family is still cached but AddDomain returns an error so the
|
||||
// caller retries the incomplete resolve rather than treating it as complete.
|
||||
// entry for that qtype.
|
||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
@@ -213,10 +203,6 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||
|
||||
if errA != nil || errAAAA != nil {
|
||||
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -476,7 +462,6 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
||||
delete(m.records, qAAAA)
|
||||
delete(m.refreshing, qA)
|
||||
delete(m.refreshing, qAAAA)
|
||||
delete(m.failedResolves, d)
|
||||
|
||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||
return nil
|
||||
@@ -520,7 +505,6 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
|
||||
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
|
||||
currentDomains := m.GetCachedDomains()
|
||||
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
|
||||
m.pruneFailedResolves(allDomains)
|
||||
}
|
||||
|
||||
m.addNewDomains(ctx, newDomains)
|
||||
@@ -593,85 +577,13 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
|
||||
return m.mgmtDomain != nil && domain == *m.mgmtDomain
|
||||
}
|
||||
|
||||
// addNewDomains resolves and caches domains that are not yet in the cache,
|
||||
// running the lookups concurrently. Domains already cached are skipped and left
|
||||
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
|
||||
// synchronously: once NetBird owns the OS resolver the resolve runs through the
|
||||
// handler chain and would otherwise dial the managed upstreams under the engine
|
||||
// sync lock on every update.
|
||||
// addNewDomains resolves and caches all domains from the update
|
||||
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
|
||||
var wg sync.WaitGroup
|
||||
seen := make(map[domain.Domain]struct{}, len(newDomains))
|
||||
for _, newDomain := range newDomains {
|
||||
if _, dup := seen[newDomain]; dup {
|
||||
continue
|
||||
}
|
||||
seen[newDomain] = struct{}{}
|
||||
|
||||
if !m.needsResolve(newDomain) {
|
||||
continue
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func(d domain.Domain) {
|
||||
defer wg.Done()
|
||||
if err := m.AddDomain(ctx, d); err != nil {
|
||||
m.markResolveFailed(d)
|
||||
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
|
||||
return
|
||||
}
|
||||
m.clearResolveFailed(d)
|
||||
log.Debugf("added/updated management cache domain=%s", d.SafeString())
|
||||
}(newDomain)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// needsResolve reports whether d should be resolved now. A recent failed or
|
||||
// incomplete resolve gates retries on the backoff even when one family is
|
||||
// already cached, so a transiently-failed family is retried instead of being
|
||||
// treated as fully resolved. Otherwise a domain with any cached record is left
|
||||
// to the stale-while-revalidate refresh path.
|
||||
func (m *Resolver) needsResolve(d domain.Domain) bool {
|
||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||
|
||||
m.mutex.RLock()
|
||||
defer m.mutex.RUnlock()
|
||||
|
||||
if failedAt, ok := m.failedResolves[d]; ok {
|
||||
return time.Since(failedAt) >= refreshBackoff
|
||||
}
|
||||
|
||||
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
|
||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
||||
if _, ok := m.records[q]; ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Resolver) markResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
m.failedResolves[d] = time.Now()
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *Resolver) clearResolveFailed(d domain.Domain) {
|
||||
m.mutex.Lock()
|
||||
delete(m.failedResolves, d)
|
||||
m.mutex.Unlock()
|
||||
}
|
||||
|
||||
// pruneFailedResolves drops failure markers for domains no longer present in
|
||||
// the server-domains set, keeping the map bounded to the current set (a
|
||||
// failed-only domain has no cached record, so RemoveDomain never sees it).
|
||||
func (m *Resolver) pruneFailedResolves(domains domain.List) {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
for d := range m.failedResolves {
|
||||
if !slices.Contains(domains, d) {
|
||||
delete(m.failedResolves, d)
|
||||
if err := m.AddDomain(ctx, newDomain); err != nil {
|
||||
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
|
||||
} else {
|
||||
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ type fakeChain struct {
|
||||
mu sync.Mutex
|
||||
calls map[string]int
|
||||
answers map[string][]dns.RR
|
||||
qErr map[string]error
|
||||
err error
|
||||
hasRoot bool
|
||||
onLookup func()
|
||||
@@ -31,7 +30,6 @@ func newFakeChain() *fakeChain {
|
||||
return &fakeChain{
|
||||
calls: map[string]int{},
|
||||
answers: map[string][]dns.RR{},
|
||||
qErr: map[string]error{},
|
||||
hasRoot: true,
|
||||
}
|
||||
}
|
||||
@@ -49,9 +47,6 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
|
||||
f.calls[key]++
|
||||
answers := f.answers[key]
|
||||
err := f.err
|
||||
if err == nil {
|
||||
err = f.qErr[key]
|
||||
}
|
||||
onLookup := f.onLookup
|
||||
f.mu.Unlock()
|
||||
|
||||
@@ -80,12 +75,6 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
|
||||
}
|
||||
|
||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
@@ -1,183 +0,0 @@
|
||||
package mgmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
)
|
||||
|
||||
// A domain already in the cache must not be re-resolved on a subsequent server
|
||||
// domains update; it is left to the stale-while-revalidate refresh path.
|
||||
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must resolve the domain")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"cached domain must not be re-resolved on a subsequent update")
|
||||
}
|
||||
|
||||
// New domains in a single update must resolve concurrently rather than serially.
|
||||
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
|
||||
var inflight, maxInflight atomic.Int32
|
||||
chain.onLookup = func() {
|
||||
n := inflight.Add(1)
|
||||
for {
|
||||
old := maxInflight.Load()
|
||||
if n <= old || maxInflight.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
inflight.Add(-1)
|
||||
}
|
||||
|
||||
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
|
||||
for _, d := range relays {
|
||||
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
|
||||
}
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
start := time.Now()
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
|
||||
require.NoError(t, err)
|
||||
elapsed := time.Since(start)
|
||||
|
||||
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
|
||||
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
|
||||
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
|
||||
}
|
||||
|
||||
// A domain that fails to resolve must not be retried on every update; the
|
||||
// failure backoff suppresses re-resolution until it expires.
|
||||
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"first update must attempt the resolve")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
|
||||
"failed resolve must back off and not retry on the next update")
|
||||
}
|
||||
|
||||
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
|
||||
// the same host) must be resolved once per update, not once per occurrence.
|
||||
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
sd := dnsconfig.ServerDomains{
|
||||
Stuns: []domain.Domain{"dup.example.com"},
|
||||
Turns: []domain.Domain{"dup.example.com"},
|
||||
}
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), sd)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
|
||||
"a domain appearing under multiple server-domain types must resolve once")
|
||||
}
|
||||
|
||||
// A failure marker must be dropped once its domain leaves the server-domains set
|
||||
// so the map stays bounded to the current set.
|
||||
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.err = errors.New("resolve boom")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, marked, "failed resolve must be recorded")
|
||||
|
||||
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
|
||||
require.NoError(t, err)
|
||||
r.mutex.RLock()
|
||||
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
|
||||
}
|
||||
|
||||
// When one family hard-errors while the other resolves, the domain is cached
|
||||
// for the working family but recorded as incomplete so the failed family is
|
||||
// retried under backoff instead of being treated as fully resolved forever.
|
||||
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
|
||||
d := domain.Domain("relay.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
|
||||
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
require.True(t, aCached, "the working family must still be cached")
|
||||
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
|
||||
|
||||
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
|
||||
|
||||
r.mutex.Lock()
|
||||
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
|
||||
r.mutex.Unlock()
|
||||
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
|
||||
}
|
||||
|
||||
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
|
||||
// not a failure: the domain must not be marked for retry, otherwise it would be
|
||||
// re-resolved on every sync.
|
||||
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
|
||||
d := domain.Domain("v4only.example.com")
|
||||
r := NewResolver()
|
||||
chain := newFakeChain()
|
||||
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
|
||||
r.SetChainResolver(chain, 50)
|
||||
|
||||
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
|
||||
require.NoError(t, err)
|
||||
|
||||
r.mutex.RLock()
|
||||
_, marked := r.failedResolves[d]
|
||||
r.mutex.RUnlock()
|
||||
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
|
||||
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
|
||||
}
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime/debug"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
@@ -177,6 +178,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
||||
return fmt.Errorf("failed to get account zones: %v", err)
|
||||
}
|
||||
|
||||
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
|
||||
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
|
||||
}
|
||||
|
||||
for _, peer := range account.Peers {
|
||||
if !c.peersUpdateManager.HasChannel(peer.ID) {
|
||||
log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID)
|
||||
@@ -244,6 +249,7 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
|
||||
|
||||
// UpdateAffectedPeers updates only the specified peers that belong to an account.
|
||||
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("UpdateAccountPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
|
||||
if len(peerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -251,7 +257,7 @@ func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string,
|
||||
}
|
||||
|
||||
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers (caller: %s)", accountID, len(peerIDs), util.GetCallerName())
|
||||
|
||||
if !c.hasConnectedPeers(peerIDs) {
|
||||
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
|
||||
@@ -497,7 +503,11 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
|
||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
|
||||
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
|
||||
if reason.Operation == types.UpdateOperationUpdate && reason.Resource == types.UpdateResourceUser {
|
||||
log.WithContext(ctx).Tracef("got an user update, stack: %s", debug.Stack())
|
||||
}
|
||||
|
||||
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
|
||||
peerIDs: make(map[string]struct{}),
|
||||
@@ -610,10 +620,12 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
|
||||
startPosture := time.Now()
|
||||
postureChecks, err := c.getPeerPostureChecks(account, peerID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
|
||||
|
||||
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
|
||||
if err != nil {
|
||||
|
||||
@@ -1916,117 +1916,6 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: true,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
// Establish a session so the matching-token disconnect is actually applied.
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
// Install the mock only now, so the assertion observes the disconnect, not
|
||||
// the earlier connect.
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
// expected: disconnect re-armed the inactivity expiry timer
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
|
||||
require.NoError(t, err, "unable to create an account")
|
||||
|
||||
key, err := wgtypes.GenerateKey()
|
||||
require.NoError(t, err, "unable to generate WireGuard key")
|
||||
peerPubKey := key.PublicKey().String()
|
||||
|
||||
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
|
||||
Key: peerPubKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
|
||||
InactivityExpirationEnabled: true,
|
||||
}, false)
|
||||
require.NoError(t, err, "unable to add peer")
|
||||
|
||||
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
|
||||
// stays disabled, so disconnect must not schedule anything.
|
||||
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
|
||||
PeerLoginExpiration: time.Hour,
|
||||
PeerLoginExpirationEnabled: true,
|
||||
PeerInactivityExpiration: time.Hour,
|
||||
PeerInactivityExpirationEnabled: false,
|
||||
Extra: &types.ExtraSettings{},
|
||||
})
|
||||
require.NoError(t, err, "expecting to update account settings successfully but got error")
|
||||
|
||||
streamStartTime := time.Now().UTC()
|
||||
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
|
||||
require.NoError(t, err, "unable to mark peer connected")
|
||||
|
||||
scheduled := make(chan struct{}, 1)
|
||||
manager.peerInactivityExpiry = &MockScheduler{
|
||||
CancelFunc: func(ctx context.Context, IDs []string) {},
|
||||
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
|
||||
select {
|
||||
case scheduled <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
|
||||
require.NoError(t, err, "unable to mark peer disconnected")
|
||||
|
||||
select {
|
||||
case <-scheduled:
|
||||
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
|
||||
case <-time.After(200 * time.Millisecond):
|
||||
// expected: nothing scheduled
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
|
||||
manager, _, err := createManager(t)
|
||||
require.NoError(t, err, "unable to create account manager")
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
@@ -188,15 +189,6 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
|
||||
}
|
||||
}
|
||||
|
||||
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
|
||||
} else if settings.PeerInactivityExpirationEnabled {
|
||||
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -209,14 +201,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
|
||||
if am.geo == nil || realIP == nil {
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
|
||||
return nil
|
||||
}
|
||||
location, err := am.geo.Lookup(realIP)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
|
||||
return nil
|
||||
}
|
||||
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
|
||||
return nil
|
||||
}
|
||||
return &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: location.Country.ISOCode,
|
||||
@@ -1052,7 +1044,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
}
|
||||
|
||||
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
|
||||
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
|
||||
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged, metaDiff.Hostname) {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1063,6 +1055,29 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return peer, nmap, resPostureChecks, dnsFwdPort, nil
|
||||
}
|
||||
|
||||
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
|
||||
reason := ""
|
||||
switch {
|
||||
case isStatusChanged:
|
||||
reason = "status changed"
|
||||
case updateAccountPeers:
|
||||
reason = "update account peers"
|
||||
case ipv6CapabilityChanged:
|
||||
reason = "ipv6 capability changed"
|
||||
case metaDiffAffectsPosture:
|
||||
reason = "meta diff affects posture"
|
||||
case versionChanged:
|
||||
reason = "version changed"
|
||||
case hostname:
|
||||
reason = "hostname changed"
|
||||
default:
|
||||
return false
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("peer update required: %s", reason)
|
||||
return true
|
||||
}
|
||||
|
||||
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
|
||||
// peer's own validated network map is bidirectional for policy and routing
|
||||
// reachability, so when the peer stays valid and no source-posture gate is in
|
||||
@@ -1486,6 +1501,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
|
||||
// UpdateAccountPeers updates all peers that belong to an account.
|
||||
// Should be called when changes have to be synced to peers.
|
||||
func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
log.WithContext(ctx).Tracef("update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
@@ -1584,6 +1600,7 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
||||
log.WithContext(ctx).Tracef("buffering update account peers for account %s from caller: %s with reason %s/%s", accountID, util.GetCallerName(), reason.Operation, reason.Resource)
|
||||
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
|
||||
}
|
||||
|
||||
|
||||
@@ -107,15 +107,6 @@ type Location struct {
|
||||
GeoNameID uint // city level geoname id
|
||||
}
|
||||
|
||||
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
|
||||
// IP.Equal, not ==.
|
||||
func (l Location) equal(other Location) bool {
|
||||
return l.CountryCode == other.CountryCode &&
|
||||
l.CityName == other.CityName &&
|
||||
l.GeoNameID == other.GeoNameID &&
|
||||
l.ConnectionIP.Equal(other.ConnectionIP)
|
||||
}
|
||||
|
||||
// NetworkAddress is the IP address with network and MAC address of a network interface
|
||||
type NetworkAddress struct {
|
||||
NetIP netip.Prefix `gorm:"serializer:json"`
|
||||
@@ -276,139 +267,183 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
|
||||
return MetaDiff{}
|
||||
}
|
||||
|
||||
versionChanged := p.Meta.WtVersion != meta.WtVersion
|
||||
|
||||
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
|
||||
if meta.UIVersion == "" {
|
||||
meta.UIVersion = p.Meta.UIVersion
|
||||
}
|
||||
|
||||
effectiveLocation := p.Location
|
||||
if newLocation != nil {
|
||||
effectiveLocation = *newLocation
|
||||
}
|
||||
oldVersion := p.Meta.WtVersion
|
||||
|
||||
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
|
||||
if diff.Updated() {
|
||||
diff := diffMeta(p.Meta, meta)
|
||||
if diff.Any() {
|
||||
p.Meta = meta
|
||||
}
|
||||
p.Location = effectiveLocation
|
||||
diff.VersionChanged = versionChanged
|
||||
|
||||
if diff.Updated() {
|
||||
log.WithContext(ctx).Debug(diff.LogSummary())
|
||||
locationInfo := ""
|
||||
if newLocation != nil {
|
||||
p.Location = *newLocation
|
||||
diff.LocationChanged = true
|
||||
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
|
||||
}
|
||||
|
||||
versionInfo := ""
|
||||
if diff.VersionChanged {
|
||||
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
|
||||
}
|
||||
|
||||
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
|
||||
log.WithContext(ctx).
|
||||
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
|
||||
}
|
||||
|
||||
return diff
|
||||
}
|
||||
|
||||
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
|
||||
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
|
||||
// checks read it). Changed lists what moved, for logging and the persistence decision;
|
||||
// the snapshots let a posture check be replayed against old and new. Everything is derived
|
||||
// from these fields, so there are no parallel per-field flags to keep in sync.
|
||||
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
|
||||
// maps to a single struct field, except Environment, which is split into Cloud and
|
||||
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
|
||||
// existing log line and isEqual can be derived from the same comparison.
|
||||
//
|
||||
// VersionChanged and LocationChanged sit outside the per-meta-field set:
|
||||
// VersionChanged tracks the WireGuard client version specifically (compared before
|
||||
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
|
||||
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
|
||||
// Neither contributes an entry to Changed, so the field-coverage accounting stays
|
||||
// driven purely by the PeerSystemMeta comparison.
|
||||
type MetaDiff struct {
|
||||
OldMeta PeerSystemMeta
|
||||
NewMeta PeerSystemMeta
|
||||
OldLocation Location
|
||||
NewLocation Location
|
||||
Hostname bool
|
||||
GoOS bool
|
||||
Kernel bool
|
||||
KernelVersion bool
|
||||
Core bool
|
||||
Platform bool
|
||||
OS bool
|
||||
OSVersion bool
|
||||
WtVersion bool
|
||||
UIVersion bool
|
||||
SystemSerialNumber bool
|
||||
SystemProductName bool
|
||||
SystemManufacturer bool
|
||||
EnvironmentCloud bool
|
||||
EnvironmentPlatform bool
|
||||
Flags bool
|
||||
Capabilities bool
|
||||
NetworkAddresses bool
|
||||
Files bool
|
||||
|
||||
VersionChanged bool
|
||||
LocationChanged bool
|
||||
|
||||
Changed []string
|
||||
}
|
||||
|
||||
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
|
||||
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
|
||||
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
|
||||
func (d *MetaDiff) Updated() bool {
|
||||
// Any reports whether any PeerSystemMeta field changed.
|
||||
func (d MetaDiff) Any() bool {
|
||||
return len(d.Changed) != 0
|
||||
}
|
||||
|
||||
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
|
||||
func (d *MetaDiff) VersionChanged() bool {
|
||||
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
|
||||
}
|
||||
|
||||
// HostnameChanged reports whether the peer's hostname changed.
|
||||
func (d *MetaDiff) HostnameChanged() bool {
|
||||
return d.OldMeta.Hostname != d.NewMeta.Hostname
|
||||
}
|
||||
|
||||
// LogSummary renders the changed fields as a single human-readable line.
|
||||
func (d *MetaDiff) LogSummary() string {
|
||||
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
|
||||
len(d.Changed), strings.Join(d.Changed, ", "))
|
||||
// Updated reports whether the peer needs to be persisted: any meta field changed
|
||||
// or the geo location changed. The version flag alone does not imply a write,
|
||||
// since a version change is also reflected in the WtVersion meta field.
|
||||
func (d MetaDiff) Updated() bool {
|
||||
return d.Any() || d.LocationChanged || d.VersionChanged
|
||||
}
|
||||
|
||||
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
|
||||
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
|
||||
return diffMeta(oldMeta, newMeta).Changed
|
||||
}
|
||||
|
||||
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
|
||||
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
|
||||
// list, so the log line and the persistence decision can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
|
||||
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
|
||||
// diffMeta compares two metas field by field, returning both a per-field flag set
|
||||
// (for callers that need to know exactly what changed, e.g. matching against
|
||||
// posture checks) and the human-readable Changed list. It is the single source of
|
||||
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
|
||||
// line, the change decision, and the flags can never disagree.
|
||||
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
|
||||
var d MetaDiff
|
||||
add := func(field string, oldVal, newVal any) {
|
||||
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
|
||||
}
|
||||
|
||||
if oldMeta.Hostname != newMeta.Hostname {
|
||||
d.Hostname = true
|
||||
add("hostname", oldMeta.Hostname, newMeta.Hostname)
|
||||
}
|
||||
if oldMeta.GoOS != newMeta.GoOS {
|
||||
d.GoOS = true
|
||||
add("goos", oldMeta.GoOS, newMeta.GoOS)
|
||||
}
|
||||
if oldMeta.Kernel != newMeta.Kernel {
|
||||
d.Kernel = true
|
||||
add("kernel", oldMeta.Kernel, newMeta.Kernel)
|
||||
}
|
||||
if oldMeta.KernelVersion != newMeta.KernelVersion {
|
||||
d.KernelVersion = true
|
||||
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
|
||||
}
|
||||
if oldMeta.Core != newMeta.Core {
|
||||
d.Core = true
|
||||
add("core", oldMeta.Core, newMeta.Core)
|
||||
}
|
||||
if oldMeta.Platform != newMeta.Platform {
|
||||
d.Platform = true
|
||||
add("platform", oldMeta.Platform, newMeta.Platform)
|
||||
}
|
||||
if oldMeta.OS != newMeta.OS {
|
||||
d.OS = true
|
||||
add("os", oldMeta.OS, newMeta.OS)
|
||||
}
|
||||
if oldMeta.OSVersion != newMeta.OSVersion {
|
||||
d.OSVersion = true
|
||||
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
|
||||
}
|
||||
if oldMeta.WtVersion != newMeta.WtVersion {
|
||||
d.WtVersion = true
|
||||
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
|
||||
}
|
||||
if oldMeta.UIVersion != newMeta.UIVersion {
|
||||
d.UIVersion = true
|
||||
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
|
||||
}
|
||||
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
|
||||
d.SystemSerialNumber = true
|
||||
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
|
||||
}
|
||||
if oldMeta.SystemProductName != newMeta.SystemProductName {
|
||||
d.SystemProductName = true
|
||||
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
|
||||
}
|
||||
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
|
||||
d.SystemManufacturer = true
|
||||
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
|
||||
}
|
||||
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
|
||||
d.EnvironmentCloud = true
|
||||
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
|
||||
}
|
||||
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
|
||||
d.EnvironmentPlatform = true
|
||||
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
|
||||
}
|
||||
if !oldMeta.Flags.isEqual(newMeta.Flags) {
|
||||
d.Flags = true
|
||||
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
|
||||
}
|
||||
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
|
||||
d.Capabilities = true
|
||||
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
|
||||
}
|
||||
|
||||
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
|
||||
d.NetworkAddresses = true
|
||||
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
|
||||
}
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
if !oldLocation.equal(newLocation) {
|
||||
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
|
||||
if !sameMultiset(oldMeta.Files, newMeta.Files) {
|
||||
d.Files = true
|
||||
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
|
||||
}
|
||||
|
||||
return d
|
||||
|
||||
@@ -49,6 +49,7 @@ import (
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/geolocation"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
@@ -2893,3 +2894,141 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
|
||||
require.NoError(t, err, "renaming to unique FQDN should succeed")
|
||||
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
|
||||
}
|
||||
|
||||
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
|
||||
// returns a record built from the configured city geoname id, or an error when set.
|
||||
type fakeGeo struct {
|
||||
geoNameID uint
|
||||
isoCode string
|
||||
cityName string
|
||||
err error
|
||||
}
|
||||
|
||||
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
|
||||
if g.err != nil {
|
||||
return nil, g.err
|
||||
}
|
||||
record := &geolocation.Record{}
|
||||
record.City.GeonameID = g.geoNameID
|
||||
record.City.Names.En = g.cityName
|
||||
record.Country.ISOCode = g.isoCode
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
|
||||
|
||||
func (g *fakeGeo) Stop() error { return nil }
|
||||
|
||||
func TestResolvePeerLocation(t *testing.T) {
|
||||
realIP := net.ParseIP("203.0.113.10")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
geo geolocation.Geolocation
|
||||
peer *nbpeer.Peer
|
||||
realIP net.IP
|
||||
want *nbpeer.Location
|
||||
wantNil bool
|
||||
}{
|
||||
{
|
||||
name: "no geo configured returns nil",
|
||||
geo: nil,
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "nil real IP returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: nil,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "lookup error returns nil",
|
||||
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP and same geoname returns nil",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
wantNil: true,
|
||||
},
|
||||
{
|
||||
name: "same IP but changed geoname returns location",
|
||||
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City B",
|
||||
GeoNameID: 200,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "different IP returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{
|
||||
ID: "p1",
|
||||
Location: nbpeer.Location{
|
||||
ConnectionIP: net.ParseIP("198.51.100.7"),
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no prior location returns location",
|
||||
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
|
||||
peer: &nbpeer.Peer{ID: "p1"},
|
||||
realIP: realIP,
|
||||
want: &nbpeer.Location{
|
||||
ConnectionIP: realIP,
|
||||
CountryCode: "US",
|
||||
CityName: "City A",
|
||||
GeoNameID: 100,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
am := &DefaultAccountManager{geo: tt.geo}
|
||||
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
|
||||
if tt.wantNil {
|
||||
assert.Nil(t, got, "resolved location should be nil")
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got, "resolved location should not be nil")
|
||||
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
|
||||
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
|
||||
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
|
||||
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
|
||||
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
|
||||
return &nbpeer.MetaDiff{
|
||||
OldMeta: oldMeta,
|
||||
NewMeta: newMeta,
|
||||
OldLocation: oldLoc,
|
||||
NewLocation: newLoc,
|
||||
}
|
||||
}
|
||||
|
||||
func checks(def ChecksDefinition) []*Checks {
|
||||
return []*Checks{{Checks: def}}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NilDiff(t *testing.T) {
|
||||
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
})))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NBVersion(t *testing.T) {
|
||||
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
oldVer, newVer string
|
||||
want bool
|
||||
}{
|
||||
{"both above min, no flip", "1.3.0", "1.4.0", false},
|
||||
{"both below min, no flip", "1.0.0", "1.1.0", false},
|
||||
{"crosses up below->above", "1.1.0", "1.3.0", true},
|
||||
{"crosses down above->below", "1.3.0", "1.1.0", true},
|
||||
{"unparsable old only -> flip", "garbage", "1.3.0", true},
|
||||
{"unparsable both -> no flip", "garbage", "junk", false},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
|
||||
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
|
||||
}})
|
||||
|
||||
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
|
||||
withinMin := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
|
||||
|
||||
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
|
||||
crossesDown := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
|
||||
// failing linux kernel flips the verdict pass -> fail.
|
||||
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
|
||||
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
|
||||
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
|
||||
// flips the verdict.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
|
||||
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
|
||||
// A tracked process stays running while an unrelated file is added: the verdict does
|
||||
// not move, so posture is not affected.
|
||||
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
|
||||
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
|
||||
}})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
|
||||
{Path: "/usr/bin/foo", ProcessIsRunning: true},
|
||||
{Path: "/usr/bin/bar", ProcessIsRunning: true},
|
||||
}},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_GeoLocation(t *testing.T) {
|
||||
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
|
||||
Action: CheckActionAllow,
|
||||
Locations: []Location{{CountryCode: "DE"}},
|
||||
}})
|
||||
|
||||
// Moving within allowed countries keeps the verdict; moving out flips it.
|
||||
stayAllowed := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
|
||||
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
|
||||
|
||||
moveOut := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{CountryCode: "DE"},
|
||||
nbpeer.Location{CountryCode: "FR"},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
|
||||
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
|
||||
// moving within it does not.
|
||||
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
|
||||
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Action: CheckActionAllow,
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
|
||||
}})
|
||||
|
||||
movesOutOfRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
|
||||
)
|
||||
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
|
||||
|
||||
staysInRange := diffFrom(
|
||||
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
|
||||
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
|
||||
// Hostname changes but no check reads it: not affected even with checks present.
|
||||
c := checks(ChecksDefinition{
|
||||
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
|
||||
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
|
||||
})
|
||||
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
|
||||
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
|
||||
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, c))
|
||||
}
|
||||
|
||||
func TestAffectsPosture_NoChecks(t *testing.T) {
|
||||
diff := diffFrom(
|
||||
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
|
||||
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
|
||||
nbpeer.Location{}, nbpeer.Location{},
|
||||
)
|
||||
assert.False(t, AffectsPosture(context.Background(), diff, nil))
|
||||
}
|
||||
130
management/server/posture/affects_test.go
Normal file
130
management/server/posture/affects_test.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package posture
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
)
|
||||
|
||||
func TestAffectsPosture(t *testing.T) {
|
||||
processCheck := &Checks{Checks: ChecksDefinition{ProcessCheck: &ProcessCheck{}}}
|
||||
osCheck := &Checks{Checks: ChecksDefinition{OSVersionCheck: &OSVersionCheck{}}}
|
||||
nbCheck := &Checks{Checks: ChecksDefinition{NBVersionCheck: &NBVersionCheck{}}}
|
||||
geoCheck := &Checks{Checks: ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{}}}
|
||||
|
||||
privateRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
|
||||
},
|
||||
}}
|
||||
publicRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{netip.MustParsePrefix("203.0.113.0/24")},
|
||||
},
|
||||
}}
|
||||
mixedRangeCheck := &Checks{Checks: ChecksDefinition{
|
||||
PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
|
||||
Ranges: []netip.Prefix{
|
||||
netip.MustParsePrefix("203.0.113.0/24"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
},
|
||||
},
|
||||
}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
diff *nbpeer.MetaDiff
|
||||
checks []*Checks
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "nil diff never affects posture",
|
||||
diff: nil,
|
||||
checks: []*Checks{processCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "process check affected by files change",
|
||||
diff: &nbpeer.MetaDiff{Files: true},
|
||||
checks: []*Checks{processCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "process check ignores unrelated change",
|
||||
diff: &nbpeer.MetaDiff{Hostname: true},
|
||||
checks: []*Checks{processCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "os check affected by os version change",
|
||||
diff: &nbpeer.MetaDiff{OSVersion: true},
|
||||
checks: []*Checks{osCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nb check affected by wt version change",
|
||||
diff: &nbpeer.MetaDiff{WtVersion: true},
|
||||
checks: []*Checks{nbCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "geo check affected by location change",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{geoCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "network range check not affected without network address or location change",
|
||||
diff: &nbpeer.MetaDiff{Hostname: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "private range check affected by network address change",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public range check not affected by network address change alone",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "public range check affected by location change alone",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "private range check affected by location change alone",
|
||||
diff: &nbpeer.MetaDiff{LocationChanged: true},
|
||||
checks: []*Checks{privateRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "public range check affected when location also changed",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true, LocationChanged: true},
|
||||
checks: []*Checks{publicRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "mixed ranges affected by network address change due to private range",
|
||||
diff: &nbpeer.MetaDiff{NetworkAddresses: true},
|
||||
checks: []*Checks{mixedRangeCheck},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := AffectsPosture(context.Background(), tt.diff, tt.checks)
|
||||
assert.Equal(t, tt.want, got, "AffectsPosture result should match expectation")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -53,46 +53,47 @@ type Checks struct {
|
||||
Checks ChecksDefinition `gorm:"serializer:json"`
|
||||
}
|
||||
|
||||
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
|
||||
// replays each check against the peer's old and new state and compares verdicts, so a
|
||||
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
|
||||
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
|
||||
// evaluation error counts.
|
||||
// AffectsPosture reports whether the peer metadata changes described by diff can
|
||||
// alter the outcome of any of the given posture checks. It maps each check kind to
|
||||
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
|
||||
// does not force a posture re-evaluation.
|
||||
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
|
||||
if diff == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
|
||||
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
|
||||
|
||||
for _, c := range checks {
|
||||
for _, check := range c.GetChecks() {
|
||||
if verdictChanged(ctx, check, oldPeer, newPeer) {
|
||||
if c.Checks.ProcessCheck != nil && diff.Files {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by files change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by OS version check", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by NB version change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
|
||||
return true
|
||||
}
|
||||
if c.Checks.PeerNetworkRangeCheck != nil && (diff.NetworkAddresses || diff.LocationChanged) {
|
||||
if diff.LocationChanged {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by location change", c.Name)
|
||||
return true
|
||||
}
|
||||
for _, r := range c.Checks.PeerNetworkRangeCheck.Ranges {
|
||||
if r.Addr().IsPrivate() {
|
||||
log.WithContext(ctx).Tracef("posture check %s is affected by network address change", c.Name)
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// verdictChanged replays check against old and new state and reports whether the verdict
|
||||
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
|
||||
// verdict (no change), an error on one side only is a flip.
|
||||
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
|
||||
oldPass, oldErr := check.Check(ctx, oldPeer)
|
||||
newPass, newErr := check.Check(ctx, newPeer)
|
||||
|
||||
oldVerdict := oldPass && oldErr == nil
|
||||
newVerdict := newPass && newErr == nil
|
||||
changed := oldVerdict != newVerdict
|
||||
|
||||
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
|
||||
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
|
||||
|
||||
return changed
|
||||
}
|
||||
|
||||
// ChecksDefinition contains definition of actual check
|
||||
type ChecksDefinition struct {
|
||||
NBVersionCheck *NBVersionCheck `json:",omitempty"`
|
||||
|
||||
Reference in New Issue
Block a user