Compare commits

...

7 Commits

Author SHA1 Message Date
pascal
9eadf50f4c fix typo 2026-06-23 23:02:48 +02:00
pascal
be06016ad2 validate posture checks on meta change before account update 2026-06-23 22:58:32 +02:00
Viktor Liu
17b2044596 [client] Skip re-resolving cached management cache domains (#6518) 2026-06-23 17:55:57 +02:00
Bethuel Mmbaga
07101c59ac [management] Reschedule inactivity expiration when a peer disconnects (#6523) 2026-06-23 17:44:32 +03:00
Riccardo Manfrin
51b6f6291b Fixup debug config (#6514) 2026-06-22 22:01:49 +02:00
Pascal Fischer
2ebf26006a [management] empty file check in nmap on other posturechecks (#6511) 2026-06-22 19:54:38 +02:00
Pascal Fischer
211a26019a [management] validate meta change against posture checks (#6510) 2026-06-22 19:42:04 +02:00
19 changed files with 804 additions and 183 deletions

View File

@@ -130,7 +130,7 @@ func debugConfigDump(cmd *cobra.Command, _ []string) error {
client := proto.NewDaemonServiceClient(conn)
resp, err := client.GetConfig(cmd.Context(), &proto.GetConfigRequest{
ProfileName: activeProf.Name,
ProfileName: string(activeProf.ID),
Username: currUser.Username,
})
if err != nil {

View File

@@ -51,13 +51,20 @@ type cachedRecord struct {
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// records, refreshing, failedResolves, mgmtDomain and serverDomains are all
// guarded by mutex.
type Resolver struct {
records map[dns.Question]*cachedRecord
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
// failedResolves records the last failed initial resolve per domain so a
// domain that never resolves isn't retried on every server-domains update
// until refreshBackoff elapses. Entries are cleared on success and pruned
// to the current server-domains set.
failedResolves map[domain.Domain]time.Time
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
@@ -76,9 +83,10 @@ type Resolver struct {
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
failedResolves: make(map[domain.Domain]time.Time),
cacheTTL: resolveCacheTTL(),
}
}
@@ -173,7 +181,9 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// entry for that qtype. When one family hard-errors while the other succeeds,
// the resolved family is still cached but AddDomain returns an error so the
// caller retries the incomplete resolve rather than treating it as complete.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
@@ -203,6 +213,10 @@ func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
if errA != nil || errAAAA != nil {
return fmt.Errorf("resolve %s: incomplete, a family failed: %w", d.SafeString(), errors.Join(errA, errAAAA))
}
return nil
}
@@ -462,6 +476,7 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
delete(m.failedResolves, d)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -505,6 +520,7 @@ func (m *Resolver) UpdateFromServerDomains(ctx context.Context, serverDomains dn
allDomains := m.extractDomainsFromServerDomains(updatedServerDomains)
currentDomains := m.GetCachedDomains()
removedDomains = m.removeStaleDomains(currentDomains, allDomains)
m.pruneFailedResolves(allDomains)
}
m.addNewDomains(ctx, newDomains)
@@ -577,13 +593,85 @@ func (m *Resolver) isManagementDomain(domain domain.Domain) bool {
return m.mgmtDomain != nil && domain == *m.mgmtDomain
}
// addNewDomains resolves and caches all domains from the update
// addNewDomains resolves and caches domains that are not yet in the cache,
// running the lookups concurrently. Domains already cached are skipped and left
// to the stale-while-revalidate refresh path, so a sync never re-resolves them
// synchronously: once NetBird owns the OS resolver the resolve runs through the
// handler chain and would otherwise dial the managed upstreams under the engine
// sync lock on every update.
func (m *Resolver) addNewDomains(ctx context.Context, newDomains domain.List) {
var wg sync.WaitGroup
seen := make(map[domain.Domain]struct{}, len(newDomains))
for _, newDomain := range newDomains {
if err := m.AddDomain(ctx, newDomain); err != nil {
log.Warnf("failed to add/update domain=%s: %v", newDomain.SafeString(), err)
} else {
log.Debugf("added/updated management cache domain=%s", newDomain.SafeString())
if _, dup := seen[newDomain]; dup {
continue
}
seen[newDomain] = struct{}{}
if !m.needsResolve(newDomain) {
continue
}
wg.Add(1)
go func(d domain.Domain) {
defer wg.Done()
if err := m.AddDomain(ctx, d); err != nil {
m.markResolveFailed(d)
log.Warnf("failed to add/update domain=%s: %v", d.SafeString(), err)
return
}
m.clearResolveFailed(d)
log.Debugf("added/updated management cache domain=%s", d.SafeString())
}(newDomain)
}
wg.Wait()
}
// needsResolve reports whether d should be resolved now. A recent failed or
// incomplete resolve gates retries on the backoff even when one family is
// already cached, so a transiently-failed family is retried instead of being
// treated as fully resolved. Otherwise a domain with any cached record is left
// to the stale-while-revalidate refresh path.
func (m *Resolver) needsResolve(d domain.Domain) bool {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
m.mutex.RLock()
defer m.mutex.RUnlock()
if failedAt, ok := m.failedResolves[d]; ok {
return time.Since(failedAt) >= refreshBackoff
}
for _, qtype := range []uint16{dns.TypeA, dns.TypeAAAA} {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
if _, ok := m.records[q]; ok {
return false
}
}
return true
}
func (m *Resolver) markResolveFailed(d domain.Domain) {
m.mutex.Lock()
m.failedResolves[d] = time.Now()
m.mutex.Unlock()
}
func (m *Resolver) clearResolveFailed(d domain.Domain) {
m.mutex.Lock()
delete(m.failedResolves, d)
m.mutex.Unlock()
}
// pruneFailedResolves drops failure markers for domains no longer present in
// the server-domains set, keeping the map bounded to the current set (a
// failed-only domain has no cached record, so RemoveDomain never sees it).
func (m *Resolver) pruneFailedResolves(domains domain.List) {
m.mutex.Lock()
defer m.mutex.Unlock()
for d := range m.failedResolves {
if !slices.Contains(domains, d) {
delete(m.failedResolves, d)
}
}
}

View File

@@ -21,6 +21,7 @@ type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
qErr map[string]error
err error
hasRoot bool
onLookup func()
@@ -30,6 +31,7 @@ func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
qErr: map[string]error{},
hasRoot: true,
}
}
@@ -47,6 +49,9 @@ func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriori
f.calls[key]++
answers := f.answers[key]
err := f.err
if err == nil {
err = f.qErr[key]
}
onLookup := f.onLookup
f.mu.Unlock()
@@ -75,6 +80,12 @@ func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
}
}
func (f *fakeChain) setErr(name string, qtype uint16, err error) {
f.mu.Lock()
defer f.mu.Unlock()
f.qErr[name+"|"+dns.TypeToString[qtype]] = err
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()

View File

@@ -0,0 +1,183 @@
package mgmt
import (
"context"
"errors"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/shared/management/domain"
)
// A domain already in the cache must not be re-resolved on a subsequent server
// domains update; it is left to the stale-while-revalidate refresh path.
func TestResolver_UpdateFromServerDomains_SkipsCached(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("signal.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must resolve the domain")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"cached domain must not be re-resolved on a subsequent update")
}
// New domains in a single update must resolve concurrently rather than serially.
func TestResolver_AddNewDomains_ResolvesConcurrently(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
var inflight, maxInflight atomic.Int32
chain.onLookup = func() {
n := inflight.Add(1)
for {
old := maxInflight.Load()
if n <= old || maxInflight.CompareAndSwap(old, n) {
break
}
}
time.Sleep(50 * time.Millisecond)
inflight.Add(-1)
}
relays := []domain.Domain{"a.example.com", "b.example.com", "c.example.com", "d.example.com"}
for _, d := range relays {
chain.setAnswer(dns.Fqdn(string(d)), dns.TypeA, "10.0.0.2")
}
r.SetChainResolver(chain, 50)
start := time.Now()
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: relays})
require.NoError(t, err)
elapsed := time.Since(start)
assert.GreaterOrEqual(t, int(maxInflight.Load()), 2, "domains must resolve concurrently")
// Serial resolution of 4 domains would take at least 4*50ms; concurrent is far less.
assert.Less(t, elapsed, 300*time.Millisecond, "resolution should not be serial")
}
// A domain that fails to resolve must not be retried on every update; the
// failure backoff suppresses re-resolution until it expires.
func TestResolver_UpdateFromServerDomains_BacksOffFailures(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{Signal: domain.Domain("signal.example.com")}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
require.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"first update must attempt the resolve")
_, err = r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("signal.example.com.", dns.TypeA),
"failed resolve must back off and not retry on the next update")
}
// A domain listed under more than one server-domain type (e.g. STUN and TURN on
// the same host) must be resolved once per update, not once per occurrence.
func TestResolver_AddNewDomains_DedupesDuplicateDomains(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("dup.example.com.", dns.TypeA, "10.0.0.9")
r.SetChainResolver(chain, 50)
sd := dnsconfig.ServerDomains{
Stuns: []domain.Domain{"dup.example.com"},
Turns: []domain.Domain{"dup.example.com"},
}
_, err := r.UpdateFromServerDomains(context.Background(), sd)
require.NoError(t, err)
assert.Equal(t, 1, chain.callCount("dup.example.com.", dns.TypeA),
"a domain appearing under multiple server-domain types must resolve once")
}
// A failure marker must be dropped once its domain leaves the server-domains set
// so the map stays bounded to the current set.
func TestResolver_UpdateFromServerDomains_PrunesFailedResolves(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("resolve boom")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("gone.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
require.True(t, marked, "failed resolve must be recorded")
_, err = r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Signal: domain.Domain("other.example.com")})
require.NoError(t, err)
r.mutex.RLock()
_, stillMarked := r.failedResolves[domain.Domain("gone.example.com")]
r.mutex.RUnlock()
assert.False(t, stillMarked, "failure marker for a domain no longer in the set must be pruned")
}
// When one family hard-errors while the other resolves, the domain is cached
// for the working family but recorded as incomplete so the failed family is
// retried under backoff instead of being treated as fully resolved forever.
func TestResolver_AddNewDomains_RetriesPartialFamilyFailure(t *testing.T) {
d := domain.Domain("relay.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("relay.example.com.", dns.TypeA, "10.0.0.2")
chain.setErr("relay.example.com.", dns.TypeAAAA, errors.New("servfail"))
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, aCached := r.records[dns.Question{Name: "relay.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}]
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
require.True(t, aCached, "the working family must still be cached")
require.True(t, marked, "a partial failure must be recorded so the failed family is retried")
assert.False(t, r.needsResolve(d), "within the backoff window the domain is not retried")
r.mutex.Lock()
r.failedResolves[d] = time.Now().Add(-2 * refreshBackoff)
r.mutex.Unlock()
assert.True(t, r.needsResolve(d), "after the backoff elapses the domain is retried to pick up the missing family")
}
// A family that returns NODATA (legitimately absent, e.g. an IPv4-only host) is
// not a failure: the domain must not be marked for retry, otherwise it would be
// re-resolved on every sync.
func TestResolver_AddNewDomains_NodataIsNotFailure(t *testing.T) {
d := domain.Domain("v4only.example.com")
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("v4only.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
_, err := r.UpdateFromServerDomains(context.Background(), dnsconfig.ServerDomains{Relay: []domain.Domain{d}})
require.NoError(t, err)
r.mutex.RLock()
_, marked := r.failedResolves[d]
r.mutex.RUnlock()
assert.False(t, marked, "a NODATA family must not be recorded as a failure")
assert.False(t, r.needsResolve(d), "an IPv4-only host must not be re-resolved on later syncs")
}

View File

@@ -1205,7 +1205,7 @@ func (s *Server) SyncMeta(ctx context.Context, req *proto.EncryptedMessage) (*pr
return nil, msg
}
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()))
err = s.accountManager.SyncPeerMeta(ctx, peerKey.String(), extractPeerMeta(ctx, syncMetaReq.GetMeta()), realIP)
if err != nil {
return nil, mapError(ctx, err)
}
@@ -1254,7 +1254,10 @@ func (s *Server) Logout(ctx context.Context, req *proto.EncryptedMessage) (*prot
func toProtocolChecks(ctx context.Context, postureChecks []*posture.Checks) []*proto.Checks {
protoChecks := make([]*proto.Checks, 0, len(postureChecks))
for _, postureCheck := range postureChecks {
protoChecks = append(protoChecks, toProtocolCheck(postureCheck))
check := toProtocolCheck(postureCheck)
if check != nil {
protoChecks = append(protoChecks, check)
}
}
return protoChecks
@@ -1278,5 +1281,9 @@ func toProtocolCheck(postureCheck *posture.Checks) *proto.Checks {
}
}
if len(protoCheck.Files) == 0 {
return nil
}
return protoCheck
}

View File

@@ -1889,12 +1889,12 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano(), netMap); err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, accountID, syncTime.UnixNano(), netMap); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
@@ -1914,13 +1914,13 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account
return nil
}
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
func (am *DefaultAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
accountID, err := am.Store.GetAccountIDByPeerPubKey(ctx, peerPubKey)
if err != nil {
return err
}
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, UpdateAccountPeers: true}, accountID)
_, _, _, _, err = am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta, RealIP: realIP, UpdateAccountPeers: true}, accountID)
if err != nil {
return err
}

