[management] Fence peer status updates with a session token (#6193)

* [management] Fence peer status updates with a session token

The connect/disconnect path used a best-effort LastSeen-after-streamStart
comparison to decide whether a status update should land. Under contention
— a re-sync arriving while the previous stream's disconnect was still in
flight, or two management replicas seeing the same peer at once — the
check was a read-then-decide-then-write window: any UPDATE in between
caused the wrong row to be written. The Go-side time.Now() that fed the
comparison also drifted under lock contention, since it was captured
seconds before the write actually committed.

Replace it with an integer-nanosecond fencing token stored alongside the
status. Every gRPC sync stream uses its open time (UnixNano) as its token.
Connects only land when the incoming token is strictly greater than the
stored one; disconnects only land when the incoming token equals the
stored one (i.e. we're the stream that owns the current session). Both
are single optimistic-locked UPDATEs — no read-then-write, no transaction
wrapper.

LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The
caller never supplies it, so the value always reflects the real moment
of the UPDATE rather than the moment the caller queued the work — which
was already off by minutes under heavy lock contention.

Side effects (geo lookup, peer-login-expiration scheduling, network-map
fan-out) are explicitly documented as running after the fence UPDATE
commits, never inside it. Geo also skips the update when realIP equals
the stored ConnectionIP, dropping a redundant SavePeerLocation call on
same-IP reconnects.

Tests cover the three semantic cases (matched disconnect lands, stale
disconnect dropped, stale connect dropped) plus a 16-goroutine race test
that asserts the highest token always wins.

* [management] Add SessionStartedAt to peer status updates

Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety.

* Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields
This commit is contained in:
Maycon Santos
2026-05-18 20:25:12 +02:00
committed by GitHub
parent 705f87fc20
commit 13d32d274f
10 changed files with 354 additions and 118 deletions

View File

@@ -1868,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu
return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain
}
// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's
// network map and then marks the peer connected with a session token
// derived from syncTime (the moment the gRPC stream opened). Any
// concurrent stream that started earlier loses the optimistic-lock race
// in MarkPeerConnected and bails without writing.
func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err)
}
err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime)
if err != nil {
if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err)
}
return peer, netMap, postureChecks, dnsfwdPort, nil
}
// OnPeerDisconnected is invoked when a sync stream ends. It marks the
// peer disconnected only when the stored SessionStartedAt matches the
// nanosecond token derived from streamStartTime — i.e. only when this
// is the stream that currently owns the peer's session. A mismatch
// means a newer stream has already replaced us, so the disconnect is
// dropped.
func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err)
return nil
}
if peer.Status.LastSeen.After(streamStartTime) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect",
peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339))
return nil
}
err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC())
if err != nil {
if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil {
log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err)
}
return nil

View File

@@ -61,7 +61,8 @@ type Manager interface {
GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsers(ctx context.Context, accountID string) ([]*types.User, error)
GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error
MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
DeletePeer(ctx context.Context, accountID, peerID, userID string) error
UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error)
UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error

View File

@@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal
}
// MarkPeerConnected mocks base method.
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime)
ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerConnected indicates an expected call of MarkPeerConnected.
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt)
}
// MarkPeerDisconnected mocks base method.
func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt)
ret0, _ := ret[0].(error)
return ret0
}
// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected.
func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt)
}
// OnPeerDisconnected mocks base method.

View File

