Compare commits

..

7 Commits

24 changed files with 1383 additions and 526 deletions

View File

@@ -51,13 +51,20 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -76,9 +83,10 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
}
}
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
}
}
}

View File

@@ -21,6 +21,7 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -0,0 +1,183 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -418,7 +418,14 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
case args.showProfiles:
s.showProfilesUI()
case args.showQuickActions:
s.showQuickActionsUI()
// Suppress the on-boot Quick Actions popup when the daemon
// reports DisableAutoConnect=true — that flag carries both the
// user's "Connect on Startup = off" preference AND any MDM-
// enforced override (applyMDMPolicy writes the policy value
// into the same Config field). See netbirdio/netbird#5744.
if !s.disableAutoConnectFromDaemon() {
s.showQuickActionsUI()
}
case args.showUpdate:
s.showUpdateProgress(ctx, args.showUpdateVersion)
}
@@ -1338,6 +1345,40 @@ func (s *serviceClient) getFeatures() (*proto.GetFeaturesResponse, error) {
return features, nil
}
// disableAutoConnectFromDaemon returns true when the daemon reports
// the active profile has DisableAutoConnect=true. Used by the
// --quick-actions startup path to suppress the on-boot popup when the
// user (or an MDM admin) opted out of auto-connecting; both cases
// converge on the same Config field because applyMDMPolicy writes the
// policy value into it. Returns false on any RPC / lookup failure so a
// daemon hiccup does not silently swallow the popup.
func (s *serviceClient) disableAutoConnectFromDaemon() bool {
activeProf, err := s.profileManager.GetActiveProfile()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get active profile: %v", err)
return false
}
currUser, err := user.Current()
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get current user: %v", err)
return false
}
conn, err := s.getSrvClient(failFastTimeout)
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: get daemon client: %v", err)
return false
}
srvCfg, err := conn.GetConfig(s.ctx, &proto.GetConfigRequest{
ProfileName: activeProf.ID.String(),
Username: currUser.Username,
})
if err != nil {
log.Warnf("disableAutoConnectFromDaemon: GetConfig RPC: %v", err)
return false
}
return srvCfg.GetDisableAutoConnect()
}
// getSrvConfig from the service to show it in the settings window.
func (s *serviceClient) getSrvConfig() {
s.managementURL = profilemanager.DefaultManagementURL

View File

@@ -497,7 +497,7 @@ func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID st
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s with reason %s/%s", len(peerIDs), accountID, util.GetCallerName(), reason.Operation, reason.Resource)
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
@@ -610,12 +610,10 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, 0, err
}
startPosture := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, peerID)
if err != nil {
return nil, nil, 0, err
}
log.WithContext(ctx).Debugf("getPeerPostureChecks took %s", time.Since(startPosture))
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {

View File

@@ -11,7 +11,7 @@ import (
const (
reconnThreshold = 5 * time.Minute
baseBlockDuration = 10 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
baseBlockDuration = 30 * time.Minute // Duration for which a peer is banned after exceeding the reconnection limit
reconnLimitForBan = 30 // Number of reconnections within the reconnTreshold that triggers a ban
metaChangeLimit = 3 // Number of reconnections with different metadata that triggers a ban of one peer
)
@@ -139,22 +139,13 @@ func (l *loginFilter) addLogin(wgPubKey string, metaHash uint64) {
state.lastSeen = now
}
func metaHash(meta nbpeer.PeerSystemMeta, pubip string) uint64 {
func metaHash(meta nbpeer.PeerSystemMeta) uint64 {
h := fnv.New64a()
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
macs := uint64(0)
for _, na := range meta.NetworkAddresses {
for _, r := range na.Mac {
macs += uint64(r)
}
}
return h.Sum64() + macs
return h.Sum64()
}

View File

@@ -164,9 +164,7 @@ func BenchmarkHashingMethods(b *testing.B) {
KernelVersion: "5.15.0-76-generic",
Hostname: "prod-server-database-01",
SystemSerialNumber: "PC-1234567890",
NetworkAddresses: []nbpeer.NetworkAddress{{Mac: "00:1B:44:11:3A:B7"}, {Mac: "00:1B:44:11:3A:B8"}},
}
pubip := "8.8.8.8"
var resultString string
var resultUint uint64
@@ -175,7 +173,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = builderString(meta, pubip)
resultString = builderString(meta)
}
})
@@ -183,7 +181,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultString = fnvHashToString(meta, pubip)
resultString = fnvHashToString(meta)
}
})
@@ -191,7 +189,7 @@ func BenchmarkHashingMethods(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
resultUint = metaHash(meta, pubip)
resultUint = metaHash(meta)
}
})
@@ -199,29 +197,20 @@ func BenchmarkHashingMethods(b *testing.B) {
_ = resultUint
}
func fnvHashToString(meta nbpeer.PeerSystemMeta, pubip string) string {
func fnvHashToString(meta nbpeer.PeerSystemMeta) string {
h := fnv.New64a()
if len(meta.NetworkAddresses) != 0 {
for _, na := range meta.NetworkAddresses {
h.Write([]byte(na.Mac))
}
}
h.Write([]byte(meta.WtVersion))
h.Write([]byte(meta.OSVersion))
h.Write([]byte(meta.KernelVersion))
h.Write([]byte(meta.Hostname))
h.Write([]byte(meta.SystemSerialNumber))
h.Write([]byte(pubip))
return strconv.FormatUint(h.Sum64(), 16)
}
func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
mac := getMacAddress(meta.NetworkAddresses)
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) +
len(pubip) + len(mac) + 6
func builderString(meta nbpeer.PeerSystemMeta) string {
estimatedSize := len(meta.WtVersion) + len(meta.OSVersion) + len(meta.KernelVersion) + len(meta.Hostname) + len(meta.SystemSerialNumber) + 4
var b strings.Builder
b.Grow(estimatedSize)
@@ -235,23 +224,10 @@ func builderString(meta nbpeer.PeerSystemMeta, pubip string) string {
b.WriteString(meta.Hostname)
b.WriteByte('|')
b.WriteString(meta.SystemSerialNumber)
b.WriteByte('|')
b.WriteString(pubip)
return b.String()
}
func getMacAddress(nas []nbpeer.NetworkAddress) string {
if len(nas) == 0 {
return ""
}
macs := make([]string, 0, len(nas))
for _, na := range nas {
macs = append(macs, na.Mac)
}
return strings.Join(macs, "/")
}
func BenchmarkLoginFilter_ParallelLoad(b *testing.B) {
filter := newLoginFilterWithCfg(testAdvancedCfg())
numKeys := 100000

View File

@@ -254,7 +254,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
return mapError(ctx, err)
}
metahashed := metaHash(peerMeta, sRealIP)
metahashed := metaHash(peerMeta)
if userID == "" && !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.appMetrics != nil {
s.appMetrics.GRPCMetrics().CountSyncRequestBlocked()
@@ -306,7 +306,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S
log.WithContext(ctx).Tracef("peer system meta has to be provided on sync. Peer %s, remote addr %s", peerKey.String(), realIP)
}
metahash := metaHash(peerMeta, realIP.String())
metahash := metaHash(peerMeta)
s.loginFilter.addLogin(peerKey.String(), metahash)
peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, syncStart)
@@ -732,7 +732,7 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
}
peerMeta := extractPeerMeta(ctx, loginReq.GetMeta())
metahashed := metaHash(peerMeta, sRealIP)
metahashed := metaHash(peerMeta)
if !s.loginFilter.allowLogin(peerKey.String(), metahashed) {
if s.logBlockedPeers {
log.WithContext(ctx).Tracef("peer %s with meta hash %d is blocked from login", peerKey.String(), metahashed)
@@ -788,7 +788,11 @@ func (s *Server) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto
ExtraDNSLabels: loginReq.GetDnsLabels(),
})
if err != nil {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
if errors.Is(err, internalStatus.ErrNoAuthMethodProvided) {
log.WithContext(ctx).Tracef("failed logging in peer %s: %s", peerKey, err)
} else {
log.WithContext(ctx).Warnf("failed logging in peer %s: %s", peerKey, err)
}
return nil, mapError(ctx, err)
}