View File

@@ -62,7 +62,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
@@ -123,7 +123,7 @@ type Manager interface {
GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error)
SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error)
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)

View File

@@ -1323,17 +1323,17 @@ func (mr *MockManagerMockRecorder) ExtendPeerSession(ctx, peerPubKey, userID int
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, accountID, sessionStartedAt, nmap)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, accountID, sessionStartedAt, nmap interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, accountID, sessionStartedAt, nmap)
}
// MarkPeerDisconnected mocks base method.
@@ -1586,17 +1586,17 @@ func (mr *MockManagerMockRecorder) SyncPeer(ctx, sync, accountID interface{}) *g
}
// SyncPeerMeta mocks base method.
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta) error {
func (m *MockManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta peer.PeerSystemMeta, realIP net.IP) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta)
ret := m.ctrl.Call(m, "SyncPeerMeta", ctx, peerPubKey, meta, realIP)
ret0, _ := ret[0].(error)
return ret0
}
// SyncPeerMeta indicates an expected call of SyncPeerMeta.
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) SyncPeerMeta(ctx, peerPubKey, meta, realIP interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SyncPeerMeta", reflect.TypeOf((*MockManager)(nil).SyncPeerMeta), ctx, peerPubKey, meta, realIP)
}
// SyncUserJWTGroups mocks base method.