@@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
@@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
require.NoError(t, err, "unable to get the account")
// when we mark peer as connected, the peer login expiration routine should trigger
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
failed := waitTimeout(wg, time.Second)
@@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
}, false)
require.NoError(t, err, "unable to add peer")
t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
t.Run("disconnect peer when session token matches", func(t *testing.T) {
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err, "unable to get peer")
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := time.Now().UTC()
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal the token we passed in")
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
require.NoError(t, err)
@@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) {
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.False(t, peer.Status.Connected, "peer should be disconnected")
require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0")
})
t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC())
t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) {
// Newer stream wins on connect (sets SessionStartedAt = now ns).
streamStartTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano())
require.NoError(t, err, "unable to mark peer connected")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour)
// Older stream tries to mark disconnect with its own (older) session token —
// fencing kicks in and the write is dropped.
staleStreamStartTime := streamStartTime.Add(-1 * time.Hour)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime)
err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime)
require.NoError(t, err)
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected,
"peer should remain connected because LastSeen > streamStartTime (zombie stream protection)")
"peer should remain connected because the stored session is newer than the disconnect token")
require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should still hold the winning stream's token")
})
t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) {
t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) {
node2SyncTime := time.Now().UTC()
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano())
require.NoError(t, err, "node 2 should connect peer")
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should equal node2SyncTime token")
node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime)
err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano())
require.NoError(t, err, "stale connect should not return error")
peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should still be connected")
require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(),
"LastSeen should NOT be overwritten by stale syncTime from blocked goroutine")
require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt,
"SessionStartedAt should NOT be overwritten by stale token from blocked goroutine")
})
}
// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the
// fencing protocol under contention: many goroutines race to mark the
// same peer connected with distinct session tokens at the same time.
// The contract is that the highest token always wins and is what remains
// in the store, regardless of execution order.
func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID})
require.NoError(t, err, "unable to get account")
key, err := wgtypes.GenerateKey()
require.NoError(t, err, "unable to generate WireGuard key")
peerPubKey := key.PublicKey().String()
_, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{
Key: peerPubKey,
Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"},
}, false)
require.NoError(t, err, "unable to add peer")
const workers = 16
base := time.Now().UTC().UnixNano()
tokens := make([]int64, workers)
for i := range tokens {
// Spread tokens by 1ms so the comparison is unambiguous; the
// largest is index workers-1.
tokens[i] = base + int64(i)*int64(time.Millisecond)
}
expected := tokens[workers-1]
var ready sync.WaitGroup
ready.Add(workers)
var start sync.WaitGroup
start.Add(1)
var done sync.WaitGroup
done.Add(workers)
// require.* calls t.FailNow which is documented as unsafe from
// non-test goroutines (it calls runtime.Goexit on the wrong stack and
// races with the WaitGroup). Collect errors here and assert from the
// main goroutine after done.Wait().
errs := make(chan error, workers)
for i := 0; i < workers; i++ {
token := tokens[i]
go func() {
defer done.Done()
ready.Done()
start.Wait()
errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token)
}()
}
ready.Wait()
start.Done()
done.Wait()
close(errs)
for err := range errs {
require.NoError(t, err, "MarkPeerConnected must not error under contention")
}
peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey)
require.NoError(t, err)
require.True(t, peer.Status.Connected, "peer should be connected after the race")
require.Equal(t, expected, peer.Status.SessionStartedAt,
"the largest token must win regardless of execution order")
}
func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) {
manager, _, err := createManager(t)
require.NoError(t, err, "unable to create account manager")
@@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
account, err := manager.Store.GetAccount(context.Background(), accountID)
require.NoError(t, err, "unable to get the account")
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC())
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano())
require.NoError(t, err, "unable to mark peer connected")
wg := &sync.WaitGroup{}

View File