View File

@@ -107,7 +107,9 @@ func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
affected := resolveAffected(t, s.manager.Store, s.accountID, change)
assert.ElementsMatch(t, affected, mustContain, "expected peer to be affected")
assert.NotContains(t, affected, mustExclude, "peer must not be affected")
for _, peerID := range mustExclude {
assert.NotContains(t, affected, peerID, "peer must not be affected")
}
})
}
}

View File

@@ -251,7 +251,9 @@ func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSou
}
}
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
// A disabled sibling router routes to nobody, so updating a resource on its network
// must NOT refresh its peer (the enabled router carries the bridge instead).
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouterNotBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -274,13 +276,18 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
enabledCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(disabledCh)
settleAffectedUpdates(disabledCh, enabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, disabledCh)
peerShouldReceiveUpdate(t, enabledCh)
peerShouldNotReceiveUpdate(t, disabledCh)
close(done)
}()
@@ -298,7 +305,7 @@ func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
t.Error("timeout")
}
}

View File

@@ -682,6 +682,9 @@ func TestAffectedPeers_AllRoutingPeers_Network(t *testing.T) {
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
// A disabled router in the snapshot routes to nobody, so it is skipped when the
// walk scans existing account data: a policy edit still folds the literal source
// group, but not the disabled router's peer.
func TestAffectedPeers_DisabledRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -694,11 +697,13 @@ func TestAffectedPeers_DisabledRouter(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled router routes to nobody, so its peer must not be folded from snapshot data")
}
// A disabled resource in the snapshot is skipped: the policy edit still folds the
// literal source group, but the resource no longer bridges to its network's router.
func TestAffectedPeers_DisabledResource(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
@@ -710,9 +715,9 @@ func TestAffectedPeers_DisabledResource(t *testing.T) {
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
assert.Contains(t, affected, s.sourcePeerID, "source peer (literal policy source group) must be affected")
assert.NotContains(t, affected, s.routerPeerID,
"a disabled resource routes to nobody, so its network's router must not be folded from snapshot data")
}
func TestAffectedPeers_DisabledRule(t *testing.T) {

View File

@@ -338,6 +338,7 @@ func TestCollectGroupChange_NetworkRouterLinked(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -368,6 +369,7 @@ func TestCollectGroupChange_NetworkRouterPeerOnlyNoGroups(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
Peer: peerIDs[4],
Enabled: true,
})
require.NoError(t, err)
@@ -490,8 +492,9 @@ func TestResolveAffectedPeers_PolicyBetweenTwoGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_PolicyThreeGroups(t *testing.T) {
@@ -544,8 +547,9 @@ func TestResolveAffectedPeers_RoutePeerGroups(t *testing.T) {
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[1]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1]}, result)
// peerIDs[2] is in no policy; only its own map can change, so it refreshes itself.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[2]}, result)
}
func TestResolveAffectedPeers_RouteWithDirectPeer(t *testing.T) {
@@ -602,9 +606,9 @@ func TestResolveAffectedPeers_RouteWithAccessControlGroups(t *testing.T) {
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[2]})
assert.ElementsMatch(t, []string{peerIDs[0], peerIDs[1], peerIDs[2]}, result)
// peer3 is unrelated
// peer3 is unrelated to the route; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[3]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[3]}, result)
}
func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
@@ -625,6 +629,7 @@ func TestResolveAffectedPeers_NetworkRouter(t *testing.T) {
AccountID: accountID,
PeerGroups: []string{groupIDs[0]},
Peer: peerIDs[3],
Enabled: true,
})
require.NoError(t, err)
@@ -896,8 +901,9 @@ func TestAffectedPeers_IsolatedPolicies(t *testing.T) {
assert.NotContains(t, result, peerIDs[0])
assert.NotContains(t, result, peerIDs[1])
// peerIDs[4] is in neither isolated policy; only its own map can change.
result = manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[4]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[4]}, result)
}
func TestAffectedPeers_IsolatedRouteAndPolicy(t *testing.T) {
@@ -1019,12 +1025,13 @@ func TestAffectedPeers_GroupUpdateOnlyAffectsLinkedPeers(t *testing.T) {
})
}
func TestAffectedPeers_UnlinkedGroupChange_NoUpdates(t *testing.T) {
// A peer in no policy/route refreshes only itself — no other peer is affected.
func TestAffectedPeers_UnlinkedPeerChange_RefreshesSelfOnly(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
result := manager.resolveAffectedPeersForPeerChanges(ctx, s, accountID, []string{peerIDs[0]})
assert.Empty(t, result)
assert.ElementsMatch(t, []string{peerIDs[0]}, result)
}
// TestAffectedPeers_PolicyChange_UnrelatedPeerNoUpdate verifies that creating/deleting a
@@ -1374,6 +1381,7 @@ func TestAffectedPeers_NetworkRouterUnlinkedPeerNoUpdate(t *testing.T) {
NetworkID: net1.ID,
AccountID: accountID,
PeerGroups: []string{"nr-grpA"},
Enabled: true,
})
require.NoError(t, err)
@@ -1797,7 +1805,9 @@ func TestCollectAffectedFromProxyServices_GroupContainingTargetPeerChanged(t *te
assert.Contains(t, directPeers, peerIDs[1], "target peer must be refreshed")
}
func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing.T) {
// A disabled service in the snapshot proxies nothing, so it is skipped: a changed
// target peer does not pull in the service's proxy peer.
func TestCollectAffectedFromProxyServices_DisabledServiceSkipped(t *testing.T) {
manager, s, accountID, peerIDs, _ := setupAffectedPeersTest(t)
ctx := context.Background()
@@ -1823,8 +1833,7 @@ func TestCollectAffectedFromProxyServices_DisabledServiceStillMatches(t *testing
require.NoError(t, s.CreateService(ctx, svc))
_, directPeers := collectPeerChangeAffectedGroups(ctx, manager.Store, accountID, nil, []string{peerIDs[1]})
assert.Contains(t, directPeers, peerIDs[0], "disabled service should still trigger a refresh so peers are ready when re-enabled")
assert.Contains(t, directPeers, peerIDs[1], "disabled target should still trigger a refresh")
assert.NotContains(t, directPeers, peerIDs[0], "a disabled service proxies nothing, so its proxy peer must not be folded")
}
func TestCollectAffectedFromProxyServices_NonPeerTargetType(t *testing.T) {

File diff suppressed because it is too large Load Diff

View File

@@ -10,8 +10,8 @@ import (
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
// policyGroupsAndPeers mirrors the both-sides extraction (RuleGroups + direct peers)
// the resolver folds in for a changed policy, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
@@ -19,7 +19,14 @@ func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []s
continue
}
groups = append(groups, p.RuleGroups()...)
collectPolicyDirectPeers(p, peerSet)
for _, rule := range p.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
for id := range peerSet {
peers = append(peers, id)
@@ -87,24 +94,9 @@ func TestPolicyReferencesPostureChecks(t *testing.T) {
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},

View File

@@ -520,7 +520,12 @@ func collectDeletableGroups(ctx context.Context, transaction store.Store, accoun
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{ChangedGroupIDs: []string{groupID}}
// A membership change affects only the peer itself and the opposite side of THIS
// group's policies — not the group's other members, and not the peer's other
// groups. LinkGroups walks only this group (matched, not expanded); OutputPeerIDs
// refreshes the peer without seeding its other group memberships. For an
// intra-group policy the opposite side is the group, so its members still refresh.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
@@ -586,10 +591,11 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var snap *affectedpeers.Snapshot
change := affectedpeers.Change{
ChangedGroupIDs: []string{groupID},
RemovedPeersByGroup: map[string][]string{groupID: {peerID}},
}
// Same as GroupAddPeer: the removed peer and the opposite side of THIS group's
// policies refresh, not the group's other members or the peer's other groups. The
// peer is no longer in the group's index, but LinkGroups still drives the
// opposite-side walk, and OutputPeerIDs refreshes the removed peer itself.
change := affectedpeers.Change{OutputPeerIDs: []string{peerID}, LinkGroups: []string{groupID}}
err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
@@ -600,8 +606,6 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
// The removed peer is carried in change.RemovedPeersByGroup and folded in
// only when the group is linked, so loading post-removal is correct.
var err error
if snap, err = affectedpeers.Load(ctx, transaction, accountID, change); err != nil {
return err

View File

@@ -220,7 +220,7 @@ func (h *handler) getAllUsers(w http.ResponseWriter, r *http.Request) {
}
includeServiceUser, err := strconv.ParseBool(serviceUser)
log.WithContext(r.Context()).Debugf("Should include service user: %v", includeServiceUser)
log.WithContext(r.Context()).Tracef("Should include service user: %v", includeServiceUser)
if err != nil {
util.WriteError(r.Context(), status.Errorf(status.InvalidArgument, "invalid service_user query parameter"), w)
return

View File

@@ -209,14 +209,14 @@ func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *
if am.geo == nil || realIP == nil {
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return nil
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) && peer.Location.GeoNameID == location.City.GeonameID {
return nil
}
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
@@ -730,7 +730,7 @@ func (am *DefaultAccountManager) handleSetupKeyAddedPeer(ctx context.Context, en
func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKey, userID string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.Network, []*posture.Checks, bool, error) {
if setupKey == "" && userID == "" && !peer.ProxyMeta.Embedded {
// no auth method provided => reject access
return nil, nil, nil, false, status.Errorf(status.Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
return nil, nil, nil, false, status.ErrNoAuthMethodProvided
}
upperKey := strings.ToUpper(setupKey)
@@ -1051,8 +1051,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
metaDiffAffectsPosture := posture.AffectsPosture(&metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged || metaDiff.Hostname {
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if requiresPeerUpdate(ctx, isStatusChanged, sync.UpdateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, metaDiff.VersionChanged(), metaDiff.HostnameChanged()) {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1063,6 +1063,29 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return peer, nmap, resPostureChecks, dnsFwdPort, nil
}
func requiresPeerUpdate(ctx context.Context, isStatusChanged, updateAccountPeers, ipv6CapabilityChanged, metaDiffAffectsPosture, versionChanged, hostname bool) bool {
var reason string
switch {
case isStatusChanged:
reason = "status changed"
case updateAccountPeers:
reason = "update account peers"
case ipv6CapabilityChanged:
reason = "ipv6 capability changed"
case metaDiffAffectsPosture:
reason = "meta diff affects posture"
case versionChanged:
reason = "version changed"
case hostname:
reason = "hostname changed"
default:
return false
}
log.WithContext(ctx).Tracef("peer update required: %s", reason)
return true
}
// syncPeerAffectedPeers resolves the peers affected by a SyncPeer change. The
// peer's own validated network map is bidirectional for policy and routing
// reachability, so when the peer stays valid and no source-posture gate is in

View File

@@ -107,6 +107,15 @@ type Location struct {
GeoNameID uint // city level geoname id
}
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
// IP.Equal, not ==.
func (l Location) equal(other Location) bool {
return l.CountryCode == other.CountryCode &&
l.CityName == other.CityName &&
l.GeoNameID == other.GeoNameID &&
l.ConnectionIP.Equal(other.ConnectionIP)
}
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
@@ -267,185 +276,141 @@ func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLoca
return MetaDiff{}
}
versionChanged := p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
oldVersion := p.Meta.WtVersion
effectiveLocation := p.Location
if newLocation != nil {
effectiveLocation = *newLocation
}
diff := diffMeta(p.Meta, meta)
if diff.Any() {
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
if diff.Updated() {
p.Meta = meta
}
diff.VersionChanged = versionChanged
p.Location = effectiveLocation
locationInfo := ""
if newLocation != nil {
p.Location = *newLocation
diff.LocationChanged = true
locationInfo = fmt.Sprintf("location changed to %s, ", newLocation.ConnectionIP)
}
versionInfo := ""
if diff.VersionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if diff.Any() || diff.VersionChanged || diff.LocationChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%s%d field(s) changed: %s", versionInfo, locationInfo, len(diff.Changed), strings.Join(diff.Changed, ", "))
if diff.Updated() {
log.WithContext(ctx).Debug(diff.LogSummary())
}
return diff
}
// MetaDiff records which PeerSystemMeta fields differ between two metas. Each bool
// maps to a single struct field, except Environment, which is split into Cloud and
// Platform. Changed holds the human-readable `field: <old> -> <new>` entries so the
// existing log line and isEqual can be derived from the same comparison.
//
// VersionChanged and LocationChanged sit outside the per-meta-field set:
// VersionChanged tracks the WireGuard client version specifically (compared before
// the UIVersion fixup, to signal client upgrades) and LocationChanged tracks the
// peer's connection geo location, which lives on Peer rather than PeerSystemMeta.
// Neither contributes an entry to Changed, so the field-coverage accounting stays
// driven purely by the PeerSystemMeta comparison.
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
// checks read it). Changed lists what moved, for logging and the persistence decision;
// the snapshots let a posture check be replayed against old and new. Everything is derived
// from these fields, so there are no parallel per-field flags to keep in sync.
type MetaDiff struct {
Hostname bool
GoOS bool
Kernel bool
KernelVersion bool
Core bool
Platform bool
OS bool
OSVersion bool
WtVersion bool
UIVersion bool
SystemSerialNumber bool
SystemProductName bool
SystemManufacturer bool
EnvironmentCloud bool
EnvironmentPlatform bool
Flags bool
Capabilities bool
NetworkAddresses bool
Files bool
VersionChanged bool
LocationChanged bool
OldMeta PeerSystemMeta
NewMeta PeerSystemMeta
OldLocation Location
NewLocation Location
Changed []string
}
// Any reports whether any PeerSystemMeta field changed.
func (d MetaDiff) Any() bool {
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
func (d *MetaDiff) Updated() bool {
return len(d.Changed) != 0
}
// Updated reports whether the peer needs to be persisted: any meta field changed
// or the geo location changed. The version flag alone does not imply a write,
// since a version change is also reflected in the WtVersion meta field.
func (d MetaDiff) Updated() bool {
return d.Any() || d.LocationChanged || d.VersionChanged
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
func (d *MetaDiff) VersionChanged() bool {
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
}
// HostnameChanged reports whether the peer's hostname changed.
func (d *MetaDiff) HostnameChanged() bool {
return d.OldMeta.Hostname != d.NewMeta.Hostname
}
// LogSummary renders the changed fields as a single human-readable line.
func (d *MetaDiff) LogSummary() string {
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
len(d.Changed), strings.Join(d.Changed, ", "))
}
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
return diffMeta(oldMeta, newMeta).Changed
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
}
// diffMeta compares two metas field by field, returning both a per-field flag set
// (for callers that need to know exactly what changed, e.g. matching against
// posture checks) and the human-readable Changed list. It is the single source of
// truth for meta comparison: isEqual reports equality as an empty diff, so the log
// line, the change decision, and the flags can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta) MetaDiff {
var d MetaDiff
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
// list, so the log line and the persistence decision can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
add := func(field string, oldVal, newVal any) {
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
d.Hostname = true
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
d.GoOS = true
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
d.Kernel = true
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
d.KernelVersion = true
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
d.Core = true
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
d.Platform = true
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
d.OS = true
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
d.OSVersion = true
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
d.WtVersion = true
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
d.UIVersion = true
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
d.SystemSerialNumber = true
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
d.SystemProductName = true
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
d.SystemManufacturer = true
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
d.EnvironmentCloud = true
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
d.EnvironmentPlatform = true
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
d.Flags = true
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
d.Capabilities = true
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
d.NetworkAddresses = true
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
d.Files = true
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
if !oldLocation.equal(newLocation) {
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
}
return d
}

View File

@@ -49,6 +49,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/geolocation"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
@@ -2893,3 +2894,141 @@ func TestUpdatePeer_DnsLabelUniqueName(t *testing.T) {
require.NoError(t, err, "renaming to unique FQDN should succeed")
assert.Equal(t, "api-server", updated.DNSLabel, "DNS label should be first label of FQDN")
}
// fakeGeo is a configurable geolocation.Geolocation implementation for tests. It
// returns a record built from the configured city geoname id, or an error when set.
type fakeGeo struct {
geoNameID uint
isoCode string
cityName string
err error
}
func (g *fakeGeo) Lookup(net.IP) (*geolocation.Record, error) {
if g.err != nil {
return nil, g.err
}
record := &geolocation.Record{}
record.City.GeonameID = g.geoNameID
record.City.Names.En = g.cityName
record.Country.ISOCode = g.isoCode
return record, nil
}
func (g *fakeGeo) GetAllCountries() ([]geolocation.Country, error) { return nil, nil }
func (g *fakeGeo) GetCitiesByCountry(string) ([]geolocation.City, error) { return nil, nil }
func (g *fakeGeo) Stop() error { return nil }
func TestResolvePeerLocation(t *testing.T) {
realIP := net.ParseIP("203.0.113.10")
tests := []struct {
name string
geo geolocation.Geolocation
peer *nbpeer.Peer
realIP net.IP
want *nbpeer.Location
wantNil bool
}{
{
name: "no geo configured returns nil",
geo: nil,
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "nil real IP returns nil",
geo: &fakeGeo{geoNameID: 100},
peer: &nbpeer.Peer{ID: "p1"},
realIP: nil,
wantNil: true,
},
{
name: "lookup error returns nil",
geo: &fakeGeo{err: fmt.Errorf("lookup boom")},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
wantNil: true,
},
{
name: "same IP and same geoname returns nil",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
wantNil: true,
},
{
name: "same IP but changed geoname returns location",
geo: &fakeGeo{geoNameID: 200, isoCode: "US", cityName: "City B"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: realIP,
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City B",
GeoNameID: 200,
},
},
{
name: "different IP returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{
ID: "p1",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("198.51.100.7"),
GeoNameID: 100,
},
},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
{
name: "no prior location returns location",
geo: &fakeGeo{geoNameID: 100, isoCode: "US", cityName: "City A"},
peer: &nbpeer.Peer{ID: "p1"},
realIP: realIP,
want: &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: "US",
CityName: "City A",
GeoNameID: 100,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
am := &DefaultAccountManager{geo: tt.geo}
got := am.resolvePeerLocation(context.Background(), tt.peer, tt.realIP)
if tt.wantNil {
assert.Nil(t, got, "resolved location should be nil")
return
}
require.NotNil(t, got, "resolved location should not be nil")
assert.True(t, tt.want.ConnectionIP.Equal(got.ConnectionIP), "connection IP should match")
assert.Equal(t, tt.want.CountryCode, got.CountryCode, "country code should match")
assert.Equal(t, tt.want.CityName, got.CityName, "city name should match")
assert.Equal(t, tt.want.GeoNameID, got.GeoNameID, "geoname id should match")
})
}
}

View File

@@ -0,0 +1,202 @@
package posture
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
return &nbpeer.MetaDiff{
OldMeta: oldMeta,
NewMeta: newMeta,
OldLocation: oldLoc,
NewLocation: newLoc,
}
}
func checks(def ChecksDefinition) []*Checks {
return []*Checks{{Checks: def}}
}
func TestAffectsPosture_NilDiff(t *testing.T) {
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
})))
}
func TestAffectsPosture_NBVersion(t *testing.T) {
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
tests := []struct {
name string
oldVer, newVer string
want bool
}{
{"both above min, no flip", "1.3.0", "1.4.0", false},
{"both below min, no flip", "1.0.0", "1.1.0", false},
{"crosses up below->above", "1.1.0", "1.3.0", true},
{"crosses down above->below", "1.3.0", "1.1.0", true},
{"unparsable old only -> flip", "garbage", "1.3.0", true},
{"unparsable both -> no flip", "garbage", "junk", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
nbpeer.Location{}, nbpeer.Location{},
)
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
})
}
}
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
}})
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
withinMin := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
crossesDown := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
}
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
// failing linux kernel flips the verdict pass -> fail.
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
// flips the verdict.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
// A tracked process stays running while an unrelated file is added: the verdict does
// not move, so posture is not affected.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
}},
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
{Path: "/usr/bin/bar", ProcessIsRunning: true},
}},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_GeoLocation(t *testing.T) {
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{{CountryCode: "DE"}},
}})
// Moving within allowed countries keeps the verdict; moving out flips it.
stayAllowed := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
)
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
moveOut := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE"},
nbpeer.Location{CountryCode: "FR"},
)
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
}
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
// moving within it does not.
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
}})
movesOutOfRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
)
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
staysInRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
)
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
}
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
// Hostname changes but no check reads it: not affected even with checks present.
c := checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
})
diff := diffFrom(
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_NoChecks(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, nil))
}

