From 2ffa6a23319dc0cc2509ec8453fddc57661395bf Mon Sep 17 00:00:00 2001 From: mlsmaycon Date: Mon, 16 Mar 2026 18:20:32 +0100 Subject: [PATCH] [management] Refactor peer activity checks to prevent stale status updates --- management/server/account.go | 16 +++------------- management/server/account_test.go | 23 +++++++++++++++++++++++ management/server/peer.go | 6 +++--- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 75db36a5f..696b3879f 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1760,20 +1760,10 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return peer, netMap, postureChecks, dnsfwdPort, nil } +// OnPeerDisconnected marks a peer as disconnected using streamStartTime for stale detection. +// The actual staleness check happens inside MarkPeerConnected's transaction to avoid TOCTOU races. func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { - peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) - if err != nil { - log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) - return nil - } - - if peer.Status.LastSeen.After(streamStartTime) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", - peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) - return nil - } - - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) + err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, streamStartTime) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account_test.go b/management/server/account_test.go index fdec43617..4863d732d 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2030,6 +2030,29 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") }) + t.Run("skip stale disconnect when peer reconnected to another server", func(t *testing.T) { + // Simulate: peer connects to Server A at T1, then Server B sends stale disconnect from T0 + serverAConnectTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, serverAConnectTime) + require.NoError(t, err, "server A should connect peer") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected via server A") + + // Server B's stream started before the peer reconnected to Server A + serverBStreamStart := serverAConnectTime.Add(-5 * time.Second) + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, serverBStreamStart) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, + "peer should remain connected: server B's stale disconnect must not override server A's newer connect") + require.Equal(t, serverAConnectTime.Unix(), peer.Status.LastSeen.Unix(), + "LastSeen should remain as server A's connect time") + }) + t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { node2SyncTime := time.Now().UTC() err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) diff --git a/management/server/peer.go b/management/server/peer.go index 78ecbfcae..3eea14bfa 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -117,9 +117,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } - if connected && !syncTime.After(peer.Status.LastSeen) { - log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", - peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) + if !syncTime.After(peer.Status.LastSeen) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping status update to connected=%t", + peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339), connected) skipped = true return nil }