@@ -38,7 +38,8 @@ type MockAccountManager struct {
GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error)
ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error)
GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error)
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error
MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error
MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error
SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error
GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error)
@@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str
return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error {
// Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing
// hook so tests that inject MarkPeerDisconnectedFunc actually observe
// disconnect events. Falls through to nil when no hook is set, which
// is the original behaviour.
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano())
}
return nil
}
@@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth
}
// MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
if am.MarkPeerConnectedFunc != nil {
return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime)
return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented")
}
// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface
func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error {
if am.MarkPeerDisconnectedFunc != nil {
return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt)
}
return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented")
}
// DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface
func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error {
if am.DeleteAccountFunc != nil {

View File

@@ -16,7 +16,6 @@ import (
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/geolocation"
"github.com/netbirdio/netbird/management/server/idp"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
@@ -63,56 +62,51 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID
return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID)
}
// MarkPeerConnected marks peer as connected (true) or disconnected (false)
// syncTime is used as the LastSeen timestamp and for stale request detection
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error {
var peer *nbpeer.Peer
var settings *types.Settings
var expired bool
var err error
var skipped bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey)
if err != nil {
return err
}
if connected && !syncTime.After(peer.Status.LastSeen) {
log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect",
peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339))
skipped = true
return nil
}
expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime)
return err
})
if skipped {
return nil
}
// MarkPeerConnected marks a peer as connected with optimistic-locked
// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument
// is the start time of the gRPC sync stream that owns this update,
// expressed as Unix nanoseconds — only the call whose token is greater
// than what's stored wins. LastSeen is written by the database itself;
// we never pass it down.
//
// Disconnects use MarkPeerDisconnected and require the session to match
// exactly; see PeerStatus.SessionStartedAt for the protocol.
func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
return err
}
updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
return err
}
if !updated {
log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID)
return nil
}
if am.geo != nil && realIP != nil {
am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP)
}
expired := peer.Status != nil && peer.Status.LoginExpired
if peer.AddedWithSSOLogin() {
settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled {
am.schedulePeerLoginExpiration(ctx, accountID)
}
if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled {
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
}
}
if expired {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
if err != nil {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
@@ -120,41 +114,46 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
return nil
}
func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) {
oldStatus := peer.Status.Copy()
newStatus := oldStatus
newStatus.LastSeen = syncTime
newStatus.Connected = connected
// whenever peer got connected that means that it logged in successfully
if newStatus.Connected {
newStatus.LoginExpired = false
}
peer.Status = newStatus
if geo != nil && realIP != nil {
location, err := geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
} else {
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
err = transaction.SavePeerLocation(ctx, accountID, peer)
if err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
}
log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected)
err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus)
// MarkPeerDisconnected marks a peer as disconnected, but only when the
// stored session token matches the one passed in. A mismatch means a
// newer stream has already taken ownership of the peer — disconnects from
// the older stream are ignored. LastSeen is written by the database.
func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error {
peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey)
if err != nil {
return false, err
return err
}
return oldStatus.LoginExpired, nil
updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt)
if err != nil {
return err
}
if !updated {
log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping",
peer.ID, sessionStartedAt)
}
return nil
}
// updatePeerLocationIfChanged refreshes the geolocation on a separate
// row update, only when the connection IP actually changed. Geo lookups
// are expensive so we skip same-IP reconnects.
func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) {
if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) {
return
}
location, err := am.geo.Lookup(realIP)
if err != nil {
log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err)
return
}
peer.Location.ConnectionIP = realIP
peer.Location.CountryCode = location.Country.ISOCode
peer.Location.CityName = location.City.Names.En
peer.Location.GeoNameID = location.City.GeonameID
if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil {
log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err)
}
}
// UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated.

View File