View File

@@ -7,6 +7,7 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
@@ -52,34 +53,46 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the peer metadata changes described by diff can
// alter the outcome of any of the given posture checks. It maps each check kind to
// the metadata fields it inspects, so an unrelated change (e.g. a hostname update)
// does not force a posture re-evaluation.
func AffectsPosture(diff *nbpeer.MetaDiff, checks []*Checks) bool {
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
// replays each check against the peer's old and new state and compares verdicts, so a
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
// evaluation error counts.
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
for _, c := range checks {
if c.Checks.ProcessCheck != nil && diff.Files {
return true
}
if c.Checks.OSVersionCheck != nil && (diff.OSVersion || diff.OS || diff.KernelVersion) {
return true
}
if c.Checks.NBVersionCheck != nil && diff.WtVersion {
return true
}
if c.Checks.GeoLocationCheck != nil && diff.LocationChanged {
return true
}
if c.Checks.PeerNetworkRangeCheck != nil && diff.NetworkAddresses {
return true
for _, check := range c.GetChecks() {
if verdictChanged(ctx, check, oldPeer, newPeer) {
return true
}
}
}
return false
}
// verdictChanged replays check against old and new state and reports whether the verdict
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
// verdict (no change), an error on one side only is a flip.
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
oldPass, oldErr := check.Check(ctx, oldPeer)
newPass, newErr := check.Check(ctx, newPeer)
oldVerdict := oldPass && (oldErr == nil)
newVerdict := newPass && (newErr == nil)
changed := oldVerdict != newVerdict
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
return changed
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`

View File

@@ -489,6 +489,7 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
policy := &types.Policy{
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,

View File

@@ -1059,8 +1059,8 @@ func (am *DefaultAccountManager) BuildUserInfosForAccount(ctx context.Context, a
if err != nil {
return nil, err
}
log.WithContext(ctx).Debugf("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
log.WithContext(ctx).Debugf("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
log.WithContext(ctx).Tracef("Got %d users from ExternalCache for account %s", len(usersFromIntegration), accountID)
log.WithContext(ctx).Tracef("Got %d users from InternalCache for account %s", len(queriedUsers), accountID)
queriedUsers = append(queriedUsers, usersFromIntegration...)
}

View File

@@ -48,6 +48,10 @@ type Type int32
var (
ErrExtraSettingsNotFound = errors.New("extra settings not found")
ErrPeerAlreadyLoggedIn = errors.New("peer with the same public key is already logged in")
// ErrNoAuthMethodProvided is returned when a peer login attempt carries neither a
// setup key nor an SSO token. Match it with errors.Is.
ErrNoAuthMethodProvided = Errorf(Unauthenticated, "no peer auth method provided, please use a setup key or interactive SSO login")
)
// Error is an internal error
@@ -66,6 +70,16 @@ func (e *Error) Error() string {
return e.Message
}
// Is reports whether target is an *Error with the same type and message,
// enabling matching with errors.Is against sentinel errors.
func (e *Error) Is(target error) bool {
var t *Error
if !errors.As(target, &t) {
return false
}
return e.ErrorType == t.ErrorType && e.Message == t.Message
}
// Errorf returns Error(ErrorType, fmt.Sprintf(format, a...)).
func Errorf(errorType Type, format string, a ...interface{}) error {
return &Error{