From af8f730bdac9109b02f3432efd6b6d24a251cea9 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:00:43 +0100 Subject: [PATCH] [management] check stream start time for connecting peer (#5267) --- management/internals/shared/grpc/server.go | 10 +++--- management/server/account.go | 6 ++-- management/server/account/manager.go | 4 +-- management/server/account_test.go | 33 +++++++++++++++---- management/server/mock_server/account_mock.go | 12 +++---- management/server/peer.go | 20 ++++++++--- 6 files changed, 58 insertions(+), 27 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index befcd2adf..98c68ebda 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -300,20 +300,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) - peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - streamStartTime := time.Now().UTC() - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -321,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -338,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { diff --git a/management/server/account.go b/management/server/account.go index 4f53415f5..a9f59773a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1670,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -1697,7 +1697,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account return nil } - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index eed7739da..1d25b0af7 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -58,7 +58,7 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error @@ -114,7 +114,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index f3d98916c..443e6344e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1979,7 +1979,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.NoError(t, err, "unable to add peer") t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -1997,7 +1997,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }) t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -2014,6 +2014,27 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.True(t, peer.Status.Connected, "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") }) + + t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + node2SyncTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + require.NoError(t, err, "node 2 should connect peer") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + + node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + require.NoError(t, err, "stale connect should not return error") + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should still be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), + "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + }) } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { @@ -2038,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -3231,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC()) assert.NoError(b, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a4754d180..8471d0a94 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,8 +37,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -214,9 +214,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime) } return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } @@ -322,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index ab72d3051..a4bdc784d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { +// syncTime is used as the LastSeen timestamp and for stale request detection +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { var peer *nbpeer.Peer var settings *types.Settings var expired bool var err error + var skipped bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) @@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + if connected && !syncTime.After(peer.Status.LastSeen) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", + peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) + skipped = true + return nil + } + + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) return err }) + if skipped { + return nil + } if err != nil { return err } @@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus - newStatus.LastSeen = time.Now().UTC() + newStatus.LastSeen = syncTime newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully if newStatus.Connected {