View File

@@ -1836,7 +1836,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1907,7 +1907,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1916,6 +1916,117 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SchedulesInactivityExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: true,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
// Establish a session so the matching-token disconnect is actually applied.
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
// Install the mock only now, so the assertion observes the disconnect, not
// the earlier connect.
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
// expected: disconnect re-armed the inactivity expiry timer
case <-time.After(time.Second):
t.Fatal("expected inactivity expiration to be rescheduled when an eligible peer disconnects")
}
}
func TestDefaultAccountManager_MarkPeerDisconnected_SkipsInactivityExpirationWhenDisabled(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to create an account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"},
InactivityExpirationEnabled: true,
}, false)
require.NoError(t, err, "unable to add peer")
// Peer is eligible (SSO + inactivity enabled) but the account-level setting
// stays disabled, so disconnect must not schedule anything.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true,
PeerInactivityExpiration: time.Hour,
PeerInactivityExpirationEnabled: false,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err, "expecting to update account settings successfully but got error")
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
scheduled := make(chan struct{}, 1)
manager.peerInactivityExpiry = &MockScheduler{
CancelFunc: func(ctx context.Context, IDs []string) {},
ScheduleFunc: func(ctx context.Context, in time.Duration, ID string, job func() (nextRunIn time.Duration, reschedule bool)) {
select {
case scheduled <- struct{}{}:
default:
}
},
}
err = manager.MarkPeerDisconnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer disconnected")
select {
case <-scheduled:
t.Fatal("inactivity expiration must not be scheduled while the account-level setting is disabled")
case <-time.After(200 * time.Millisecond):
// expected: nothing scheduled
}
}
func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1935,7 +2046,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1956,7 +2067,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, streamStartTime.UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1980,7 +2091,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node2SyncTime.UnixNano(), nil)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -1990,7 +2101,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, node1StaleSyncTime.UnixNano(), nil)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
@@ -2052,7 +2163,7 @@ func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token, nil)
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, accountID, token, nil)
}()
}
@@ -2093,7 +2204,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano(), nil)
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), accountID, time.Now().UTC().UnixNano(), nil)
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -39,7 +39,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
@@ -114,7 +114,7 @@ type MockAccountManager struct {
GetIdpManagerFunc func() idp.Manager
UpdateIntegratedValidatorFunc func(ctx context.Context, accountID, userID, validator string, groups []string) error
GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error)
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error
SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error
FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)
GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error)
GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error)
@@ -345,9 +345,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt, nmap)
return am.MarkPeerConnectedFunc(ctx, peerKey, accountID, sessionStartedAt, nmap)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
@@ -975,9 +975,9 @@ func (am *MockAccountManager) GroupValidation(ctx context.Context, accountId str
}
// SyncPeerMeta mocks SyncPeerMeta of the AccountManager interface
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error {
func (am *MockAccountManager) SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) error {
if am.SyncPeerMetaFunc != nil {
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta)
return am.SyncPeerMetaFunc(ctx, peerPubKey, meta, realIP)
}
return status.Errorf(codes.Unimplemented, "method SyncPeerMeta is not implemented")
}

