From 13d32d274f74b700557f8f6a615f56be2ab9c6a5 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 18 May 2026 20:25:12 +0200 Subject: [PATCH] [management] Fence peer status updates with a session token (#6193) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [management] Fence peer status updates with a session token The connect/disconnect path used a best-effort LastSeen-after-streamStart comparison to decide whether a status update should land. Under contention — a re-sync arriving while the previous stream's disconnect was still in flight, or two management replicas seeing the same peer at once — the check was a read-then-decide-then-write window: any UPDATE in between caused the wrong row to be written. The Go-side time.Now() that fed the comparison also drifted under lock contention, since it was captured seconds before the write actually committed. Replace it with an integer-nanosecond fencing token stored alongside the status. Every gRPC sync stream uses its open time (UnixNano) as its token. Connects only land when the incoming token is strictly greater than the stored one; disconnects only land when the incoming token equals the stored one (i.e. we're the stream that owns the current session). Both are single optimistic-locked UPDATEs — no read-then-write, no transaction wrapper. LastSeen is now written by the database itself (CURRENT_TIMESTAMP). The caller never supplies it, so the value always reflects the real moment of the UPDATE rather than the moment the caller queued the work — which was already off by minutes under heavy lock contention. Side effects (geo lookup, peer-login-expiration scheduling, network-map fan-out) are explicitly documented as running after the fence UPDATE commits, never inside it. Geo also skips the update when realIP equals the stored ConnectionIP, dropping a redundant SavePeerLocation call on same-IP reconnects. Tests cover the three semantic cases (matched disconnect lands, stale disconnect dropped, stale connect dropped) plus a 16-goroutine race test that asserts the highest token always wins. * [management] Add SessionStartedAt to peer status updates Stored `SessionStartedAt` for fencing token propagation across goroutines and updated database queries/functions to handle the new field. Removed outdated geolocation handling logic and adjusted tests for concurrency safety. * Rename `peer_status_required_approval` to `peer_status_requires_approval` in SQL store fields --- management/server/account.go | 29 ++-- management/server/account/manager.go | 3 +- management/server/account/manager_mock.go | 22 ++- management/server/account_test.go | 115 ++++++++++++--- management/server/mock_server/account_mock.go | 24 +++- management/server/peer.go | 131 +++++++++--------- management/server/peer/peer.go | 19 ++- management/server/store/sql_store.go | 84 ++++++++++- management/server/store/store.go | 15 ++ management/server/store/store_mock.go | 30 ++++ 10 files changed, 354 insertions(+), 118 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index e7b4acaac..8e4e595f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1868,35 +1868,32 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } +// SyncAndMarkPeer is the per-Sync entry point: it refreshes the peer's +// network map and then marks the peer connected with a session token +// derived from syncTime (the moment the gRPC stream opened). Any +// concurrent stream that started earlier loses the optimistic-lock race +// in MarkPeerConnected and bails without writing. func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) - if err != nil { + if err := am.MarkPeerConnected(ctx, peerPubKey, realIP, accountID, syncTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } return peer, netMap, postureChecks, dnsfwdPort, nil } +// OnPeerDisconnected is invoked when a sync stream ends. It marks the +// peer disconnected only when the stored SessionStartedAt matches the +// nanosecond token derived from streamStartTime — i.e. only when this +// is the stream that currently owns the peer's session. A mismatch +// means a newer stream has already replaced us, so the disconnect is +// dropped. func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) - if err != nil { - log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) - return nil - } - - if peer.Status.LastSeen.After(streamStartTime) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", - peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) - return nil - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) - if err != nil { + if err := am.MarkPeerDisconnected(ctx, peerPubKey, accountID, streamStartTime.UnixNano()); err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } return nil diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 71af0645c..ae3de8d79 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -61,7 +61,8 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error + MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, p *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error diff --git a/management/server/account/manager_mock.go b/management/server/account/manager_mock.go index 7ffc41d73..0486e63ec 100644 --- a/management/server/account/manager_mock.go +++ b/management/server/account/manager_mock.go @@ -1305,17 +1305,31 @@ func (mr *MockManagerMockRecorder) LoginPeer(ctx, login interface{}) *gomock.Cal } // MarkPeerConnected mocks base method. -func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (m *MockManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, connected, realIP, accountID, syncTime) + ret := m.ctrl.Call(m, "MarkPeerConnected", ctx, peerKey, realIP, accountID, sessionStartedAt) ret0, _ := ret[0].(error) return ret0 } // MarkPeerConnected indicates an expected call of MarkPeerConnected. -func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, connected, realIP, accountID, syncTime interface{}) *gomock.Call { +func (mr *MockManagerMockRecorder) MarkPeerConnected(ctx, peerKey, realIP, accountID, sessionStartedAt interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, connected, realIP, accountID, syncTime) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnected", reflect.TypeOf((*MockManager)(nil).MarkPeerConnected), ctx, peerKey, realIP, accountID, sessionStartedAt) +} + +// MarkPeerDisconnected mocks base method. +func (m *MockManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnected", ctx, peerKey, accountID, sessionStartedAt) + ret0, _ := ret[0].(error) + return ret0 +} + +// MarkPeerDisconnected indicates an expected call of MarkPeerDisconnected. +func (mr *MockManagerMockRecorder) MarkPeerDisconnected(ctx, peerKey, accountID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnected", reflect.TypeOf((*MockManager)(nil).MarkPeerDisconnected), ctx, peerKey, accountID, sessionStartedAt) } // OnPeerDisconnected mocks base method. diff --git a/management/server/account_test.go b/management/server/account_test.go index 60720faa6..ba621030c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1813,7 +1813,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1884,7 +1884,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1910,15 +1910,16 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }, false) require.NoError(t, err, "unable to add peer") - t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("disconnect peer when session token matches", func(t *testing.T) { + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err, "unable to get peer") require.True(t, peer.Status.Connected, "peer should be connected") - - streamStartTime := time.Now().UTC() + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal the token we passed in") err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) require.NoError(t, err) @@ -1926,49 +1927,127 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.False(t, peer.Status.Connected, "peer should be disconnected") + require.Equal(t, int64(0), peer.Status.SessionStartedAt, "SessionStartedAt should be reset to 0") }) - t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) + t.Run("skip disconnect when stored session is newer (zombie stream protection)", func(t *testing.T) { + // Newer stream wins on connect (sets SessionStartedAt = now ns). + streamStartTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, streamStartTime.UnixNano()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + // Older stream tries to mark disconnect with its own (older) session token — + // fencing kicks in and the write is dropped. + staleStreamStartTime := streamStartTime.Add(-1 * time.Hour) - err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, staleStreamStartTime) require.NoError(t, err) peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, - "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + "peer should remain connected because the stored session is newer than the disconnect token") + require.Equal(t, streamStartTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should still hold the winning stream's token") }) - t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + t.Run("skip stale connect when stored session is newer (blocked goroutine protection)", func(t *testing.T) { node2SyncTime := time.Now().UTC() - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node2SyncTime.UnixNano()) require.NoError(t, err, "node 2 should connect peer") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should equal node2SyncTime token") node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, node1StaleSyncTime.UnixNano()) require.NoError(t, err, "stale connect should not return error") peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) require.NoError(t, err) require.True(t, peer.Status.Connected, "peer should still be connected") - require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), - "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + require.Equal(t, node2SyncTime.UnixNano(), peer.Status.SessionStartedAt, + "SessionStartedAt should NOT be overwritten by stale token from blocked goroutine") }) } +// TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace exercises the +// fencing protocol under contention: many goroutines race to mark the +// same peer connected with distinct session tokens at the same time. +// The contract is that the highest token always wins and is what remains +// in the store, regardless of execution order. +func TestDefaultAccountManager_MarkPeerConnected_ConcurrentRace(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to get account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "race-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + const workers = 16 + base := time.Now().UTC().UnixNano() + tokens := make([]int64, workers) + for i := range tokens { + // Spread tokens by 1ms so the comparison is unambiguous; the + // largest is index workers-1. + tokens[i] = base + int64(i)*int64(time.Millisecond) + } + expected := tokens[workers-1] + + var ready sync.WaitGroup + ready.Add(workers) + var start sync.WaitGroup + start.Add(1) + var done sync.WaitGroup + done.Add(workers) + + // require.* calls t.FailNow which is documented as unsafe from + // non-test goroutines (it calls runtime.Goexit on the wrong stack and + // races with the WaitGroup). Collect errors here and assert from the + // main goroutine after done.Wait(). + errs := make(chan error, workers) + + for i := 0; i < workers; i++ { + token := tokens[i] + go func() { + defer done.Done() + ready.Done() + start.Wait() + errs <- manager.MarkPeerConnected(context.Background(), peerPubKey, nil, accountID, token) + }() + } + + ready.Wait() + start.Done() + done.Wait() + close(errs) + for err := range errs { + require.NoError(t, err, "MarkPeerConnected must not error under contention") + } + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected after the race") + require.Equal(t, expected, peer.Status.SessionStartedAt, + "the largest token must win regardless of execution order") +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") @@ -1991,7 +2070,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), nil, accountID, time.Now().UTC().UnixNano()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 08091d4b7..aba408184 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -38,7 +38,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error + MarkPeerDisconnectedFunc func(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) @@ -227,7 +228,14 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { +func (am *MockAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + // Mirror DefaultAccountManager.OnPeerDisconnected: drive the fencing + // hook so tests that inject MarkPeerDisconnectedFunc actually observe + // disconnect events. Falls through to nil when no hook is set, which + // is the original behaviour. + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerPubKey, accountID, streamStartTime.UnixNano()) + } return nil } @@ -328,13 +336,21 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) + return am.MarkPeerConnectedFunc(ctx, peerKey, realIP, accountID, sessionStartedAt) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } +// MarkPeerDisconnected mock implementation of MarkPeerDisconnected from server.AccountManager interface +func (am *MockAccountManager) MarkPeerDisconnected(ctx context.Context, peerKey string, accountID string, sessionStartedAt int64) error { + if am.MarkPeerDisconnectedFunc != nil { + return am.MarkPeerDisconnectedFunc(ctx, peerKey, accountID, sessionStartedAt) + } + return status.Errorf(codes.Unimplemented, "method MarkPeerDisconnected is not implemented") +} + // DeleteAccount mock implementation of DeleteAccount from server.AccountManager interface func (am *MockAccountManager) DeleteAccount(ctx context.Context, accountID, userID string) error { if am.DeleteAccountFunc != nil { diff --git a/management/server/peer.go b/management/server/peer.go index c3b130ba2..4790a5aab 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -16,7 +16,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/idp" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" @@ -63,56 +62,51 @@ func (am *DefaultAccountManager) GetPeers(ctx context.Context, accountID, userID return am.Store.GetUserPeers(ctx, store.LockingStrengthNone, accountID, userID) } -// MarkPeerConnected marks peer as connected (true) or disconnected (false) -// syncTime is used as the LastSeen timestamp and for stale request detection -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { - var peer *nbpeer.Peer - var settings *types.Settings - var expired bool - var err error - var skipped bool - - err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) - if err != nil { - return err - } - - if connected && !syncTime.After(peer.Status.LastSeen) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", - peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) - skipped = true - return nil - } - - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) - return err - }) - if skipped { - return nil - } +// MarkPeerConnected marks a peer as connected with optimistic-locked +// fencing on PeerStatus.SessionStartedAt. The sessionStartedAt argument +// is the start time of the gRPC sync stream that owns this update, +// expressed as Unix nanoseconds — only the call whose token is greater +// than what's stored wins. LastSeen is written by the database itself; +// we never pass it down. +// +// Disconnects use MarkPeerDisconnected and require the session to match +// exactly; see PeerStatus.SessionStartedAt for the protocol. +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, realIP net.IP, accountID string, sessionStartedAt int64) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { return err } + updated, err := am.Store.MarkPeerConnectedIfNewerSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + return err + } + if !updated { + log.WithContext(ctx).Tracef("peer %s already has a newer session in store, skipping connect", peer.ID) + return nil + } + + if am.geo != nil && realIP != nil { + am.updatePeerLocationIfChanged(ctx, accountID, peer, realIP) + } + + expired := peer.Status != nil && peer.Status.LoginExpired + if peer.AddedWithSSOLogin() { - settings, err = am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) + settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { return err } - if peer.LoginExpirationEnabled && settings.PeerLoginExpirationEnabled { am.schedulePeerLoginExpiration(ctx, accountID) } - if peer.InactivityExpirationEnabled && settings.PeerInactivityExpirationEnabled { am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } } if expired { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) - if err != nil { + if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } } @@ -120,41 +114,46 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { - oldStatus := peer.Status.Copy() - newStatus := oldStatus - newStatus.LastSeen = syncTime - newStatus.Connected = connected - // whenever peer got connected that means that it logged in successfully - if newStatus.Connected { - newStatus.LoginExpired = false - } - peer.Status = newStatus - - if geo != nil && realIP != nil { - location, err := geo.Lookup(realIP) - if err != nil { - log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) - } else { - peer.Location.ConnectionIP = realIP - peer.Location.CountryCode = location.Country.ISOCode - peer.Location.CityName = location.City.Names.En - peer.Location.GeoNameID = location.City.GeonameID - err = transaction.SavePeerLocation(ctx, accountID, peer) - if err != nil { - log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) - } - } - } - - log.WithContext(ctx).Debugf("saving peer status for peer %s is connected: %t", peer.ID, connected) - - err := transaction.SavePeerStatus(ctx, accountID, peer.ID, *newStatus) +// MarkPeerDisconnected marks a peer as disconnected, but only when the +// stored session token matches the one passed in. A mismatch means a +// newer stream has already taken ownership of the peer — disconnects from +// the older stream are ignored. LastSeen is written by the database. +func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerPubKey string, accountID string, sessionStartedAt int64) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) if err != nil { - return false, err + return err } - return oldStatus.LoginExpired, nil + updated, err := am.Store.MarkPeerDisconnectedIfSameSession(ctx, accountID, peer.ID, sessionStartedAt) + if err != nil { + return err + } + if !updated { + log.WithContext(ctx).Tracef("peer %s session token mismatch on disconnect (token=%d), skipping", + peer.ID, sessionStartedAt) + } + return nil +} + +// updatePeerLocationIfChanged refreshes the geolocation on a separate +// row update, only when the connection IP actually changed. Geo lookups +// are expensive so we skip same-IP reconnects. +func (am *DefaultAccountManager) updatePeerLocationIfChanged(ctx context.Context, accountID string, peer *nbpeer.Peer, realIP net.IP) { + if peer.Location.ConnectionIP != nil && peer.Location.ConnectionIP.Equal(realIP) { + return + } + location, err := am.geo.Lookup(realIP) + if err != nil { + log.WithContext(ctx).Warnf("failed to get location for peer %s realip: [%s]: %v", peer.ID, realIP.String(), err) + return + } + peer.Location.ConnectionIP = realIP + peer.Location.CountryCode = location.Country.ISOCode + peer.Location.CityName = location.City.Names.En + peer.Location.GeoNameID = location.City.GeonameID + if err := am.Store.SavePeerLocation(ctx, accountID, peer); err != nil { + log.WithContext(ctx).Warnf("could not store location for peer %s: %s", peer.ID, err) + } } // UpdatePeer updates peer. Only Peer.Name, Peer.SSHEnabled, Peer.LoginExpirationEnabled and Peer.InactivityExpirationEnabled can be updated. diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 17df761a1..2963dfcbd 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -74,8 +74,19 @@ type ProxyMeta struct { } type PeerStatus struct { //nolint:revive - // LastSeen is the last time peer was connected to the management service + // LastSeen is the last time the peer status was updated (i.e. the last + // time we observed the peer being alive on a sync stream). Written by + // the database (CURRENT_TIMESTAMP) — callers do not supply it. LastSeen time.Time + // SessionStartedAt records when the currently-active sync stream began, + // stored as Unix nanoseconds. It acts as the optimistic-locking token + // for status updates: a stream is only allowed to mutate the peer's + // status when its own token strictly exceeds the stored token (when connecting) + // or matches it exactly (for disconnects). Zero means "no + // active session". Integer nanoseconds are used so equality is + // precision-safe across drivers, and so the predicates compose to a + // single bigint comparison. + SessionStartedAt int64 // Connected indicates whether peer is connected to the management service or not Connected bool // LoginExpired @@ -375,10 +386,14 @@ func (p *Peer) EventMeta(dnsDomain string) map[string]any { return meta } -// Copy PeerStatus +// Copy PeerStatus. SessionStartedAt must be propagated so clone-based +// callers (Peer.Copy, MarkLoginExpired, UpdateLastLogin) don't silently +// reset the fencing token to zero — that would let any subsequent +// SavePeerStatus write reopen the optimistic-lock window. func (p *PeerStatus) Copy() *PeerStatus { return &PeerStatus{ LastSeen: p.LastSeen, + SessionStartedAt: p.SessionStartedAt, Connected: p.Connected, LoginExpired: p.LoginExpired, RequiresApproval: p.RequiresApproval, diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 893ee2168..8cf37de56 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -498,8 +498,9 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, peerCopy.Status = &peerStatus fieldsToUpdate := []string{ - "peer_status_last_seen", "peer_status_connected", - "peer_status_login_expired", "peer_status_required_approval", + "peer_status_last_seen", "peer_status_session_started_at", + "peer_status_connected", "peer_status_login_expired", + "peer_status_requires_approval", } result := s.db.Model(&nbpeer.Peer{}). Select(fieldsToUpdate). @@ -516,6 +517,69 @@ func (s *SqlStore) SavePeerStatus(ctx context.Context, accountID, peerID string, return nil } +// MarkPeerConnectedIfNewerSession is an atomic optimistic-locked update. +// The peer is marked connected with the given session token only when +// the stored SessionStartedAt is strictly smaller than the incoming +// one — equivalently, when no newer stream has already taken ownership. +// The sentinel zero (set on peer creation or after a disconnect) counts +// as the smallest possible token. This is the write half of the +// fencing protocol described on PeerStatus.SessionStartedAt. +// +// The post-write side effects in the caller — geo lookup, +// schedulePeerLoginExpiration, checkAndSchedulePeerInactivityExpiration, +// OnPeersUpdated — all run AFTER this method returns and are deliberately +// outside the database write so they cannot extend the row-lock window. +// +// LastSeen is set to the database's clock (CURRENT_TIMESTAMP) at the +// moment the row is written. The caller never supplies LastSeen because +// the value would otherwise drift under lock contention — a Go-side +// time.Now() taken before the write can land minutes later than the +// actual UPDATE under load, which previously caused real ordering bugs. +func (s *SqlStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at < ?", newSessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": true, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": newSessionStartedAt, + "peer_status_login_expired": false, + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer connected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + +// MarkPeerDisconnectedIfSameSession is an atomic optimistic-locked update. +// The peer is marked disconnected only when the stored SessionStartedAt +// matches the incoming token — meaning the stream that owns the current +// session is the one ending. If a newer stream has already replaced the +// session, the update is skipped. LastSeen is set to CURRENT_TIMESTAMP at +// write time; see MarkPeerConnectedIfNewerSession for the rationale. +// +// A zero sessionStartedAt is rejected at the call site; the underlying +// WHERE on equality would otherwise match every never-connected peer. +func (s *SqlStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + if sessionStartedAt == 0 { + return false, nil + } + result := s.db.WithContext(ctx). + Model(&nbpeer.Peer{}). + Where(accountAndIDQueryCondition, accountID, peerID). + Where("peer_status_session_started_at = ?", sessionStartedAt). + Updates(map[string]any{ + "peer_status_connected": false, + "peer_status_last_seen": gorm.Expr("CURRENT_TIMESTAMP"), + "peer_status_session_started_at": int64(0), + }) + if result.Error != nil { + return false, status.Errorf(status.Internal, "mark peer disconnected: %v", result.Error) + } + return result.RowsAffected > 0, nil +} + func (s *SqlStore) SavePeerLocation(ctx context.Context, accountID string, peerWithLocation *nbpeer.Peer) error { // To maintain data integrity, we create a copy of the peer's location to prevent unintended updates to other fields. var peerCopy nbpeer.Peer @@ -1723,9 +1787,10 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee inactivity_expiration_enabled, last_login, created_at, ephemeral, extra_dns_labels, allow_extra_dns_labels, meta_hostname, meta_go_os, meta_kernel, meta_core, meta_platform, meta_os, meta_os_version, meta_wt_version, meta_ui_version, meta_kernel_version, meta_network_addresses, meta_system_serial_number, meta_system_product_name, meta_system_manufacturer, - meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_connected, peer_status_login_expired, - peer_status_requires_approval, location_connection_ip, location_country_code, location_city_name, - location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 FROM peers WHERE account_id = $1` + meta_environment, meta_flags, meta_files, meta_capabilities, peer_status_last_seen, peer_status_session_started_at, + peer_status_connected, peer_status_login_expired, peer_status_requires_approval, location_connection_ip, + location_country_code, location_city_name, location_geo_name_id, proxy_meta_embedded, proxy_meta_cluster, ipv6 + FROM peers WHERE account_id = $1` rows, err := s.pool.Query(ctx, query, accountID) if err != nil { return nil, err @@ -1738,6 +1803,7 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee lastLogin, createdAt sql.NullTime sshEnabled, loginExpirationEnabled, inactivityExpirationEnabled, ephemeral, allowExtraDNSLabels sql.NullBool peerStatusLastSeen sql.NullTime + peerStatusSessionStartedAt sql.NullInt64 peerStatusConnected, peerStatusLoginExpired, peerStatusRequiresApproval, proxyEmbedded sql.NullBool ip, extraDNS, netAddr, env, flags, files, capabilities, connIP, ipv6 []byte metaHostname, metaGoOS, metaKernel, metaCore, metaPlatform sql.NullString @@ -1752,8 +1818,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee &allowExtraDNSLabels, &metaHostname, &metaGoOS, &metaKernel, &metaCore, &metaPlatform, &metaOS, &metaOSVersion, &metaWtVersion, &metaUIVersion, &metaKernelVersion, &netAddr, &metaSystemSerialNumber, &metaSystemProductName, &metaSystemManufacturer, &env, &flags, &files, &capabilities, - &peerStatusLastSeen, &peerStatusConnected, &peerStatusLoginExpired, &peerStatusRequiresApproval, &connIP, - &locationCountryCode, &locationCityName, &locationGeoNameID, &proxyEmbedded, &proxyCluster, &ipv6) + &peerStatusLastSeen, &peerStatusSessionStartedAt, &peerStatusConnected, &peerStatusLoginExpired, + &peerStatusRequiresApproval, &connIP, &locationCountryCode, &locationCityName, &locationGeoNameID, + &proxyEmbedded, &proxyCluster, &ipv6) if err == nil { if lastLogin.Valid { @@ -1780,6 +1847,9 @@ func (s *SqlStore) getPeers(ctx context.Context, accountID string) ([]nbpeer.Pee if peerStatusLastSeen.Valid { p.Status.LastSeen = peerStatusLastSeen.Time } + if peerStatusSessionStartedAt.Valid { + p.Status.SessionStartedAt = peerStatusSessionStartedAt.Int64 + } if peerStatusConnected.Valid { p.Status.Connected = peerStatusConnected.Bool } diff --git a/management/server/store/store.go b/management/server/store/store.go index aa601c33f..a723c1fc3 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -167,6 +167,21 @@ type Store interface { GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) SavePeer(ctx context.Context, accountID string, peer *nbpeer.Peer) error SavePeerStatus(ctx context.Context, accountID, peerID string, status nbpeer.PeerStatus) error + // MarkPeerConnectedIfNewerSession sets the peer to connected with the + // given session token, but only when the stored SessionStartedAt is + // strictly less than newSessionStartedAt (the sentinel zero counts as + // "older"). LastSeen is recorded by the database at the moment the + // row is updated — never by the caller — so it always reflects the + // real write time even under lock contention. + // Returns true when the update happened, false when this stream lost + // the race against a newer session. + MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) + // MarkPeerDisconnectedIfSameSession sets the peer to disconnected and + // resets SessionStartedAt to zero, but only when the stored + // SessionStartedAt equals the given sessionStartedAt. LastSeen is + // recorded by the database. Returns true when the update happened, + // false when a newer session has taken over. + MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) SavePeerLocation(ctx context.Context, accountID string, peer *nbpeer.Peer) error ApproveAccountPeers(ctx context.Context, accountID string) (int, error) DeletePeer(ctx context.Context, accountID string, peerID string) error diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 9780c521e..d51629606 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -2878,6 +2878,36 @@ func (mr *MockStoreMockRecorder) SavePeerStatus(ctx, accountID, peerID, status i return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SavePeerStatus", reflect.TypeOf((*MockStore)(nil).SavePeerStatus), ctx, accountID, peerID, status) } +// MarkPeerConnectedIfNewerSession mocks base method. +func (m *MockStore) MarkPeerConnectedIfNewerSession(ctx context.Context, accountID, peerID string, newSessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerConnectedIfNewerSession", ctx, accountID, peerID, newSessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerConnectedIfNewerSession indicates an expected call of MarkPeerConnectedIfNewerSession. +func (mr *MockStoreMockRecorder) MarkPeerConnectedIfNewerSession(ctx, accountID, peerID, newSessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerConnectedIfNewerSession", reflect.TypeOf((*MockStore)(nil).MarkPeerConnectedIfNewerSession), ctx, accountID, peerID, newSessionStartedAt) +} + +// MarkPeerDisconnectedIfSameSession mocks base method. +func (m *MockStore) MarkPeerDisconnectedIfSameSession(ctx context.Context, accountID, peerID string, sessionStartedAt int64) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "MarkPeerDisconnectedIfSameSession", ctx, accountID, peerID, sessionStartedAt) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// MarkPeerDisconnectedIfSameSession indicates an expected call of MarkPeerDisconnectedIfSameSession. +func (mr *MockStoreMockRecorder) MarkPeerDisconnectedIfSameSession(ctx, accountID, peerID, sessionStartedAt interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MarkPeerDisconnectedIfSameSession", reflect.TypeOf((*MockStore)(nil).MarkPeerDisconnectedIfSameSession), ctx, accountID, peerID, sessionStartedAt) +} + // SavePolicy mocks base method. func (m *MockStore) SavePolicy(ctx context.Context, policy *types2.Policy) error { m.ctrl.T.Helper()