Compare commits

..

2 Commits

Author SHA1 Message Date
Viktor Liu
486cd4c0e3 Add test for bidirectional SSH rule authorized users on source peers 2026-05-16 15:59:02 +02:00
Viktor Liu
b59382d4f2 Collect SSH authorized users for bidirectional rules on source peers 2026-05-16 15:49:07 +02:00
25 changed files with 571 additions and 1636 deletions

View File

@@ -260,23 +260,15 @@ WriteRegStr ${REG_ROOT} "${UNINSTALL_PATH}" "Publisher" "${COMP_NAME}"
WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Drop Run, App Paths and Uninstall entries left in the 32-bit registry view
; or HKCU by legacy installers.
DetailPrint "Cleaning legacy 32-bit / HKCU entries..."
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -307,16 +299,11 @@ ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
DetailPrint "Terminating Netbird UI process..."
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart entries from every view a previous installer may have used.
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
SetRegView 32
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DeleteRegKey HKLM "${REG_APP_PATH}"
DeleteRegKey HKLM "${UI_REG_APP_PATH}"
DeleteRegKey HKLM "${UNINSTALL_PATH}"
SetRegView 64
; Handle data deletion based on checkbox
DetailPrint "Checking if user requested data deletion..."

View File

@@ -64,13 +64,6 @@
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
<!-- Drop the HKCU Run\Netbird value written by legacy NSIS installers. -->
<Component Id="NetbirdLegacyHKCUCleanup" Guid="*">
<RegistryValue Root="HKCU" Key="Software\NetBird GmbH\Installer"
Name="LegacyHKCUCleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKCU"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
@@ -83,28 +76,10 @@
</Directory>
</StandardDirectory>
<!-- Drop Run, App Paths and Uninstall entries written by legacy NSIS
installers into the 32-bit registry view (HKLM\Software\Wow6432Node). -->
<Component Id="NetbirdLegacyWow6432Cleanup" Directory="NetbirdInstallDir"
Guid="bda5d628-16bd-4086-b2c1-5099d8d51763" Bitness="always32">
<RegistryValue Root="HKLM" Key="Software\NetBird GmbH\Installer"
Name="LegacyWow6432Cleanup" Type="integer" Value="1" KeyPath="yes" />
<RemoveRegistryValue Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Run" Name="Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\App Paths\Netbird-ui" />
<RemoveRegistryKey Action="removeOnInstall" Root="HKLM"
Key="Software\Microsoft\Windows\CurrentVersion\Uninstall\Netbird" />
</Component>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
<ComponentRef Id="NetbirdLegacyHKCUCleanup" />
<ComponentRef Id="NetbirdLegacyWow6432Cleanup" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

View File