View File

@@ -74,7 +74,7 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64, nmap *types.NetworkMap) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
@@ -102,10 +102,6 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
if err = am.schedulePeerExpirations(ctx, accountID, peer); err != nil {
return err
}
@@ -192,27 +188,40 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
}
}
if peer.AddedWithSSOLogin() && peer.InactivityExpirationEnabled {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed getting account settings to schedule inactivity expiration for peer %s: %v", peer.ID, err)
} else if settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
return nil
}
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
// resolvePeerLocation looks up the geo location for realIP, returning nil when
// there is nothing to apply: geo disabled, no real IP, the IP is unchanged from
// what the peer already has, or the lookup failed. Geo lookups are skipped on
// same-IP reconnects since they are comparatively expensive. The returned value
// is applied by Peer.UpdateMetaIfNew so the change is persisted by its peer save.
func (am *DefaultAccountManager) resolvePeerLocation(ctx context.Context, peer *nbpeer.Peer, realIP net.IP) *nbpeer.Location {
if am.geo == nil || realIP == nil {
return nil
}
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
return nil
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return
return nil
}
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
return &nbpeer.Location{
ConnectionIP: realIP,
CountryCode: location.Country.ISOCode,
CityName: location.City.Names.En,
GeoNameID: location.City.GeonameID,
}
}
@@ -980,7 +989,8 @@ func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
// SyncPeer checks whether peer is eligible for receiving NetworkMap (authenticated) and returns its NetworkMap if eligible
func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
var peer *nbpeer.Peer
var updated, versionChanged, ipv6CapabilityChanged bool
var ipv6CapabilityChanged bool
var metaDiff nbpeer.MetaDiff
var err error
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1010,9 +1020,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
newLocation := am.resolvePeerLocation(ctx, peer, sync.RealIP)
metaDiff = peer.UpdateMetaIfNew(ctx, sync.Meta, newLocation)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
if metaDiff.Updated() {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
log.WithContext(ctx).Tracef("peer %s metadata updated", peer.ID)
if err = transaction.SavePeer(ctx, accountID, peer); err != nil {
@@ -1040,9 +1051,10 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return nil, nil, nil, 0, err
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(resPostureChecks) > 0 || versionChanged)) {
metaDiffAffectsPosture := posture.AffectsPosture(ctx, &metaDiff, resPostureChecks)
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || metaDiffAffectsPosture || metaDiff.VersionChanged() || metaDiff.HostnameChanged() {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, updated, len(resPostureChecks) > 0)
affectedPeerIDs := am.syncPeerAffectedPeers(ctx, accountID, peer.ID, nmap, peerNotValid, metaDiffAffectsPosture)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1059,8 +1071,8 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
// metadata change that flips a posture result removes this peer from others'
// maps asymmetrically; that case (and an invalid peer, whose map is empty) falls
// back to the resolver.
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaUpdated, hasPostureChecks bool) []string {
if peerNotValid || (metaUpdated && hasPostureChecks) {
func (am *DefaultAccountManager) syncPeerAffectedPeers(ctx context.Context, accountID, peerID string, nmap *types.NetworkMap, peerNotValid, metaChangeAffectedPosture bool) []string {
if peerNotValid || metaChangeAffectedPosture {
return am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, []string{peerID})
}
return affectedPeerIDsFromNetworkMap(nmap, peerID)

View File

@@ -107,6 +107,15 @@ type Location struct {
GeoNameID uint // city level geoname id
}
// equal reports whether two locations match. ConnectionIP is a net.IP slice, so it uses
// IP.Equal, not ==.
func (l Location) equal(other Location) bool {
return l.CountryCode == other.CountryCode &&
l.CityName == other.CityName &&
l.GeoNameID == other.GeoNameID &&
l.ConnectionIP.Equal(other.ConnectionIP)
}
// NetworkAddress is the IP address with network and MAC address of a network interface
type NetworkAddress struct {
NetIP netip.Prefix `gorm:"serializer:json"`
@@ -256,50 +265,88 @@ func (p *Peer) Copy() *Peer {
}
}
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
// UpdateMetaIfNew updates peer's system metadata and connection geo location if
// new information is provided. newLocation is the geo location resolved from the
// peer's current connection IP, or nil when there is nothing to apply (geo
// disabled, no real IP, or the IP is unchanged); the caller owns the expensive
// lookup and the same-IP guard. It returns a MetaDiff describing what changed;
// diff.Updated() reports whether the peer needs to be persisted.
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta, newLocation *Location) MetaDiff {
if meta.isEmpty() {
return updated, versionChanged
return MetaDiff{}
}
versionChanged = p.Meta.WtVersion != meta.WtVersion
// Avoid overwriting UIVersion if the update was triggered sole by the CLI client
if meta.UIVersion == "" {
meta.UIVersion = p.Meta.UIVersion
}
oldVersion := p.Meta.WtVersion
effectiveLocation := p.Location
if newLocation != nil {
effectiveLocation = *newLocation
}
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
diff := diffMeta(p.Meta, meta, p.Location, effectiveLocation)
if diff.Updated() {
p.Meta = meta
updated = true
}
p.Location = effectiveLocation
if diff.Updated() {
log.WithContext(ctx).Debug(diff.LogSummary())
}
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
return diff
}
// MetaDiff holds a peer's full before/after state across a sync: both metas and both
// connection locations (the location lives on Peer, not PeerSystemMeta, but posture
// checks read it). Changed lists what moved, for logging and the persistence decision;
// the snapshots let a posture check be replayed against old and new. Everything is derived
// from these fields, so there are no parallel per-field flags to keep in sync.
type MetaDiff struct {
OldMeta PeerSystemMeta
NewMeta PeerSystemMeta
OldLocation Location
NewLocation Location
Changed []string
}
// Updated reports whether anything changed and the peer must be persisted. diffMeta fills
// Changed in the pass that builds the diff, so this is a length check, not a re-comparison.
// Pointer receiver: MetaDiff embeds two metas, so copying it per call is wasteful.
func (d *MetaDiff) Updated() bool {
return len(d.Changed) != 0
}
// VersionChanged reports whether the WireGuard client version changed (a client upgrade).
func (d *MetaDiff) VersionChanged() bool {
return d.OldMeta.WtVersion != d.NewMeta.WtVersion
}
// HostnameChanged reports whether the peer's hostname changed.
func (d *MetaDiff) HostnameChanged() bool {
return d.OldMeta.Hostname != d.NewMeta.Hostname
}
// LogSummary renders the changed fields as a single human-readable line.
func (d *MetaDiff) LogSummary() string {
return fmt.Sprintf("peer meta updated, %d field(s) changed: %s",
len(d.Changed), strings.Join(d.Changed, ", "))
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
return diffMeta(oldMeta, newMeta, Location{}, Location{}).Changed
}
// diffMeta snapshots a peer's old and new state and records a Changed entry per field that
// moved. It is the single source of truth for the comparison: isEqual is an empty Changed
// list, so the log line and the persistence decision can never disagree.
func diffMeta(oldMeta, newMeta PeerSystemMeta, oldLocation, newLocation Location) MetaDiff {
d := MetaDiff{OldMeta: oldMeta, NewMeta: newMeta, OldLocation: oldLocation, NewLocation: newLocation}
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
d.Changed = append(d.Changed, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
@@ -353,16 +400,18 @@ func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
if !sameMultiset(oldMeta.NetworkAddresses, newMeta.NetworkAddresses) {
add("network_addresses", fmt.Sprintf("%v", oldMeta.NetworkAddresses), fmt.Sprintf("%v", newMeta.NetworkAddresses))
}
if !sameMultiset(oldMeta.Files, newMeta.Files) {
add("files", fmt.Sprintf("%v", oldMeta.Files), fmt.Sprintf("%v", newMeta.Files))
}
return diff
if !oldLocation.equal(newLocation) {
add("connection_ip", oldLocation.ConnectionIP, newLocation.ConnectionIP)
}
return d
}
// sameMultiset reports whether two slices contain the same elements with the

View File

@@ -0,0 +1,202 @@
package posture
import (
"context"
"net"
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
)
// diffFrom builds a MetaDiff from the old/new snapshots AffectsPosture replays against.
func diffFrom(oldMeta, newMeta nbpeer.PeerSystemMeta, oldLoc, newLoc nbpeer.Location) *nbpeer.MetaDiff {
return &nbpeer.MetaDiff{
OldMeta: oldMeta,
NewMeta: newMeta,
OldLocation: oldLoc,
NewLocation: newLoc,
}
}
func checks(def ChecksDefinition) []*Checks {
return []*Checks{{Checks: def}}
}
func TestAffectsPosture_NilDiff(t *testing.T) {
assert.False(t, AffectsPosture(context.Background(), nil, checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
})))
}
func TestAffectsPosture_NBVersion(t *testing.T) {
c := checks(ChecksDefinition{NBVersionCheck: &NBVersionCheck{MinVersion: "1.2.0"}})
tests := []struct {
name string
oldVer, newVer string
want bool
}{
{"both above min, no flip", "1.3.0", "1.4.0", false},
{"both below min, no flip", "1.0.0", "1.1.0", false},
{"crosses up below->above", "1.1.0", "1.3.0", true},
{"crosses down above->below", "1.3.0", "1.1.0", true},
{"unparsable old only -> flip", "garbage", "1.3.0", true},
{"unparsable both -> no flip", "garbage", "junk", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: tt.oldVer},
nbpeer.PeerSystemMeta{WtVersion: tt.newVer},
nbpeer.Location{}, nbpeer.Location{},
)
assert.Equal(t, tt.want, AffectsPosture(context.Background(), diff, c))
})
}
}
func TestAffectsPosture_OSVersion_KernelBumpWithinMin(t *testing.T) {
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "5.0.0"},
}})
// Kernel moves but stays above the minimum: verdict stays pass -> not affected.
withinMin := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.15.0-arch2"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), withinMin, c))
// Kernel drops below the minimum: verdict flips pass -> fail -> affected.
crossesDown := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "5.10.0-arch1"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0-arch1"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), crossesDown, c))
}
func TestAffectsPosture_OSVersion_GoOSSwitchFlipsVerdict(t *testing.T) {
// Only Linux is constrained. An OS outside the switch (freebsd) passes; switching to a
// failing linux kernel flips the verdict pass -> fail.
c := checks(ChecksDefinition{OSVersionCheck: &OSVersionCheck{
Linux: &MinKernelVersionCheck{MinKernelVersion: "6.0.0"},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "freebsd"},
nbpeer.PeerSystemMeta{GoOS: "linux", KernelVersion: "4.19.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_GoOSSwitchFlipsVerdict(t *testing.T) {
// Process runs at a linux path. Switching GoOS to windows (no WindowsPath configured)
// flips the verdict.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
files := []nbpeer.File{{Path: "/usr/bin/foo", ProcessIsRunning: true}}
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: files},
nbpeer.PeerSystemMeta{GoOS: "windows", Files: files},
nbpeer.Location{}, nbpeer.Location{},
)
assert.True(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_Process_UnrelatedFileChange(t *testing.T) {
// A tracked process stays running while an unrelated file is added: the verdict does
// not move, so posture is not affected.
c := checks(ChecksDefinition{ProcessCheck: &ProcessCheck{
Processes: []Process{{LinuxPath: "/usr/bin/foo"}},
}})
diff := diffFrom(
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
}},
nbpeer.PeerSystemMeta{GoOS: "linux", Files: []nbpeer.File{
{Path: "/usr/bin/foo", ProcessIsRunning: true},
{Path: "/usr/bin/bar", ProcessIsRunning: true},
}},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_GeoLocation(t *testing.T) {
c := checks(ChecksDefinition{GeoLocationCheck: &GeoLocationCheck{
Action: CheckActionAllow,
Locations: []Location{{CountryCode: "DE"}},
}})
// Moving within allowed countries keeps the verdict; moving out flips it.
stayAllowed := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE", CityName: "Berlin"},
nbpeer.Location{CountryCode: "DE", CityName: "Munich"},
)
assert.False(t, AffectsPosture(context.Background(), stayAllowed, c))
moveOut := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{CountryCode: "DE"},
nbpeer.Location{CountryCode: "FR"},
)
assert.True(t, AffectsPosture(context.Background(), moveOut, c))
}
func TestAffectsPosture_PeerNetworkRange_ConnectionIP(t *testing.T) {
// The check reads the connection IP. Moving out of the allowed range flips the verdict;
// moving within it does not.
_, allowed, _ := net.ParseCIDR("10.0.0.0/8")
c := checks(ChecksDefinition{PeerNetworkRangeCheck: &PeerNetworkRangeCheck{
Action: CheckActionAllow,
Ranges: []netip.Prefix{netip.MustParsePrefix(allowed.String())},
}})
movesOutOfRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("8.8.8.8")},
)
assert.True(t, AffectsPosture(context.Background(), movesOutOfRange, c))
staysInRange := diffFrom(
nbpeer.PeerSystemMeta{}, nbpeer.PeerSystemMeta{},
nbpeer.Location{ConnectionIP: net.ParseIP("10.1.2.3")},
nbpeer.Location{ConnectionIP: net.ParseIP("10.9.9.9")},
)
assert.False(t, AffectsPosture(context.Background(), staysInRange, c))
}
func TestAffectsPosture_IrrelevantFieldChange(t *testing.T) {
// Hostname changes but no check reads it: not affected even with checks present.
c := checks(ChecksDefinition{
NBVersionCheck: &NBVersionCheck{MinVersion: "1.0.0"},
GeoLocationCheck: &GeoLocationCheck{Action: CheckActionAllow, Locations: []Location{{CountryCode: "DE"}}},
})
diff := diffFrom(
nbpeer.PeerSystemMeta{Hostname: "old", WtVersion: "1.5.0"},
nbpeer.PeerSystemMeta{Hostname: "new", WtVersion: "1.5.0"},
nbpeer.Location{CountryCode: "DE"}, nbpeer.Location{CountryCode: "DE"},
)
assert.False(t, AffectsPosture(context.Background(), diff, c))
}
func TestAffectsPosture_NoChecks(t *testing.T) {
diff := diffFrom(
nbpeer.PeerSystemMeta{WtVersion: "1.0.0"},
nbpeer.PeerSystemMeta{WtVersion: "2.0.0"},
nbpeer.Location{}, nbpeer.Location{},
)
assert.False(t, AffectsPosture(context.Background(), diff, nil))
}

