From d488f583115402c62d9974e787a5b23f1b73fc32 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:44:46 +0100 Subject: [PATCH] [management] fix set disconnected status for connected peer (#5247) --- management/internals/shared/grpc/server.go | 32 ++++++----- management/server/account.go | 16 +++++- management/server/account/manager.go | 2 +- management/server/account_test.go | 55 +++++++++++++++++++ management/server/mock_server/account_mock.go | 5 +- 5 files changed, 89 insertions(+), 21 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 3704b3188..befcd2adf 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -307,11 +307,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return mapError(ctx, err) } + streamStartTime := time.Now().UTC() + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -319,7 +321,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -336,7 +338,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { @@ -404,7 +406,7 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -416,11 +418,11 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -429,7 +431,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return srv.Context().Err() } } @@ -437,16 +439,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { key, err := s.secretsManager.GetWGKey() if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.Send(&proto.EncryptedMessage{ @@ -454,7 +456,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) @@ -486,15 +488,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even return nil } -func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) } -func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) { - err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { + err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } diff --git a/management/server/account.go b/management/server/account.go index 8f9dad031..4f53415f5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1684,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return peer, netMap, postureChecks, dnsfwdPort, nil } -func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) + if err != nil { + log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) + return nil + } + + if peer.Status.LastSeen.After(streamStartTime) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", + peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) + return nil + } + + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 5e9bb42a2..eed7739da 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -115,7 +115,7 @@ type Manager interface { GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) - OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 86cc69e8b..f3d98916c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1961,6 +1961,61 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } } +func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err, "unable to get peer") + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := time.Now().UTC() + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.False(t, peer.Status.Connected, "peer should be disconnected") + }) + + t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, + "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + }) +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 026989898..a4754d180 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -221,9 +221,8 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { - // TODO implement me - panic("implement me") +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + return nil } func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) {