@@ -2,7 +2,6 @@ package manager
import (
"context"
"math/rand"
"sync"
"time"
@@ -12,344 +11,240 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/telemetry"
)
const (
// cleanupWindow is the small grace period added on top of the
// staleness horizon before a sweep fires. It absorbs minor clock
// skew between the management server and the database and avoids
// firing a sweep right at the boundary where last_seen could still
// be one tick under the threshold.
// cleanupWindow is the time window to wait after nearest peer deadline to start the cleanup procedure.
cleanupWindow = 1 * time.Minute
// initialLoadMinDelay and initialLoadMaxDelay bracket the random
// delay applied before the post-restart catch-up query runs. Spread
// across replicas this prevents a thundering herd of catch-up
// queries hitting the database simultaneously after a deploy.
initialLoadMinDelay = 8 * time.Minute
initialLoadMaxDelay = 10 * time.Minute
)
var (
timeNow = time.Now
)
// accountEntry is the per-account state held by the cleanup tracker.
// We don't track which peers are pending — the sweep query gets the
// authoritative list straight from the database every time. We only
// need to know the latest disconnect we've observed for this account
// (so we can decide when it's safe to drop the entry) and the timer
// that will fire the next sweep.
type accountEntry struct {
lastDisconnectedAt time.Time
timer *time.Timer
type ephemeralPeer struct {
id string
accountID string
deadline time.Time
next *ephemeralPeer
}
// EphemeralManager tracks accounts that may have ephemeral peers in
// need of cleanup and runs a per-account sweep at the appropriate
// time. State is in-memory and account-scoped: a sweep deletes any
// ephemeral peer in the account that has been disconnected for at
// least lifeTime, then either drops the account from the tracker
// (when no recent disconnects have arrived) or re-arms the timer.
// todo: consider to remove peer from ephemeral list when the peer has been deleted via API. If we do not do it
// in worst case we will get invalid error message in this manager.
// EphemeralManager keep a list of ephemeral peers. After EphemeralLifeTime inactivity the peer will be deleted
// automatically. Inactivity means the peer disconnected from the Management server.
type EphemeralManager struct {
store store.Store
peersManager peers.Manager
accountsLock sync.Mutex
accounts map[string]*accountEntry
// initialLoadTimer is the one-shot timer used to defer the
// post-restart catch-up query; held so Stop() can cancel it.
initialLoadTimer *time.Timer
// stopped is flipped by Stop() so any timer that fires after
// teardown becomes a no-op instead of touching a half-dismantled
// store.
stopped bool
headPeer *ephemeralPeer
tailPeer *ephemeralPeer
peersLock sync.Mutex
timer *time.Timer
lifeTime time.Duration
cleanupWindow time.Duration
// initialLoadDelay returns the wall-clock delay to wait before
// running the post-restart catch-up query. Pluggable so tests can
// fire the load immediately.
initialLoadDelay func() time.Duration
// bgCtx is the long-lived context captured at LoadInitialPeers
// time. Timer-driven sweeps use it because they fire long after
// the original gRPC handler ctx that produced an OnPeerDisconnected
// call has been cancelled.
bgCtx context.Context
// metrics is nil-safe; methods on telemetry.EphemeralPeersMetrics
// no-op when the receiver is nil so deployments without an app
// metrics provider work unchanged.
metrics *telemetry.EphemeralPeersMetrics
}
// NewEphemeralManager instantiate new EphemeralManager
func NewEphemeralManager(store store.Store, peersManager peers.Manager) *EphemeralManager {
return &EphemeralManager{
store: store,
peersManager: peersManager,
accounts: make(map[string]*accountEntry),
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
initialLoadDelay: defaultInitialLoadDelay,
store: store,
peersManager: peersManager,
lifeTime: ephemeral.EphemeralLifeTime,
cleanupWindow: cleanupWindow,
}
}
// SetMetrics attaches a metrics collector. Pass nil to detach.
func (e *EphemeralManager) SetMetrics(m *telemetry.EphemeralPeersMetrics) {
e.accountsLock.Lock()
e.metrics = m
e.accountsLock.Unlock()
// LoadInitialPeers load from the database the ephemeral type of peers and schedule a cleanup procedure to the head
// of the linked list (to the most deprecated peer). At the end of cleanup it schedules the next cleanup to the new
// head.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.peersLock.Lock()
defer e.peersLock.Unlock()
e.loadEphemeralPeers(ctx)
if e.headPeer != nil {
e.timer = time.AfterFunc(e.lifeTime, func() {
e.cleanup(ctx)
})
}
}
// LoadInitialPeers schedules the post-restart catch-up query for a
// random moment 8-10 minutes from now. Returns immediately. The
// catch-up populates the per-account tracker from the database so any
// peers that disconnected before the restart still get cleaned up.
//
// The random delay is critical: without it, every management replica
// hitting the same Postgres instance after a deploy would issue the
// catch-up query simultaneously.
func (e *EphemeralManager) LoadInitialPeers(ctx context.Context) {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
// Stop timer
func (e *EphemeralManager) Stop() {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.timer != nil {
e.timer.Stop()
}
}
// OnPeerConnected remove the peer from the linked list of ephemeral peers. Because it has been called when the peer
// is active the manager will not delete it while it is active.
func (e *EphemeralManager) OnPeerConnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
e.bgCtx = ctx
log.WithContext(ctx).Tracef("remove peer from ephemeral list: %s", peer.ID)
delay := e.initialLoadDelay()
log.WithContext(ctx).Infof("ephemeral peer initial load scheduled in %s", delay)
e.initialLoadTimer = time.AfterFunc(delay, func() {
e.loadInitialAccounts(e.bgCtx)
})
}
e.peersLock.Lock()
defer e.peersLock.Unlock()
// Stop cancels the deferred initial load and any per-account timers.
func (e *EphemeralManager) Stop() {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
e.removePeer(peer.ID)
e.stopped = true
if e.initialLoadTimer != nil {
e.initialLoadTimer.Stop()
e.initialLoadTimer = nil
// stop the unnecessary timer
if e.headPeer == nil && e.timer != nil {
e.timer.Stop()
e.timer = nil
}
for _, entry := range e.accounts {
if entry.timer != nil {
entry.timer.Stop()
}
}
e.accounts = make(map[string]*accountEntry)
}
// OnPeerConnected is a no-op in the account-scoped design. The sweep
// query filters out connected peers at the database level, so we don't
// need an explicit "remove from list" signal when a peer reconnects.
// Kept on the interface to preserve the existing call sites.
func (e *EphemeralManager) OnPeerConnected(_ context.Context, _ *nbpeer.Peer) {
}
// OnPeerDisconnected registers a disconnect for the peer's account and
// arms a sweep if one isn't already scheduled. Non-ephemeral peers are
// ignored.
// OnPeerDisconnected add the peer to the linked list of ephemeral peers. Because of the peer
// is inactive it will be deleted after the EphemeralLifeTime period.
func (e *EphemeralManager) OnPeerDisconnected(ctx context.Context, peer *nbpeer.Peer) {
if !peer.Ephemeral {
return
}
now := timeNow()
log.WithContext(ctx).Tracef("add peer to ephemeral list: %s", peer.ID)
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
e.peersLock.Lock()
defer e.peersLock.Unlock()
if e.isPeerOnList(peer.ID) {
return
}
entry, existed := e.accounts[peer.AccountID]
if !existed {
entry = &accountEntry{}
e.accounts[peer.AccountID] = entry
e.metrics.IncPending()
}
entry.lastDisconnectedAt = now
if entry.timer == nil {
delay := e.lifeTime + e.cleanupWindow
log.WithContext(ctx).Tracef("ephemeral: scheduling sweep for account %s in %s", peer.AccountID, delay)
accountID := peer.AccountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(e.bgCtxOrFallback(ctx), accountID)
e.addPeer(peer.AccountID, peer.ID, e.newDeadLine())
if e.timer == nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
})
}
}
// bgCtxOrFallback returns the long-lived background context captured at
// LoadInitialPeers time, falling back to the supplied ctx when the
// manager hasn't been started through LoadInitialPeers (e.g. in tests
// that drive the manager directly). Must be called with the lock held
// or before the timer is armed.
func (e *EphemeralManager) bgCtxOrFallback(ctx context.Context) context.Context {
if e.bgCtx != nil {
return e.bgCtx
func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) {
peers, err := e.store.GetAllEphemeralPeers(ctx, store.LockingStrengthNone)
if err != nil {
log.WithContext(ctx).Debugf("failed to load ephemeral peers: %s", err)
return
}
return ctx
t := e.newDeadLine()
for _, p := range peers {
e.addPeer(p.AccountID, p.ID, t)
}
log.WithContext(ctx).Debugf("loaded ephemeral peer(s): %d", len(peers))
}
// loadInitialAccounts runs the post-restart catch-up query and seeds
// the tracker with one entry per account that has at least one
// disconnected ephemeral peer.
func (e *EphemeralManager) loadInitialAccounts(ctx context.Context) {
accounts, err := e.store.GetEphemeralAccountsLastDisconnect(ctx)
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral accounts on startup: %v", err)
return
}
func (e *EphemeralManager) cleanup(ctx context.Context) {
log.Tracef("on ephemeral cleanup")
deletePeers := make(map[string]*ephemeralPeer)
e.peersLock.Lock()
now := timeNow()
added := 0
for p := e.headPeer; p != nil; p = p.next {
if now.Before(p.deadline) {
break
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
deletePeers[p.id] = p
e.headPeer = p.next
if p.next == nil {
e.tailPeer = nil
}
}
for accountID, lastDisc := range accounts {
// If we already learned about this account via an
// OnPeerDisconnected that arrived during the random delay
// window, prefer the live timestamp.
if _, alreadyTracked := e.accounts[accountID]; alreadyTracked {
continue
if e.headPeer != nil {
delay := e.headPeer.deadline.Sub(timeNow()) + e.cleanupWindow
if delay < 0 {
delay = 0
}
entry := &accountEntry{lastDisconnectedAt: lastDisc}
horizon := lastDisc.Add(e.lifeTime)
var delay time.Duration
if horizon.After(now) {
delay = horizon.Sub(now) + e.cleanupWindow
} else {
// Already past the staleness window — sweep right away
// (one cleanupWindow later, to keep startup load smooth
// when many accounts qualify at once).
delay = e.cleanupWindow
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
e.timer = time.AfterFunc(delay, func() {
e.cleanup(ctx)
})
e.accounts[accountID] = entry
added++
} else {
e.timer = nil
}
e.metrics.AddPending(int64(added))
log.WithContext(ctx).Debugf("ephemeral: loaded %d account(s) for cleanup tracking", added)
e.peersLock.Unlock()
peerIDsPerAccount := make(map[string][]string)
for id, p := range deletePeers {
peerIDsPerAccount[p.accountID] = append(peerIDsPerAccount[p.accountID], id)
}
for accountID, peerIDs := range peerIDsPerAccount {
log.WithContext(ctx).Tracef("cleanup: deleting %d ephemeral peers for account %s", len(peerIDs), accountID)
err := e.peersManager.DeletePeers(ctx, accountID, peerIDs, activity.SystemInitiator, true)
if err != nil {
log.WithContext(ctx).Errorf("failed to delete ephemeral peers: %s", err)
}
}
}
// sweep runs the cleanup pass for a single account. It queries the
// database for disconnected ephemeral peers that have crossed the
// staleness window, deletes them via peers.Manager, and then decides
// whether to drop the account from the tracker or re-arm the timer.
func (e *EphemeralManager) sweep(ctx context.Context, accountID string) {
now := timeNow()
e.accountsLock.Lock()
entry, ok := e.accounts[accountID]
if !ok || e.stopped {
e.accountsLock.Unlock()
return
func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) {
ep := &ephemeralPeer{
id: peerID,
accountID: accountID,
deadline: deadline,
}
lastDisc := entry.lastDisconnectedAt
entry.timer = nil
e.accountsLock.Unlock()
threshold := now.Add(-e.lifeTime)
stalePeerIDs, err := e.store.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, threshold)
if err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to query stale peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
if e.headPeer == nil {
e.headPeer = ep
}
if e.tailPeer != nil {
e.tailPeer.next = ep
}
e.tailPeer = ep
}
func (e *EphemeralManager) removePeer(id string) {
if e.headPeer == nil {
return
}
if len(stalePeerIDs) > 0 {
log.WithContext(ctx).Tracef("ephemeral: deleting %d peer(s) for account %s", len(stalePeerIDs), accountID)
if err := e.peersManager.DeletePeers(ctx, accountID, stalePeerIDs, activity.SystemInitiator, true); err != nil {
log.WithContext(ctx).Errorf("ephemeral: failed to delete peers for account %s: %v", accountID, err)
e.metrics.CountCleanupError()
e.rearm(ctx, accountID, e.cleanupWindow)
if e.headPeer.id == id {
e.headPeer = e.headPeer.next
if e.tailPeer.id == id {
e.tailPeer = nil
}
return
}
for p := e.headPeer; p.next != nil; p = p.next {
if p.next.id == id {
// if we remove the last element from the chain then set the last-1 as tail
if e.tailPeer.id == id {
e.tailPeer = p
}
p.next = p.next.next
return
}
e.metrics.CountCleanupRun()
e.metrics.CountPeersCleaned(int64(len(stalePeerIDs)))
}
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
}
entry, ok = e.accounts[accountID]
if !ok {
return
}
// Drop rule: if every disconnect we've observed has now crossed
// the staleness window, the sweep we just ran saw everything that
// could possibly need cleaning. Dropping is safe — a future
// disconnect will recreate the entry. The check uses the latest
// lastDisc, which may have advanced (concurrently with the sweep
// itself) due to a new OnPeerDisconnected, in which case we
// correctly re-arm.
horizon := entry.lastDisconnectedAt.Add(e.lifeTime)
if !horizon.After(now) {
delete(e.accounts, accountID)
e.metrics.DecPending(1)
log.WithContext(ctx).Tracef("ephemeral: dropping account %s (lastDisc=%s, horizon=%s, now=%s)",
accountID, lastDisc, horizon, now)
return
}
delay := horizon.Sub(now) + e.cleanupWindow
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
}
// rearm reschedules a sweep `delay` from now. Used after a recoverable
// error in the sweep path so the account doesn't get stuck.
func (e *EphemeralManager) rearm(ctx context.Context, accountID string, delay time.Duration) {
e.accountsLock.Lock()
defer e.accountsLock.Unlock()
if e.stopped {
return
func (e *EphemeralManager) isPeerOnList(id string) bool {
for p := e.headPeer; p != nil; p = p.next {
if p.id == id {
return true
}
}
entry, ok := e.accounts[accountID]
if !ok {
return
}
idForClosure := accountID
entry.timer = time.AfterFunc(delay, func() {
e.sweep(ctx, idForClosure)
})
return false
}
// defaultInitialLoadDelay returns a random duration in
// [initialLoadMinDelay, initialLoadMaxDelay). Process-wide
// math/rand is acceptable here — the delay is purely a smoothing
// jitter, not a security primitive.
func defaultInitialLoadDelay() time.Duration {
span := int64(initialLoadMaxDelay - initialLoadMinDelay)
if span <= 0 {
return initialLoadMinDelay
}
return initialLoadMinDelay + time.Duration(rand.Int63n(span))
func (e *EphemeralManager) newDeadLine() time.Time {
return timeNow().Add(e.lifeTime)
}

View File

@@ -2,544 +2,299 @@ package manager
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/golang/mock/gomock"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/peers"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
nbAccount "github.com/netbirdio/netbird/management/server/account"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// MockStore is a thin in-memory stand-in that implements only the two
// methods the EphemeralManager uses. It honors the account / ephemeral
// / connected / lastSeen attributes of each peer so the cleanup logic
// can be exercised end-to-end without bringing up sqlite or Postgres.
type MockStore struct {
store.Store
mu sync.Mutex
account *types.Account
}
func (s *MockStore) GetStaleEphemeralPeerIDsForAccount(_ context.Context, accountID string, olderThan time.Time) ([]string, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.account == nil || s.account.Id != accountID {
return nil, nil
}
var ids []string
for _, p := range s.account.Peers {
if !p.Ephemeral {
continue
}
if p.Status == nil || p.Status.Connected {
continue
}
if p.Status.LastSeen.Before(olderThan) {
ids = append(ids, p.ID)
func (s *MockStore) GetAllEphemeralPeers(_ context.Context, _ store.LockingStrength) ([]*nbpeer.Peer, error) {
var peers []*nbpeer.Peer
for _, v := range s.account.Peers {
if v.Ephemeral {
peers = append(peers, v)
}
}
return ids, nil
return peers, nil
}
func (s *MockStore) GetEphemeralAccountsLastDisconnect(_ context.Context) (map[string]time.Time, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := map[string]time.Time{}
if s.account == nil {
return out, nil
}
var latest time.Time
hasAny := false
for _, p := range s.account.Peers {
if !p.Ephemeral || p.Status == nil || p.Status.Connected {
continue
}
if !hasAny || p.Status.LastSeen.After(latest) {
latest = p.Status.LastSeen
hasAny = true
}
}
if hasAny {
out[s.account.Id] = latest
}
return out, nil
type MockAccountManager struct {
mu sync.Mutex
nbAccount.Manager
store *MockStore
deletePeerCalls int
bufferUpdateCalls map[string]int
wg *sync.WaitGroup
}
// withFakeClock pins timeNow to a settable value for the duration of t.
// Returns a getter and a setter so subtests can advance virtual time.
func withFakeClock(t *testing.T, start time.Time) (get func() time.Time, set func(time.Time)) {
t.Helper()
var mu sync.Mutex
now := start
func (a *MockAccountManager) DeletePeer(_ context.Context, accountID, peerID, userID string) error {
a.mu.Lock()
defer a.mu.Unlock()
a.deletePeerCalls++
delete(a.store.account.Peers, peerID)
if a.wg != nil {
a.wg.Done()
}
return nil
}
func (a *MockAccountManager) GetDeletePeerCalls() int {
a.mu.Lock()
defer a.mu.Unlock()
return a.deletePeerCalls
}
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
a.bufferUpdateCalls = make(map[string]int)
}
a.bufferUpdateCalls[accountID]++
}
func (a *MockAccountManager) GetBufferUpdateCalls(accountID string) int {
a.mu.Lock()
defer a.mu.Unlock()
if a.bufferUpdateCalls == nil {
return 0
}
return a.bufferUpdateCalls[accountID]
}
func (a *MockAccountManager) GetStore() store.Store {
return a.store
}
func TestNewManager(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
return startTime
}
t.Cleanup(func() { timeNow = time.Now })
return func() time.Time {
mu.Lock()
defer mu.Unlock()
return now
}, func(v time.Time) {
mu.Lock()
defer mu.Unlock()
now = v
}
}
// newManagerForTest builds a manager with short timers and no random
// initial-load delay so tests run instantly.
func newManagerForTest(t *testing.T, st store.Store, peersMgr peers.Manager) *EphemeralManager {
t.Helper()
mgr := NewEphemeralManager(st, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
mgr.initialLoadDelay = func() time.Duration { return 0 }
t.Cleanup(mgr.Stop)
return mgr
}
// TestOnPeerDisconnected_RegistersAndSweeps drives the OnPeerDisconnected
// path with a fake clock: a single ephemeral peer disconnects, we
// advance past the staleness window, and the sweep deletes it.
func TestOnPeerDisconnected_RegistersAndSweeps(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
store := &MockStore{}
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersManager := peers.NewMockManager(ctrl)
var deletedMu sync.Mutex
var deleted []string
var deleteCalls atomic.Int32
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, accountID string, peerIDs []string, _ string, _ bool) error {
deleteCalls.Add(1)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
mockStore.mu.Unlock()
deletedMu.Lock()
deleted = append(deleted, peerIDs...)
deletedMu.Unlock()
return nil
}).AnyTimes()
}).
AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
// One ephemeral peer that disconnected "now".
now := getNow()
p := &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
if len(store.account.Peers) != numberOfPeers {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", numberOfPeers, len(store.account.Peers))
}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
// Advance past lifeTime + cleanupWindow so the timer-driven sweep fires.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool { return deleteCalls.Load() >= 1 }, 2*time.Second, 5*time.Millisecond,
"sweep should fire and delete the stale peer")
deletedMu.Lock()
deletedCopy := append([]string(nil), deleted...)
deletedMu.Unlock()
require.Equal(t, []string{"p1"}, deletedCopy, "only the one ephemeral peer should be deleted")
}
// TestOnPeerDisconnected_NonEphemeralIgnored: a non-ephemeral disconnect
// must not register the account or arm any timer.
func TestOnPeerDisconnected_NonEphemeralIgnored(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// No DeletePeers expectation — must not be called.
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1",
AccountID: "acc-1",
Ephemeral: false,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
func TestNewManagerPeerConnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "non-ephemeral disconnect must not register an account")
mgr.accountsLock.Unlock()
store := &MockStore{}
ctrl := gomock.NewController(t)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for ephemeral peers (except the connected one)
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
return nil
}).
AnyTimes()
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
mgr.OnPeerConnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
expected := numberOfPeers + 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
}
// TestSweep_DropsAccountWhenIdle: after a sweep cleans the stale peers,
// if no more disconnects have arrived the account must be dropped from
// the in-memory tracker.
func TestSweep_DropsAccountWhenIdle(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
func TestNewManagerPeerDisconnected(t *testing.T) {
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
store := &MockStore{}
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
peersManager := peers.NewMockManager(ctrl)
numberOfPeers := 5
numberOfEphemeralPeers := 3
seedPeers(store, numberOfPeers, numberOfEphemeralPeers)
// Expect DeletePeers to be called for the one disconnected peer
peersManager.EXPECT().
DeletePeers(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
for _, peerID := range peerIDs {
delete(store.account.Peers, peerID)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
}).
AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr := NewEphemeralManager(store, peersManager)
mgr.loadEphemeralPeers(context.Background())
for _, v := range store.account.Peers {
mgr.OnPeerConnected(context.Background(), v)
now := getNow()
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
}
mgr.OnPeerDisconnected(context.Background(), store.account.Peers["ephemeral_peer_0"])
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
startTime = startTime.Add(ephemeral.EphemeralLifeTime + 1)
mgr.cleanup(context.Background())
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should be dropped after sweep with no new disconnects")
expected := numberOfPeers + numberOfEphemeralPeers - 1
if len(store.account.Peers) != expected {
t.Errorf("failed to cleanup ephemeral peers, expected: %d, result: %d", expected, len(store.account.Peers))
}
}
// TestSweep_ReArmsWhenNewDisconnectArrived: simulate the race where a
// fresh disconnect arrives just before the sweep fires. The sweep must
// observe the updated lastDisc and re-arm rather than drop.
func TestSweep_ReArmsWhenNewDisconnectArrived(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
const (
ephemeralPeers = 10
testLifeTime = 1 * time.Second
testCleanupWindow = 100 * time.Millisecond
)
t.Cleanup(func() {
timeNow = time.Now
})
startTime := time.Now()
timeNow = func() time.Time {
return startTime
}
mockStore := &MockStore{}
account := newAccountWithId(context.Background(), "account", "", "", false)
mockStore.account = account
wg := &sync.WaitGroup{}
wg.Add(ephemeralPeers)
mockAM := &MockAccountManager{
store: mockStore,
wg: wg,
}
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
peersManager := peers.NewMockManager(ctrl)
// Set up expectation that DeletePeers will be called once with all peer IDs
peersManager.EXPECT().
DeletePeers(gomock.Any(), account.Id, gomock.Any(), gomock.Any(), true).
DoAndReturn(func(ctx context.Context, accountID string, peerIDs []string, userID string, checkConnected bool) error {
// Simulate the actual deletion behavior
for _, peerID := range peerIDs {
err := mockAM.DeletePeer(ctx, accountID, peerID, userID)
if err != nil {
return err
}
}
mockStore.mu.Unlock()
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
return nil
}).AnyTimes()
}).
Times(1)
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr := NewEphemeralManager(mockStore, peersManager)
mgr.lifeTime = testLifeTime
mgr.cleanupWindow = testCleanupWindow
now := getNow()
p1 := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now}}
mockStore.account.Peers[p1.ID] = p1
mgr.OnPeerDisconnected(context.Background(), p1)
// Advance most of the way toward the first sweep, then introduce
// a fresh disconnect that resets lastDisc.
setNow(now.Add(mgr.lifeTime - 10*time.Millisecond))
p2 := &nbpeer.Peer{ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p2.ID] = p2
mgr.OnPeerDisconnected(context.Background(), p2)
// Push past p1's staleness so the first sweep runs and cleans p1
// but observes p2 already on the account entry. It must re-arm.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p1 should be cleaned at the first sweep")
// The account should still be tracked because p2 is younger than lifeTime
// from the sweep's vantage point at this moment.
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account should remain tracked because p2's disconnect kept it active")
// Push past p2's staleness; second sweep cleans p2 and drops the account.
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mgr.accountsLock.Lock()
defer mgr.accountsLock.Unlock()
return len(mgr.accounts) == 0
}, 2*time.Second, 5*time.Millisecond, "account should drop after the final sweep")
}
// TestSweep_BatchesPeersPerAccount: many ephemeral peers disconnect on
// the same account; a single sweep must delete them all in one
// DeletePeers call.
func TestSweep_BatchesPeersPerAccount(t *testing.T) {
const ephemeralPeers = 8
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
deleteBatches := make(chan []string, 4)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
cp := append([]string(nil), peerIDs...)
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
deleteBatches <- cp
return nil
}).Times(1)
mgr := newManagerForTest(t, mockStore, peersMgr)
now := getNow()
for i := 0; i < ephemeralPeers; i++ {
id := fmt.Sprintf("p-%d", i)
// Stagger by a fraction of cleanupWindow so they all fall on
// the same sweep tick.
when := now.Add(time.Duration(i) * time.Millisecond)
p := &nbpeer.Peer{ID: id, AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: when}}
mockStore.account.Peers[id] = p
// Add peers and disconnect them at slightly different times (within cleanup window)
for i := range ephemeralPeers {
p := &nbpeer.Peer{ID: fmt.Sprintf("peer-%d", i), AccountID: account.Id, Ephemeral: true}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
startTime = startTime.Add(testCleanupWindow / (ephemeralPeers * 2))
}
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Advance time past the lifetime to trigger cleanup
startTime = startTime.Add(testLifeTime + testCleanupWindow)
select {
case batch := <-deleteBatches:
require.Len(t, batch, ephemeralPeers, "all peers should be deleted in a single batch")
case <-time.After(2 * time.Second):
t.Fatal("expected one batched DeletePeers call")
}
// Wait for all deletions to complete
wg.Wait()
assert.Len(t, mockStore.account.Peers, 0, "all ephemeral peers should be cleaned up after the lifetime")
assert.Equal(t, 1, mockAM.GetBufferUpdateCalls(account.Id), "buffer update should be called once")
assert.Equal(t, ephemeralPeers, mockAM.GetDeletePeerCalls(), "should have deleted all peers")
}
// TestLoadInitialAccounts_SeedsFromStore exercises the post-restart
// catch-up path: pre-populate the store, point the manager at it, and
// confirm both already-stale and not-yet-stale peers get cleaned at
// their proper times.
func TestLoadInitialAccounts_SeedsFromStore(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
getNow, setNow := withFakeClock(t, time.Now())
func seedPeers(store *MockStore, numberOfPeers int, numberOfEphemeralPeers int) {
store.account = newAccountWithId(context.Background(), "my account", "", "", false)
now := getNow()
// p-stale: already past the staleness window when load runs.
mockStore.account.Peers["p-stale"] = &nbpeer.Peer{
ID: "p-stale", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now.Add(-time.Hour)},
}
// p-fresh: disconnected but not yet stale.
mockStore.account.Peers["p-fresh"] = &nbpeer.Peer{
ID: "p-fresh", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: now},
for i := 0; i < numberOfPeers; i++ {
peerId := fmt.Sprintf("peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: false,
}
store.account.Peers[p.ID] = p
}
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mgr := newManagerForTest(t, mockStore, peersMgr)
// Drive loadInitialAccounts directly with the fake-clock-aware now.
mgr.loadInitialAccounts(context.Background())
// First sweep should fire shortly (cleanupWindow) for the stale peer.
setNow(now.Add(5 * mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-stale"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-stale should be deleted on the first sweep")
// p-fresh is not yet stale; advance past its window.
setNow(now.Add(mgr.lifeTime + 5*mgr.cleanupWindow))
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p-fresh"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "p-fresh should be deleted once it crosses the staleness window")
}
// TestStop_CancelsPendingWork verifies that Stop() cancels both the
// deferred initial load and per-account sweep timers and that
// subsequent OnPeerDisconnected calls are ignored.
func TestStop_CancelsPendingWork(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
// DeletePeers must NOT be called after Stop.
mgr := NewEphemeralManager(mockStore, peersMgr)
mgr.lifeTime = 100 * time.Millisecond
mgr.cleanupWindow = 10 * time.Millisecond
// Use a long delay so the initial-load timer is still pending.
mgr.initialLoadDelay = func() time.Duration { return time.Hour }
mgr.LoadInitialPeers(context.Background())
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.NotNil(t, mgr.initialLoadTimer, "initial-load timer should be armed")
require.Len(t, mgr.accounts, 1, "account should be tracked after disconnect")
mgr.accountsLock.Unlock()
mgr.Stop()
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "Stop should clear tracked accounts")
require.True(t, mgr.stopped, "stopped flag must be set")
mgr.accountsLock.Unlock()
// Post-stop disconnect must be ignored.
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p2", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Empty(t, mgr.accounts, "disconnects after Stop must be ignored")
mgr.accountsLock.Unlock()
}
// TestOnPeerConnected_IsNoop: the OnPeerConnected hook is preserved on
// the interface but does nothing in the per-account model — the sweep
// query filters connected peers at the DB level.
func TestOnPeerConnected_IsNoop(t *testing.T) {
mockStore := &MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)}
withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
mgr.OnPeerDisconnected(context.Background(), &nbpeer.Peer{
ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: timeNow()},
})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "disconnect should track the account")
mgr.accountsLock.Unlock()
mgr.OnPeerConnected(context.Background(), &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true})
mgr.accountsLock.Lock()
require.Len(t, mgr.accounts, 1, "OnPeerConnected must be a no-op")
mgr.accountsLock.Unlock()
}
// TestSweep_StoreErrorReArms: if the stale-peer query fails, the
// account must remain tracked and a follow-up sweep gets scheduled.
func TestSweep_StoreErrorReArms(t *testing.T) {
mockStore := &erroringStore{
MockStore: MockStore{account: newAccountWithId(context.Background(), "acc-1", "", "", false)},
}
getNow, setNow := withFakeClock(t, time.Now())
ctrl := gomock.NewController(t)
peersMgr := peers.NewMockManager(ctrl)
mgr := newManagerForTest(t, mockStore, peersMgr)
p := &nbpeer.Peer{ID: "p1", AccountID: "acc-1", Ephemeral: true,
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: getNow()}}
mockStore.account.Peers[p.ID] = p
mgr.OnPeerDisconnected(context.Background(), p)
mockStore.fail.Store(true)
setNow(getNow().Add(mgr.lifeTime + 5*mgr.cleanupWindow))
// Wait until the failing sweep has run at least once.
require.Eventually(t, func() bool { return mockStore.failedCalls.Load() >= 1 },
2*time.Second, 5*time.Millisecond, "expected at least one failing sweep")
mgr.accountsLock.Lock()
_, stillTracked := mgr.accounts["acc-1"]
mgr.accountsLock.Unlock()
require.True(t, stillTracked, "account must remain tracked after a sweep error")
// Recover and ensure the rearmed sweep cleans up.
peersMgr.EXPECT().
DeletePeers(gomock.Any(), "acc-1", gomock.Any(), gomock.Any(), true).
DoAndReturn(func(_ context.Context, _ string, peerIDs []string, _ string, _ bool) error {
mockStore.mu.Lock()
for _, id := range peerIDs {
delete(mockStore.account.Peers, id)
}
mockStore.mu.Unlock()
return nil
}).AnyTimes()
mockStore.fail.Store(false)
require.Eventually(t, func() bool {
mockStore.mu.Lock()
defer mockStore.mu.Unlock()
_, gone := mockStore.account.Peers["p1"]
return !gone
}, 2*time.Second, 5*time.Millisecond, "rearmed sweep should clean up after the store recovers")
}
// erroringStore is a MockStore that can be flipped into a failing mode
// to exercise the sweep's error-rearm path.
type erroringStore struct {
MockStore
fail atomic.Bool
failedCalls atomic.Int32
}
func (s *erroringStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
if s.fail.Load() {
s.failedCalls.Add(1)
return nil, errors.New("synthetic store error")
}
return s.MockStore.GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan)
}
// TestDefaultInitialLoadDelay confirms the jitter falls inside the
// documented [8m, 10m) range — sanity check for the production timer.
func TestDefaultInitialLoadDelay(t *testing.T) {
for i := 0; i < 1000; i++ {
d := defaultInitialLoadDelay()
assert.GreaterOrEqual(t, d, initialLoadMinDelay)
assert.Less(t, d, initialLoadMaxDelay)
for i := 0; i < numberOfEphemeralPeers; i++ {
peerId := fmt.Sprintf("ephemeral_peer_%d", i)
p := &nbpeer.Peer{
ID: peerId,
Ephemeral: true,
}
store.account.Peers[p.ID] = p
}
}
@@ -596,7 +351,3 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis
}
return acc
}
// silence the import "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
// (still needed indirectly for ephemeral.EphemeralLifeTime in production paths).
var _ = ephemeral.EphemeralLifeTime