View File

@@ -7,6 +7,8 @@ import (
"regexp"
"github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
@@ -51,6 +53,46 @@ type Checks struct {
Checks ChecksDefinition `gorm:"serializer:json"`
}
// AffectsPosture reports whether the change in diff flips the verdict of any check. It
// replays each check against the peer's old and new state and compares verdicts, so a
// change that moves a field but stays the right side of a threshold (e.g. a kernel bump
// still above the minimum) does not force a re-evaluation. See verdictChanged for how an
// evaluation error counts.
func AffectsPosture(ctx context.Context, diff *nbpeer.MetaDiff, checks []*Checks) bool {
if diff == nil {
return false
}
oldPeer := nbpeer.Peer{Meta: diff.OldMeta, Location: diff.OldLocation}
newPeer := nbpeer.Peer{Meta: diff.NewMeta, Location: diff.NewLocation}
for _, c := range checks {
for _, check := range c.GetChecks() {
if verdictChanged(ctx, check, oldPeer, newPeer) {
return true
}
}
}
return false
}
// verdictChanged replays check against old and new state and reports whether the verdict
// differs. Like callers, it treats an evaluation error as deny: two errors are the same
// verdict (no change), an error on one side only is a flip.
func verdictChanged(ctx context.Context, check Check, oldPeer, newPeer nbpeer.Peer) bool {
oldPass, oldErr := check.Check(ctx, oldPeer)
newPass, newErr := check.Check(ctx, newPeer)
oldVerdict := oldPass && oldErr == nil
newVerdict := newPass && newErr == nil
changed := oldVerdict != newVerdict
log.WithContext(ctx).Tracef("posture check %s replay: verdict %t -> %t (changed=%t), errs: %v -> %v",
check.Name(), oldVerdict, newVerdict, changed, oldErr, newErr)
return changed
}
// ChecksDefinition contains definition of actual check
type ChecksDefinition struct {
NBVersionCheck *NBVersionCheck `json:",omitempty"`

View File

@@ -581,28 +581,6 @@ func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accoun
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
// Since the location field has been migrated to JSON serialization,
// updating the struct ensures the correct data format is inserted into the database.
peerCopy.Location = peerWithLocation.Location
result := s.db.Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerWithLocation.ID).
Updates(peerCopy)
if result.Error != nil {
return status.Errorf(status.Internal, "failed to save peer locations to store: %v", result.Error)
}
if result.RowsAffected == 0 {
return status.Errorf(status.NotFound, peerNotFoundFMT, peerWithLocation.ID)
}
return nil
}
// ApproveAccountPeers marks all peers that currently require approval in the given account as approved.
func (s *SqlStore) ApproveAccountPeers(ctx context.Context, accountID string) (int, error) {
result := s.db.Model(&nbpeer.Peer{}).

View File

@@ -618,56 +618,6 @@ func TestSqlStore_SavePeerStatus(t *testing.T) {
assert.WithinDurationf(t, newStatus.LastSeen, actual.LastSeen.UTC(), time.Millisecond, "LastSeen should be equal")
}
func TestSqlStore_SavePeerLocation(t *testing.T) {
store, cleanUp, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
t.Cleanup(cleanUp)
assert.NoError(t, err)
account, err := store.GetAccount(context.Background(), "bf1c8084-ba50-4ce7-9439-34653001fc3b")
require.NoError(t, err)
peer := &nbpeer.Peer{
AccountID: account.Id,
ID: "testpeer",
Location: nbpeer.Location{
ConnectionIP: net.ParseIP("0.0.0.0"),
CountryCode: "YY",
CityName: "City",
GeoNameID: 1,
},
CreatedAt: time.Now().UTC(),
Meta: nbpeer.PeerSystemMeta{},
}
// error is expected as peer is not in store yet
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
account.Peers[peer.ID] = peer
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
peer.Location.ConnectionIP = net.ParseIP("35.1.1.1")
peer.Location.CountryCode = "DE"
peer.Location.CityName = "Berlin"
peer.Location.GeoNameID = 2950159
err = store.SavePeerLocation(context.Background(), account.Id, account.Peers[peer.ID])
assert.NoError(t, err)
account, err = store.GetAccount(context.Background(), account.Id)
require.NoError(t, err)
actual := account.Peers[peer.ID].Location
assert.Equal(t, peer.Location, actual)
peer.ID = "non-existing-peer"
err = store.SavePeerLocation(context.Background(), account.Id, peer)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
require.True(t, ok)
require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error")
}
func Test_TestGetAccountByPrivateDomain(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("The SQLite store is not properly supported by Windows yet")

View File

@@ -185,7 +185,6 @@ type Store interface {
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -2968,20 +2968,6 @@ func (mr *MockStoreMockRecorder) SavePeer(ctx, accountID, peer interface{}) *gom
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeer", reflect.TypeOf((*MockStore)(nil).SavePeer), ctx, accountID, peer)
}
// SavePeerLocation mocks base method.
func (m *MockStore) SavePeerLocation(ctx context.Context, accountID string, peer *peer.Peer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SavePeerLocation", ctx, accountID, peer)
ret0, _ := ret[0].(error)
return ret0
}
// SavePeerLocation indicates an expected call of SavePeerLocation.
func (mr *MockStoreMockRecorder) SavePeerLocation(ctx, accountID, peer interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerLocation", reflect.TypeOf((*MockStore)(nil).SavePeerLocation), ctx, accountID, peer)
}
// SavePeerStatus mocks base method.
func (m *MockStore) SavePeerStatus(ctx context.Context, accountID, peerID string, status peer.PeerStatus) error {
m.ctrl.T.Helper()

View File

@@ -12,6 +12,9 @@ type PeerSync struct {
WireGuardPubKey string
// Meta is the system information passed by peer, must be always present
Meta nbpeer.PeerSystemMeta
// RealIP is the peer's connection IP, used to refresh its geo location.
// May be nil when the request has no associated connection IP.
RealIP net.IP
// UpdateAccountPeers indicate updating account peers,
// which occurs when the peer's metadata is updated
UpdateAccountPeers bool