@@ -74,8 +74,19 @@ type ProxyMeta struct {
}
type PeerStatus struct { //nolint:revive
// LastSeen is the last time peer was connected to the management service
// LastSeen is the last time the peer status was updated (i.e. the last
// time we observed the peer being alive on a sync stream). Written by
// the database (CURRENT_TIMESTAMP) — callers do not supply it.
LastSeen time.Time
// SessionStartedAt records when the currently-active sync stream began,
// stored as Unix nanoseconds. It acts as the optimistic-locking token
// for status updates: a stream is only allowed to mutate the peer's
// status when its own token strictly exceeds the stored token (when connecting)
// or matches it exactly (for disconnects). Zero means "no
// active session". Integer nanoseconds are used so equality is
// precision-safe across drivers, and so the predicates compose to a
// single bigint comparison.
SessionStartedAt int64
// Connected indicates whether peer is connected to the management service or not
Connected bool
// LoginExpired
@@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any {
return meta
}
// Copy PeerStatus
// Copy PeerStatus. SessionStartedAt must be propagated so clone-based
// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently
// reset the fencing token to zero — that would let any subsequent
// SavePeerStatus write reopen the optimistic-lock window.
func (p *PeerStatus) Copy() *PeerStatus {
return &PeerStatus{
LastSeen: p.LastSeen,
SessionStartedAt: p.SessionStartedAt,
Connected: p.Connected,
LoginExpired: p.LoginExpired,
RequiresApproval: p.RequiresApproval,

View File

@@ -498,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
peerCopy.Status = &peerStatus
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
"peer_status_last_seen", "peer_status_session_started_at",
"peer_status_connected", "peer_status_login_expired",
"peer_status_requires_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
@@ -516,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string,
return nil
}
// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update.
// The peer is marked connected with the given session token only when
// the stored SessionStartedAt is strictly smaller than the incoming
// one — equivalently, when no newer stream has already taken ownership.
// The sentinel zero (set on peer creation or after a disconnect) counts
// as the smallest possible token. This is the write half of the
// fencing protocol described on PeerStatus.SessionStartedAt.
//
// The post-write side effects in the caller — geo lookup,
// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration,
// OnPeersUpdated — all run AFTER this method returns and are deliberately
// outside the database write so they cannot extend the row-lock window.
//
// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the
// moment the row is written. The caller never supplies LastSeen because
// the value would otherwise drift under lock contention — a Go-side
// time.Now() taken before the write can land minutes later than the
// actual UPDATE under load, which previously caused real ordering bugs.
func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at < ?", newSessionStartedAt).
Updates(map[string]any{
"peer_status_connected": true,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": newSessionStartedAt,
"peer_status_login_expired": false,
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update.
// The peer is marked disconnected only when the stored SessionStartedAt
// matches the incoming token — meaning the stream that owns the current
// session is the one ending. If a newer stream has already replaced the
// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at
// write time; see MarkPeerConnectedIfNewerSession for the rationale.
//
// A zero sessionStartedAt is rejected at the call site; the underlying
// WHERE on equality would otherwise match every never-connected peer.
func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
if sessionStartedAt == 0 {
return false, nil
}
result := s.db.WithContext(ctx).
Model(&nbpeer.Peer{}).
Where(accountAndIDQueryCondition, accountID, peerID).
Where("peer_status_session_started_at = ?", sessionStartedAt).
Updates(map[string]any{
"peer_status_connected": false,
"peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"),
"peer_status_session_started_at": int64(0),
})
if result.Error != nil {
return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error)
}
return result.RowsAffected > 0, nil
}
func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error {
// To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields.
var peerCopy nbpeer.Peer
@@ -1723,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname,
meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version,
meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer,
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired,
peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name,
location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1`
meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at,
peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip,
location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6
FROM peers WHERE account_id = $1`
rows, err := s.pool.Query(ctx, query, accountID)
if err != nil {
return nil, err
@@ -1738,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
lastLogin, createdAt sql.NullTime
sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool
peerStatusLastSeen sql.NullTime
peerStatusSessionStartedAt sql.NullInt64
peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool
ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte
metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString
@@ -1752,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
&allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform,
&metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr,
&metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities,
&peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP,
&locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6)
&peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired,
&peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID,
&proxyEmbedded, &proxyCluster, &ipv6)
if err == nil {
if lastLogin.Valid {
@@ -1780,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee
if peerStatusLastSeen.Valid {
p.Status.LastSeen = peerStatusLastSeen.Time
}
if peerStatusSessionStartedAt.Valid {
p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64
}
if peerStatusConnected.Valid {
p.Status.Connected = peerStatusConnected.Bool
}

View File

@@ -167,6 +167,21 @@ type Store interface {
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)
SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error
SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error
// MarkPeerConnectedIfNewerSession sets the peer to connected with the
// given session token, but only when the stored SessionStartedAt is
// strictly less than newSessionStartedAt (the sentinel zero counts as
// "older"). LastSeen is recorded by the database at the moment the
// row is updated — never by the caller — so it always reflects the
// real write time even under lock contention.
// Returns true when the update happened, false when this stream lost
// the race against a newer session.
MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error)
// MarkPeerDisconnectedIfSameSession sets the peer to disconnected and
// resets SessionStartedAt to zero, but only when the stored
// SessionStartedAt equals the given sessionStartedAt. LastSeen is
// recorded by the database. Returns true when the update happened,
// false when a newer session has taken over.
MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error)
SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error
ApproveAccountPeers(ctx context.Context, accountID string) (int, error)
DeletePeer(ctx context.Context, accountID string, peerID string) error

View File

@@ -2878,6 +2878,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status)
}
// MarkPeerConnectedIfNewerSession mocks base method.
func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession.
func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt)
}
// MarkPeerDisconnectedIfSameSession mocks base method.
func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt)
ret0, _ := ret[0].(bool)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession.
func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt)
}
// SavePolicy mocks base method.
func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error {
m.ctrl.T.Helper()