View File

@@ -112,11 +112,7 @@ func (s *BaseServer) AuthManager() auth.Manager {
func (s *BaseServer) EphemeralManager() ephemeral.Manager {
return Create(s, func() ephemeral.Manager {
em := manager.NewEphemeralManager(s.Store(), s.PeersManager())
if metrics := s.Metrics(); metrics != nil {
em.SetMetrics(metrics.EphemeralPeersMetrics())
}
return em
return manager.NewEphemeralManager(s.Store(), s.PeersManager())
})
}

View File

@@ -522,11 +522,10 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even
}
func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {
uncanceledCTX := context.WithoutCancel(ctx)
unlock := s.acquirePeerLockByUID(uncanceledCTX, peer.Key)
unlock := s.acquirePeerLockByUID(ctx, peer.Key)
defer unlock()
s.cancelPeerRoutinesWithoutLock(uncanceledCTX, accountID, peer, streamStartTime)
s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime)
}
func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) {

View File

@@ -291,15 +291,10 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.NewPermissionDeniedError()
}
// Canonicalize the incoming range so a caller-supplied prefix with host bits
// (e.g. 100.64.1.1/16) compares equal to the masked form stored on network.Net.
newSettings.NetworkRange = newSettings.NetworkRange.Masked()
var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
var reloadReverseProxy bool
var effectiveOldNetworkRange netip.Prefix
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
@@ -313,16 +308,6 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return err
}
// No lock: the transaction already holds Settings(Update), and network.Net is
// only mutated by reallocateAccountPeerIPs, which is reachable only through
// this same code path. A Share lock here would extend an unnecessary row lock
// and complicate ordering against updatePeerIPv6InTransaction.
network, err := transaction.GetAccountNetwork(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return fmt.Errorf("get account network: %w", err)
}
effectiveOldNetworkRange = prefixFromIPNet(network.Net)
if oldSettings.Extra != nil && newSettings.Extra != nil &&
oldSettings.Extra.PeerApprovalEnabled && !newSettings.Extra.PeerApprovalEnabled {
approvedCount, err := transaction.ApproveAccountPeers(ctx, accountID)
@@ -336,7 +321,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
}
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
if oldSettings.NetworkRange != newSettings.NetworkRange {
if err = am.reallocateAccountPeerIPs(ctx, transaction, accountID, newSettings.NetworkRange); err != nil {
return err
}
@@ -411,9 +396,9 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, eventMeta)
}
if newSettings.NetworkRange.IsValid() && newSettings.NetworkRange != effectiveOldNetworkRange {
if oldSettings.NetworkRange != newSettings.NetworkRange {
eventMeta := map[string]any{
"old_network_range": effectiveOldNetworkRange.String(),
"old_network_range": oldSettings.NetworkRange.String(),
"new_network_range": newSettings.NetworkRange.String(),
}
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountNetworkRangeUpdated, eventMeta)
@@ -458,22 +443,6 @@ func ipv6SettingsChanged(old, updated *types.Settings) bool {
return !slices.Equal(oldGroups, newGroups)
}
// prefixFromIPNet returns the overlay prefix actually allocated on the account
// network, or an invalid prefix if none is set. Settings.NetworkRange is a
// user-facing override that is empty on legacy accounts, so the effective
// range must be read from network.Net to compare against an incoming update.
func prefixFromIPNet(ipNet net.IPNet) netip.Prefix {
if ipNet.IP == nil {
return netip.Prefix{}
}
addr, ok := netip.AddrFromSlice(ipNet.IP)
if !ok {
return netip.Prefix{}
}
ones, _ := ipNet.Mask.Size()
return netip.PrefixFrom(addr.Unmap(), ones)
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
@@ -1868,32 +1837,35 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
// network map and then marks the peer connected with a session token
// derived from syncTime (the moment the gRPC stream opened). Any
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, postureChecks, dnsfwdPort, nil
}
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
// peer disconnected only when the stored SessionStartedAt matches the
// nanosecond token derived from streamStartTime — i.e. only when this
// is the stream that currently owns the peer's session. A mismatch
// means a newer stream has already replaced us, so the disconnect is
// dropped.
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil

View File

@@ -61,8 +61,7 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error

View File

@@ -1305,31 +1305,17 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
}
// MarkPeerDisconnected mocks base method.
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
}
// OnPeerDisconnected mocks base method.

View File

@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1910,16 +1910,15 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal the token we passed in")
streamStartTime := time.Now().UTC()
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
@@ -1927,127 +1926,49 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
})
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
// Older stream tries to mark disconnect with its own (older) session token —
// fencing kicks in and the write is dropped.
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because the stored session is newer than the disconnect token")
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should still hold the winning stream's token")
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
})
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal node2SyncTime token")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
})
}
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
// fencing protocol under contention: many goroutines race to mark the
// same peer connected with distinct session tokens at the same time.
// The contract is that the highest token always wins and is what remains
// in the store, regardless of execution order.
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
const workers = 16
base := time.Now().UTC().UnixNano()
tokens := make([]int64, workers)
for i := range tokens {
// Spread tokens by 1ms so the comparison is unambiguous; the
// largest is index workers-1.
tokens[i] = base + int64(i)*int64(time.Millisecond)
}
expected := tokens[workers-1]
var ready sync.WaitGroup
ready.Add(workers)
var start sync.WaitGroup
start.Add(1)
var done sync.WaitGroup
done.Add(workers)
// require.* calls t.FailNow which is documented as unsafe from
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
// races with the WaitGroup). Collect errors here and assert from the
// main goroutine after done.Wait().
errs := make(chan error, workers)
for i := 0; i < workers; i++ {
token := tokens[i]
go func() {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
}()
}
ready.Wait()
start.Done()
done.Wait()
close(errs)
for err := range errs {
require.NoError(t, err, "MarkPeerConnected must not error under contention")
}
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected after the race")
require.Equal(t, expected, peer.Status.SessionStartedAt,
"the largest token must win regardless of execution order")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -2070,7 +1991,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}
@@ -4049,96 +3970,6 @@ func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangeChange(t *testi
}
}
// TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved guards against
// peer IP reallocation when a settings update carries the network range that is already
// in use. Legacy accounts have Settings.NetworkRange unset in the DB while network.Net
// holds the actual allocated overlay; the dashboard backfills the GET response from
// network.Net and echoes the value back on PUT, so the diff must be against the
// effective range to avoid renumbering every peer on an unrelated settings change.
func TestDefaultAccountManager_UpdateAccountSettings_NetworkRangePreserved(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
ctx := context.Background()
settings, err := manager.Store.GetAccountSettings(ctx, store.LockingStrengthNone, account.Id)
require.NoError(t, err)
require.False(t, settings.NetworkRange.IsValid(), "precondition: new accounts leave Settings.NetworkRange unset")
network, err := manager.Store.GetAccountNetwork(ctx, store.LockingStrengthNone, account.Id)
require.NoError(t, err)
require.NotNil(t, network.Net.IP, "precondition: network.Net should be allocated")
addr, ok := netip.AddrFromSlice(network.Net.IP)
require.True(t, ok)
ones, _ := network.Net.Mask.Size()
effective := netip.PrefixFrom(addr.Unmap(), ones)
require.True(t, effective.IsValid())
before := map[string]netip.Addr{peer1.ID: peer1.IP, peer2.ID: peer2.IP, peer3.ID: peer3.IP}
// Round-trip the effective range as if the dashboard echoed back the GET-backfilled value.
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: effective,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err := manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
require.Len(t, peers, len(before))
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when range matches effective", p.ID)
}
// Carrying the same range with host bits set must also be a no-op once canonicalized.
hostBitsForm := netip.PrefixFrom(peer1.IP, ones)
require.NotEqual(t, effective, hostBitsForm, "precondition: host-bit form should differ before masking")
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: hostBitsForm,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change for host-bit-set equivalent range", p.ID)
}
// Omitting NetworkRange (invalid prefix) must also be a no-op.
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.Equal(t, before[p.ID], p.IP, "peer %s IP should not change when NetworkRange omitted", p.ID)
}
// Sanity: an actually different range still triggers reallocation.
newRange := netip.MustParsePrefix("100.99.0.0/16")
_, err = manager.UpdateAccountSettings(ctx, account.Id, userID, &types.Settings{
PeerLoginExpirationEnabled: true,
PeerLoginExpiration: types.DefaultPeerLoginExpiration,
NetworkRange: newRange,
Extra: &types.ExtraSettings{},
})
require.NoError(t, err)
peers, err = manager.Store.GetAccountPeers(ctx, store.LockingStrengthNone, account.Id, "", "")
require.NoError(t, err)
for _, p := range peers {
assert.True(t, newRange.Contains(p.IP), "peer %s should be in new range %s, got %s", p.ID, newRange, p.IP)
assert.NotEqual(t, before[p.ID], p.IP, "peer %s IP should change on real range update", p.ID)
}
}
func TestDefaultAccountManager_UpdateAccountSettings_IPv6EnabledGroups(t *testing.T) {
manager, _, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
ctx := context.Background()

View File

@@ -6,9 +6,7 @@ import (
"errors"
"fmt"
"net/http"
"net/url"
"os"
"path"
"strings"
"github.com/dexidp/dex/storage"
@@ -140,13 +138,10 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
return nil, fmt.Errorf("invalid IdP storage config: %w", err)
}
// Build CLI redirect URIs including the device callback. Dex uses the issuer-relative
// path (for example, /oauth2/device/callback) when completing the device flow, so
// include it explicitly in addition to the legacy bare path and absolute URL.
// Build CLI redirect URIs including the device callback (both relative and absolute)
cliRedirectURIs := c.CLIRedirectURIs
cliRedirectURIs = append(cliRedirectURIs, "/device/callback")
cliRedirectURIs = append(cliRedirectURIs, issuerRelativeDeviceCallback(c.Issuer))
cliRedirectURIs = append(cliRedirectURIs, strings.TrimSuffix(c.Issuer, "/")+"/device/callback")
cliRedirectURIs = append(cliRedirectURIs, c.Issuer+"/device/callback")
// Build dashboard redirect URIs including the OAuth callback for proxy authentication
dashboardRedirectURIs := c.DashboardRedirectURIs
@@ -159,10 +154,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
// MGMT api and the dashboard, adding baseURL means less configuration for the instance admin
dashboardPostLogoutRedirectURIs = append(dashboardPostLogoutRedirectURIs, baseURL)
redirectURIs := make([]string, 0)
redirectURIs = append(redirectURIs, cliRedirectURIs...)
redirectURIs = append(redirectURIs, dashboardRedirectURIs...)
cfg := &dex.YAMLConfig{
Issuer: c.Issuer,
Storage: dex.Storage{
@@ -188,14 +179,14 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
ID: staticClientDashboard,
Name: "NetBird Dashboard",
Public: true,
RedirectURIs: redirectURIs,
RedirectURIs: dashboardRedirectURIs,
PostLogoutRedirectURIs: sanitizePostLogoutRedirectURIs(dashboardPostLogoutRedirectURIs),
},
{
ID: staticClientCLI,
Name: "NetBird CLI",
Public: true,
RedirectURIs: redirectURIs,
RedirectURIs: cliRedirectURIs,
},
},
StaticConnectors: c.StaticConnectors,
@@ -226,14 +217,6 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) {
return cfg, nil
}
func issuerRelativeDeviceCallback(issuer string) string {
u, err := url.Parse(issuer)
if err != nil || u.Path == "" {
return "/device/callback"
}
return path.Join(u.Path, "/device/callback")
}
// Due to how the frontend generates the logout, sometimes it appends a trailing slash
// and because Dex only allows exact matches, we need to make sure we always have both
// versions of each provided uri
@@ -316,7 +299,7 @@ func resolveSessionCookieEncryptionKey(configuredKey string) (string, error) {
}
}
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key:%s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
return "", fmt.Errorf("invalid embedded IdP session cookie encryption key: %s (or sessionCookieEncryptionKey) must be 16, 24, or 32 bytes as a raw string or base64-encoded to one of those lengths; got %d raw bytes", sessionCookieEncryptionKeyEnv, len([]byte(key)))
}
func validSessionCookieEncryptionKeyLength(length int) bool {

View File

@@ -314,34 +314,6 @@ func TestEmbeddedIdPManager_UpdateUserPassword(t *testing.T) {
})
}
func TestEmbeddedIdPConfig_ToYAMLConfig_IncludesDeviceCallbackRedirectURI(t *testing.T) {
config := &EmbeddedIdPConfig{
Enabled: true,
Issuer: "https://example.com/oauth2",
Storage: EmbeddedStorageConfig{
Type: "sqlite3",
Config: EmbeddedStorageTypeConfig{
File: filepath.Join(t.TempDir(), "dex.db"),
},
},
}
yamlConfig, err := config.ToYAMLConfig()
require.NoError(t, err)
var cliRedirectURIs []string
for _, client := range yamlConfig.StaticClients {
if client.ID == staticClientCLI {
cliRedirectURIs = client.RedirectURIs
break
}
}
require.NotEmpty(t, cliRedirectURIs)
assert.Contains(t, cliRedirectURIs, "/device/callback")
assert.Contains(t, cliRedirectURIs, "/oauth2/device/callback")
assert.Contains(t, cliRedirectURIs, "https://example.com/oauth2/device/callback")
}
func TestEmbeddedIdPConfig_ToYAMLConfig_SessionCookieEncryptionKey(t *testing.T) {
t.Setenv(sessionCookieEncryptionKeyEnv, "")

View File

@@ -38,8 +38,7 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
@@ -228,14 +227,7 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
// disconnect events. Falls through to nil when no hook is set, which
// is the original behaviour.
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
return nil
}
@@ -336,21 +328,13 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {

View File

@@ -16,6 +16,7 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -28,7 +29,6 @@ import (
"github.com/netbirdio/netbird/management/server/activity"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -63,64 +63,56 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
// MarkPeerConnected marks a peer as connected with optimistic-locked
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
// is the start time of the gRPC sync stream that owns this update,
// expressed as Unix nanoseconds — only the call whose token is greater
// than what's stored wins. LastSeen is written by the database itself;
// we never pass it down.
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusConnect, time.Since(start))
}()
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
outcome := telemetry.PeerStatusError
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
outcome = telemetry.PeerStatusPeerNotFound
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, outcome)
return err
}
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusError)
return err
}
if !updated {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusStale)
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
return nil
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusConnect, telemetry.PeerStatusApplied)
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
expired := peer.Status != nil && peer.Status.LoginExpired
if peer.AddedWithSSOLogin() {
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
if err != nil {
return err
}
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
return err
})
if skipped {
return nil
}
if err != nil {
return err
}
if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
if expired {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
@@ -128,60 +120,41 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
// MarkPeerDisconnected marks a peer as disconnected, but only when the
// stored session token matches the one passed in. A mismatch means a
// newer stream has already taken ownership of the peer — disconnects from
// the older stream are ignored. LastSeen is written by the database.
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
start := time.Now()
defer func() {
am.metrics.AccountManagerMetrics().RecordPeerStatusUpdateDuration(telemetry.PeerStatusDisconnect, time.Since(start))
}()
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = syncTime
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {
newStatus.LoginExpired = false
}
peer.Status = newStatus
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
outcome := telemetry.PeerStatusError
if s, ok := status.FromError(err); ok && s.Type() == status.NotFound {
outcome = telemetry.PeerStatusPeerNotFound
if geo != nil && realIP != nil {
location, err := geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else {
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
err = transaction.SavePeerLocation(ctx, accountID, peer)
if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, outcome)
return err
}
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusError)
return err
}
if !updated {
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusStale)
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
peer.ID, sessionStartedAt)
return nil
}
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
return nil
}
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
}
location, err := am.geo.Lookup(realIP)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return
}
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
return false, err
}
return oldStatus.LoginExpired, nil
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.

View File

@@ -74,19 +74,8 @@ type ProxyMeta struct {
}
type PeerStatus struct { //nolint:revive
// LastSeen is the last time the peer status was updated (i.e. the last
// time we observed the peer being alive on a sync stream). Written by
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
// LastSeen is the last time peer was connected to the management service
LastSeen time.Time
// SessionStartedAt records when the currently-active sync stream began,
// stored as Unix nanoseconds. It acts as the optimistic-locking token
// for status updates: a stream is only allowed to mutate the peer's
// status when its own token strictly exceeds the stored token (when connecting)
// or matches it exactly (for disconnects). Zero means "no
// active session". Integer nanoseconds are used so equality is
// precision-safe across drivers, and so the predicates compose to a
// single bigint comparison.
SessionStartedAt int64
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
@@ -386,14 +375,10 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return meta
}
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
// reset the fencing token to zero — that would let any subsequent
// SavePeerStatus write reopen the optimistic-lock window.
// Copy PeerStatus
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
SessionStartedAt: p.SessionStartedAt,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
RequiresApproval: p.RequiresApproval,

View File

@@ -498,9 +498,8 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
peerCopy.Status = &peerStatus
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_session_started_at",
"peer_status_connected", "peer_status_login_expired",
"peer_status_requires_approval",
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
@@ -517,69 +516,6 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
return nil
}
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
// The peer is marked connected with the given session token only when
// the stored SessionStartedAt is strictly smaller than the incoming
// one — equivalently, when no newer stream has already taken ownership.
// The sentinel zero (set on peer creation or after a disconnect) counts
// as the smallest possible token. This is the write half of the
// fencing protocol described on PeerStatus.SessionStartedAt.
//
// The post-write side effects in the caller — geo lookup,
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
// OnPeersUpdated — all run AFTER this method returns and are deliberately
// outside the database write so they cannot extend the row-lock window.
//
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
// moment the row is written. The caller never supplies LastSeen because
// the value would otherwise drift under lock contention — a Go-side
// time.Now() taken before the write can land minutes later than the
// actual UPDATE under load, which previously caused real ordering bugs.
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at < ?", newSessionStartedAt).
Updates(map[string]any{
"peer_status_connected": true,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": newSessionStartedAt,
"peer_status_login_expired": false,
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
// The peer is marked disconnected only when the stored SessionStartedAt
// matches the incoming token — meaning the stream that owns the current
// session is the one ending. If a newer stream has already replaced the
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
//
// A zero sessionStartedAt is rejected at the call site; the underlying
// WHERE on equality would otherwise match every never-connected peer.
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
if sessionStartedAt == 0 {
return false, nil
}
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at = ?", sessionStartedAt).
Updates(map[string]any{
"peer_status_connected": false,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": int64(0),
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
@@ -1787,10 +1723,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
FROM peers WHERE account_id = $1`
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -1803,7 +1738,6 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
lastLogin, createdAt sql.NullTime
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
peerStatusLastSeen sql.NullTime
peerStatusSessionStartedAt sql.NullInt64
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
@@ -1818,9 +1752,8 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
&proxyEmbedded, &proxyCluster, &ipv6)
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
if err == nil {
if lastLogin.Valid {
@@ -1847,9 +1780,6 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
if peerStatusLastSeen.Valid {
p.Status.LastSeen = peerStatusLastSeen.Time
}
if peerStatusSessionStartedAt.Valid {
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
}
if peerStatusConnected.Valid {
p.Status.Connected = peerStatusConnected.Bool
}
@@ -3463,49 +3393,6 @@ func (s *SqlStore) GetAllEphemeralPeers(ctx context.Context, lockStrength Lockin
return allEphemeralPeers, nil
}
// GetStaleEphemeralPeerIDsForAccount returns IDs of disconnected
// ephemeral peers in the given account whose last_seen is strictly
// older than olderThan.
func (s *SqlStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
var ids []string
err := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where("account_id = ? AND ephemeral = ? AND peer_status_connected = ? AND peer_status_last_seen < ?",
accountID, true, false, olderThan).
Pluck("id", &ids).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to query stale ephemeral peers for account %s: %v", accountID, err)
return nil, status.Errorf(status.Internal, "query stale ephemeral peers")
}
return ids, nil
}
// GetEphemeralAccountsLastDisconnect returns the latest peer_status_last_seen
// per account across disconnected ephemeral peers. Returns one entry per
// account that has at least one such peer.
func (s *SqlStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
type row struct {
AccountID string
LastSeen time.Time
}
var rows []row
err := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Select("account_id, MAX(peer_status_last_seen) AS last_seen").
Where("ephemeral = ? AND peer_status_connected = ?", true, false).
Group("account_id").
Scan(&rows).Error
if err != nil {
log.WithContext(ctx).Errorf("failed to load ephemeral-account last disconnect map: %v", err)
return nil, status.Errorf(status.Internal, "load ephemeral accounts")
}
out := make(map[string]time.Time, len(rows))
for _, r := range rows {
out[r.AccountID] = r.LastSeen
}
return out, nil
}
// DeletePeer removes a peer from the store.
func (s *SqlStore) DeletePeer(ctx context.Context, accountID string, peerID string) error {
result := s.db.Delete(&nbpeer.Peer{}, accountAndIDQueryCondition, accountID, peerID)

View File

@@ -165,32 +165,8 @@ type Store interface {
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
// GetStaleEphemeralPeerIDsForAccount returns the IDs of disconnected
// ephemeral peers whose last_seen is strictly older than olderThan,
// scoped to a single account. Used by the per-account cleanup sweep.
GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error)
// GetEphemeralAccountsLastDisconnect returns, for every account that
// has at least one disconnected ephemeral peer, the most recent
// last_seen across that account's disconnected ephemeral peers. Used
// to reconstruct the per-account cleanup tracker after a restart.
GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
// given session token, but only when the stored SessionStartedAt is
// strictly less than newSessionStartedAt (the sentinel zero counts as
// "older"). LastSeen is recorded by the database at the moment the
// row is updated — never by the caller — so it always reflects the
// real write time even under lock contention.
// Returns true when the update happened, false when this stream lost
// the race against a newer session.
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
// resets SessionStartedAt to zero, but only when the stored
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -1376,36 +1376,6 @@ func (mr *MockStoreMockRecorder) GetAllEphemeralPeers(ctx, lockStrength interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAllEphemeralPeers", reflect.TypeOf((*MockStore)(nil).GetAllEphemeralPeers), ctx, lockStrength)
}
// GetStaleEphemeralPeerIDsForAccount mocks base method.
func (m *MockStore) GetStaleEphemeralPeerIDsForAccount(ctx context.Context, accountID string, olderThan time.Time) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetStaleEphemeralPeerIDsForAccount", ctx, accountID, olderThan)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetStaleEphemeralPeerIDsForAccount indicates an expected call of GetStaleEphemeralPeerIDsForAccount.
func (mr *MockStoreMockRecorder) GetStaleEphemeralPeerIDsForAccount(ctx, accountID, olderThan interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetStaleEphemeralPeerIDsForAccount", reflect.TypeOf((*MockStore)(nil).GetStaleEphemeralPeerIDsForAccount), ctx, accountID, olderThan)
}
// GetEphemeralAccountsLastDisconnect mocks base method.
func (m *MockStore) GetEphemeralAccountsLastDisconnect(ctx context.Context) (map[string]time.Time, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEphemeralAccountsLastDisconnect", ctx)
ret0, _ := ret[0].(map[string]time.Time)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEphemeralAccountsLastDisconnect indicates an expected call of GetEphemeralAccountsLastDisconnect.
func (mr *MockStoreMockRecorder) GetEphemeralAccountsLastDisconnect(ctx interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEphemeralAccountsLastDisconnect", reflect.TypeOf((*MockStore)(nil).GetEphemeralAccountsLastDisconnect), ctx)
}
// GetAllProxyAccessTokens mocks base method.
func (m *MockStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types2.ProxyAccessToken, error) {
m.ctrl.T.Helper()
@@ -2908,36 +2878,6 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
}
// MarkPeerConnectedIfNewerSession mocks base method.
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
}
// MarkPeerDisconnectedIfSameSession mocks base method.
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
}
// SavePolicy mocks base method.
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
m.ctrl.T.Helper()

View File

@@ -16,8 +16,6 @@ type AccountManagerMetrics struct {
getPeerNetworkMapDurationMs metric.Float64Histogram
networkMapObjectCount metric.Int64Histogram
peerMetaUpdateCount metric.Int64Counter
peerStatusUpdateCounter metric.Int64Counter
peerStatusUpdateDurationMs metric.Float64Histogram
}
// NewAccountManagerMetrics creates an instance of AccountManagerMetrics
@@ -66,24 +64,6 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
return nil, err
}
// peerStatusUpdateCounter records every attempt to mark a peer as connected or disconnected
peerStatusUpdateCounter, err := meter.Int64Counter("management.account.peer.status.update.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of peer status update attempts, labeled by operation (connect|disconnect) and outcome (applied|stale|error|peer_not_found)"))
if err != nil {
return nil, err
}
peerStatusUpdateDurationMs, err := meter.Float64Histogram("management.account.peer.status.update.duration.ms",
metric.WithUnit("milliseconds"),
metric.WithExplicitBucketBoundaries(
1, 5, 15, 25, 50, 100, 250, 500, 1000, 2000, 5000,
),
metric.WithDescription("Duration of a peer status update (fence UPDATE + post-write side effects), labeled by operation"))
if err != nil {
return nil, err
}
return &AccountManagerMetrics{
ctx: ctx,
getPeerNetworkMapDurationMs: getPeerNetworkMapDurationMs,
@@ -91,35 +71,10 @@ func NewAccountManagerMetrics(ctx context.Context, meter metric.Meter) (*Account
updateAccountPeersCounter: updateAccountPeersCounter,
networkMapObjectCount: networkMapObjectCount,
peerMetaUpdateCount: peerMetaUpdateCount,
peerStatusUpdateCounter: peerStatusUpdateCounter,
peerStatusUpdateDurationMs: peerStatusUpdateDurationMs,
}, nil
}
// PeerStatusOperation labels the kind of fence-locked peer status write.
type PeerStatusOperation string
// PeerStatusOutcome labels how a fence-locked peer status write resolved.
type PeerStatusOutcome string
const (
PeerStatusConnect PeerStatusOperation = "connect"
PeerStatusDisconnect PeerStatusOperation = "disconnect"
// PeerStatusApplied — the fence WHERE matched and the UPDATE landed.
PeerStatusApplied PeerStatusOutcome = "applied"
// PeerStatusStale — the fence WHERE rejected the write because a
// newer session has already taken ownership (connect: stored token
// >= incoming; disconnect: stored token != incoming).
PeerStatusStale PeerStatusOutcome = "stale"
// PeerStatusError — the store returned a non-NotFound error.
PeerStatusError PeerStatusOutcome = "error"
// PeerStatusPeerNotFound — the peer lookup failed (the peer was
// deleted between the gRPC sync handshake and the status write).
PeerStatusPeerNotFound PeerStatusOutcome = "peer_not_found"
)
// CountUpdateAccountPeersDuration counts the duration of updating account peers
func (metrics *AccountManagerMetrics) CountUpdateAccountPeersDuration(duration time.Duration) {
metrics.updateAccountPeersDurationMs.Record(metrics.ctx, float64(duration.Nanoseconds())/1e6)
@@ -149,23 +104,3 @@ func (metrics *AccountManagerMetrics) CountUpdateAccountPeersTriggered(resource,
func (metrics *AccountManagerMetrics) CountPeerMetUpdate() {
metrics.peerMetaUpdateCount.Add(metrics.ctx, 1)
}
// CountPeerStatusUpdate increments the connect/disconnect counter,
// labeled by operation and outcome. Both labels are bounded enums.
func (metrics *AccountManagerMetrics) CountPeerStatusUpdate(op PeerStatusOperation, outcome PeerStatusOutcome) {
metrics.peerStatusUpdateCounter.Add(metrics.ctx, 1,
metric.WithAttributes(
attribute.String("operation", string(op)),
attribute.String("outcome", string(outcome)),
),
)
}
// RecordPeerStatusUpdateDuration records the wall-clock time spent
// running a peer status update (including post-write side effects),
// labeled by operation.
func (metrics *AccountManagerMetrics) RecordPeerStatusUpdateDuration(op PeerStatusOperation, d time.Duration) {
metrics.peerStatusUpdateDurationMs.Record(metrics.ctx, float64(d.Nanoseconds())/1e6,
metric.WithAttributes(attribute.String("operation", string(op))),
)
}

View File

@@ -29,7 +29,6 @@ type MockAppMetrics struct {
StoreMetricsFunc func() *StoreMetrics
UpdateChannelMetricsFunc func() *UpdateChannelMetrics
AddAccountManagerMetricsFunc func() *AccountManagerMetrics
EphemeralPeersMetricsFunc func() *EphemeralPeersMetrics
}
// GetMeter mocks the GetMeter function of the AppMetrics interface
@@ -104,14 +103,6 @@ func (mock *MockAppMetrics) AccountManagerMetrics() *AccountManagerMetrics {
return nil
}
// EphemeralPeersMetrics mocks the MockAppMetrics function of the EphemeralPeersMetrics interface
func (mock *MockAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
if mock.EphemeralPeersMetricsFunc != nil {
return mock.EphemeralPeersMetricsFunc()
}
return nil
}
// AppMetrics is metrics interface
type AppMetrics interface {
GetMeter() metric2.Meter
@@ -123,7 +114,6 @@ type AppMetrics interface {
StoreMetrics() *StoreMetrics
UpdateChannelMetrics() *UpdateChannelMetrics
AccountManagerMetrics() *AccountManagerMetrics
EphemeralPeersMetrics() *EphemeralPeersMetrics
}
// defaultAppMetrics are core application metrics based on OpenTelemetry https://opentelemetry.io/
@@ -139,7 +129,6 @@ type defaultAppMetrics struct {
storeMetrics *StoreMetrics
updateChannelMetrics *UpdateChannelMetrics
accountManagerMetrics *AccountManagerMetrics
ephemeralMetrics *EphemeralPeersMetrics
}
// IDPMetrics returns metrics for the idp package
@@ -172,11 +161,6 @@ func (appMetrics *defaultAppMetrics) AccountManagerMetrics() *AccountManagerMetr
return appMetrics.accountManagerMetrics
}
// EphemeralPeersMetrics returns metrics for the ephemeral peer cleanup loop
func (appMetrics *defaultAppMetrics) EphemeralPeersMetrics() *EphemeralPeersMetrics {
return appMetrics.ephemeralMetrics
}
// Close stop application metrics HTTP handler and closes listener.
func (appMetrics *defaultAppMetrics) Close() error {
if appMetrics.listener == nil {
@@ -261,11 +245,6 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
@@ -275,7 +254,6 @@ func NewDefaultAppMetrics(ctx context.Context) (AppMetrics, error) {
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
ephemeralMetrics: ephemeralMetrics,
}, nil
}
@@ -312,11 +290,6 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
return nil, fmt.Errorf("failed to initialize account manager metrics: %w", err)
}
ephemeralMetrics, err := NewEphemeralPeersMetrics(ctx, meter)
if err != nil {
return nil, fmt.Errorf("failed to initialize ephemeral peers metrics: %w", err)
}
return &defaultAppMetrics{
Meter: meter,
ctx: ctx,
@@ -327,6 +300,5 @@ func NewAppMetricsWithMeter(ctx context.Context, meter metric2.Meter) (AppMetric
storeMetrics: storeMetrics,
updateChannelMetrics: updateChannelMetrics,
accountManagerMetrics: accountManagerMetrics,
ephemeralMetrics: ephemeralMetrics,
}, nil
}

View File

@@ -1,115 +0,0 @@
package telemetry
import (
"context"
"go.opentelemetry.io/otel/metric"
)
// EphemeralPeersMetrics tracks the ephemeral peer cleanup pipeline: how
// many accounts are currently being tracked for cleanup, how many sweep
// runs deleted at least one peer, how many peers have been removed, and
// how many delete batches failed.
type EphemeralPeersMetrics struct {
ctx context.Context
pending metric.Int64UpDownCounter
cleanupRuns metric.Int64Counter
peersCleaned metric.Int64Counter
errors metric.Int64Counter
}
// NewEphemeralPeersMetrics constructs the ephemeral cleanup counters.
func NewEphemeralPeersMetrics(ctx context.Context, meter metric.Meter) (*EphemeralPeersMetrics, error) {
pending, err := meter.Int64UpDownCounter("management.ephemeral.accounts.tracked",
metric.WithUnit("1"),
metric.WithDescription("Number of accounts currently tracked for ephemeral peer cleanup"))
if err != nil {
return nil, err
}
cleanupRuns, err := meter.Int64Counter("management.ephemeral.cleanup.runs.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral cleanup sweeps that deleted at least one peer"))
if err != nil {
return nil, err
}
peersCleaned, err := meter.Int64Counter("management.ephemeral.peers.cleaned.counter",
metric.WithUnit("1"),
metric.WithDescription("Total number of ephemeral peers deleted by the cleanup loop"))
if err != nil {
return nil, err
}
errors, err := meter.Int64Counter("management.ephemeral.cleanup.errors.counter",
metric.WithUnit("1"),
metric.WithDescription("Number of ephemeral cleanup batches (per account) that failed to delete"))
if err != nil {
return nil, err
}
return &EphemeralPeersMetrics{
ctx: ctx,
pending: pending,
cleanupRuns: cleanupRuns,
peersCleaned: peersCleaned,
errors: errors,
}, nil
}
// All methods are nil-receiver safe so callers that haven't wired metrics
// (tests, self-hosted with metrics off) can invoke them unconditionally.
// IncPending bumps the tracked-accounts gauge when a new account
// becomes eligible for ephemeral cleanup tracking.
func (m *EphemeralPeersMetrics) IncPending() {
if m == nil {
return
}
m.pending.Add(m.ctx, 1)
}
// AddPending bumps the tracked-accounts gauge by n — used at startup
// when the catch-up query seeds the tracker.
func (m *EphemeralPeersMetrics) AddPending(n int64) {
if m == nil || n <= 0 {
return
}
m.pending.Add(m.ctx, n)
}
// DecPending decreases the tracked-accounts gauge when an account is
// dropped from the tracker (no more disconnects to chase).
func (m *EphemeralPeersMetrics) DecPending(n int64) {
if m == nil || n <= 0 {
return
}
m.pending.Add(m.ctx, -n)
}
// CountCleanupRun records one cleanup pass that processed >0 peers. Idle
// ticks (nothing to do) deliberately don't increment so the rate
// reflects useful work.
func (m *EphemeralPeersMetrics) CountCleanupRun() {
if m == nil {
return
}
m.cleanupRuns.Add(m.ctx, 1)
}
// CountPeersCleaned records the number of peers a single tick deleted.
func (m *EphemeralPeersMetrics) CountPeersCleaned(n int64) {
if m == nil || n <= 0 {
return
}
m.peersCleaned.Add(m.ctx, n)
}
// CountCleanupError records a failed delete batch.
func (m *EphemeralPeersMetrics) CountCleanupError() {
if m == nil {
return
}
m.errors.Add(m.ctx, 1)
}

View File

@@ -892,7 +892,12 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
// Auth is collected when this peer serves the rule. For bidirectional
// rules the peer-in-sources side also serves inbound traffic, so it
// must be treated as a destination too.
peerServesAuth := peerInDestinations || (rule.Bidirectional && peerInSources)
if peerServesAuth && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
@@ -924,7 +929,7 @@ func (a *Account) GetPeerConnectionResources(ctx context.Context, peer *nbpeer.P
default:
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
} else if peerServesAuth && policyRuleImpliesLegacySSH(rule) && peer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = a.getAllowedUserIDs()
}

View File

@@ -341,7 +341,12 @@ func (a *Account) getPeersGroupsPoliciesRoutes(
for _, srcGroupID := range rule.Sources {
relevantGroupIDs[srcGroupID] = a.GetGroup(srcGroupID)
}
}
// SSH auth requirements are gathered whenever this peer serves
// the rule. For bidirectional rules the peer-in-sources side
// also serves inbound traffic and must be treated as a destination.
if peerInDestinations || (rule.Bidirectional && peerInSources) {
if rule.Protocol == PolicyRuleProtocolNetbirdSSH {
switch {
case len(rule.AuthorizedGroups) > 0:

View File

@@ -221,7 +221,12 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
generateResources(rule, sourcePeers, FirewallRuleDirectionIN)
}
if peerInDestinations && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
// Auth is collected when this peer serves the rule. For bidirectional
// rules the peer-in-sources side also serves inbound traffic, so it
// must be treated as a destination too.
peerServesAuth := peerInDestinations || (rule.Bidirectional && peerInSources)
if peerServesAuth && rule.Protocol == PolicyRuleProtocolNetbirdSSH {
sshEnabled = true
switch {
case len(rule.AuthorizedGroups) > 0:
@@ -252,7 +257,7 @@ func (c *NetworkMapComponents) getPeerConnectionResources(targetPeerID string) (
default:
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
} else if peerInDestinations && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
} else if peerServesAuth && policyRuleImpliesLegacySSH(rule) && targetPeer.SSHEnabled {
sshEnabled = true
authorizedUsers[auth.Wildcard] = c.getAllowedUserIDs()
}
@@ -557,7 +562,6 @@ func (c *NetworkMapComponents) getRoutingPeerRoutes(peerID string) (enabledRoute
return enabledRoutes, disabledRoutes
}
func (c *NetworkMapComponents) filterRoutesByGroups(routes []*route.Route, groupListMap LookupMap) []*route.Route {
var filteredRoutes []*route.Route
for _, r := range routes {

View File

@@ -980,6 +980,44 @@ func TestComponents_SSHAuthorizedUsersContent(t *testing.T) {
assert.True(t, hasRoot || hasAdmin, "AuthorizedUsers should contain 'root' or 'admin' machine user mapping")
}
// TestComponents_SSHAuthorizedUsersBidirectionalSource verifies that a peer
// on the sources side of a bidirectional NetbirdSSH rule receives the rule's
// authorized users. The reverse direction (destinations -> sources) makes
// the source-side peer a destination too, so it must be able to authorize
// inbound SSH from the rule's destinations.
func TestComponents_SSHAuthorizedUsersBidirectionalSource(t *testing.T) {
account, validatedPeers := scalableTestAccountWithoutDefaultPolicy(20, 2)
account.Users["user-dev"] = &types.User{Id: "user-dev", Role: types.UserRoleUser, AccountID: "test-account", AutoGroups: []string{"ssh-users"}}
account.Groups["ssh-users"] = &types.Group{ID: "ssh-users", Name: "SSH Users", Peers: []string{}}
account.Policies = append(account.Policies, &types.Policy{
ID: "policy-ssh-bidir", Name: "Bidirectional SSH", Enabled: true, AccountID: "test-account",
Rules: []*types.PolicyRule{{
ID: "rule-ssh-bidir", Name: "SSH both ways", Enabled: true,
Action: types.PolicyTrafficActionAccept, Protocol: types.PolicyRuleProtocolNetbirdSSH,
Bidirectional: true,
Sources: []string{"group-0"}, Destinations: []string{"group-1"},
AuthorizedGroups: map[string][]string{"ssh-users": {"root"}},
}},
})
nmSrc := componentsNetworkMap(account, "peer-0", validatedPeers)
require.NotNil(t, nmSrc)
assert.True(t, nmSrc.EnableSSH, "source-side peer of bidirectional SSH rule should have SSH enabled")
require.NotEmpty(t, nmSrc.AuthorizedUsers, "source-side peer should receive authorized users from bidirectional rule")
rootUsers, hasRoot := nmSrc.AuthorizedUsers["root"]
require.True(t, hasRoot, "source-side peer should map the 'root' local user")
_, hasDev := rootUsers["user-dev"]
assert.True(t, hasDev, "source-side peer should include 'user-dev' under 'root'")
nmDst := componentsNetworkMap(account, "peer-10", validatedPeers)
require.NotNil(t, nmDst)
assert.True(t, nmDst.EnableSSH, "destination-side peer should also have SSH enabled")
_, hasRoot = nmDst.AuthorizedUsers["root"]
assert.True(t, hasRoot, "destination-side peer should also map the 'root' local user")
}
// TestComponents_SSHLegacyImpliedSSH verifies that a non-SSH ALL protocol policy with
// SSHEnabled peer implies legacy SSH access.
func TestComponents_SSHLegacyImpliedSSH(t *testing.T) {