From 22126d04846b20e17d9c617da7e3cee5eec678d4 Mon Sep 17 00:00:00 2001 From: Pascal Fischer Date: Wed, 6 Nov 2024 20:29:59 +0100 Subject: [PATCH] add session id to update channel --- management/server/account_test.go | 20 ++++---- management/server/dns_test.go | 14 ++--- management/server/group_test.go | 22 ++++---- management/server/grpcserver.go | 34 +++++++------ management/server/nameserver_test.go | 12 ++--- management/server/peer.go | 2 +- management/server/peer_test.go | 30 ++++++----- management/server/policy_test.go | 22 ++++---- management/server/posture_checks_test.go | 22 ++++---- management/server/route_test.go | 16 +++--- management/server/setupkey_test.go | 6 +-- management/server/token_mgr_test.go | 2 +- management/server/updatechannel.go | 65 ++++++++++++++++-------- management/server/updatechannel_test.go | 49 +++++++++++++++--- management/server/user_test.go | 19 +++---- 15 files changed, 202 insertions(+), 133 deletions(-) diff --git a/management/server/account_test.go b/management/server/account_test.go index 1cd4ae449..0a83ebb90 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1147,14 +1147,14 @@ func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { require.NoError(t, err) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - message := <-updMsg + message := <-updMsg.channel networkMap := message.Update.GetNetworkMap() if len(networkMap.RemotePeers) != 2 { t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) @@ -1174,14 +1174,14 @@ func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { manager, account, peer1, _, _ := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - message := <-updMsg + message := <-updMsg.channel networkMap := message.Update.GetNetworkMap() if len(networkMap.RemotePeers) != 0 { t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) @@ -1210,7 +1210,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) policy := Policy{ Enabled: true, @@ -1230,7 +1230,7 @@ func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { go func() { defer wg.Done() - message := <-updMsg + message := <-updMsg.channel networkMap := message.Update.GetNetworkMap() if len(networkMap.RemotePeers) != 2 { t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) @@ -1277,14 +1277,14 @@ func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { } updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) wg := sync.WaitGroup{} wg.Add(1) go func() { defer wg.Done() - message := <-updMsg + message := <-updMsg.channel networkMap := message.Update.GetNetworkMap() if len(networkMap.RemotePeers) != 1 { t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) @@ -1303,7 +1303,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) - defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) group := group.Group{ ID: "groupA", @@ -1339,7 +1339,7 @@ func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { go func() { defer wg.Done() - message := <-updMsg + message := <-updMsg.channel networkMap := message.Update.GetNetworkMap() if len(networkMap.RemotePeers) != 0 { t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 8a66da96c..c7cb0bce9 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -499,14 +499,14 @@ func TestDNSAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Saving DNS settings with groups that have no peers should not trigger updates to account peers or send peer updates t.Run("saving dns setting with unused groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -526,7 +526,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { t.Run("creating dns setting with unused groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -559,7 +559,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -585,7 +585,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { t.Run("saving dns setting with used groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -605,7 +605,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { t.Run("removing group with no peers from dns settings", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -625,7 +625,7 @@ func TestDNSAccountPeersUpdate(t *testing.T) { t.Run("removing group with peers from dns settings", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/group_test.go b/management/server/group_test.go index 89184e819..2c36e5ddf 100644 --- a/management/server/group_test.go +++ b/management/server/group_test.go @@ -418,14 +418,14 @@ func TestGroupAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Saving a group that is not linked to any resource should not update account peers t.Run("saving unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -448,7 +448,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -467,7 +467,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("removing peer from unliked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -485,7 +485,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("deleting group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -519,7 +519,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("saving linked group to policy", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -541,7 +541,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("adding peer to linked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -559,7 +559,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { t.Run("removing peer from linked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -588,7 +588,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -629,7 +629,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -656,7 +656,7 @@ func TestGroupAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 4c4ef6c3c..f58847818 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -194,31 +194,31 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, peerUpdates *PeerUpdateChannel, srv proto.ManagementService_SyncServer) error { for { select { // condition when there are some updates - case update, open := <-updates: + case update, open := <-peerUpdates.channel: if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) + s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(peerUpdates.channel) + 1) } if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, peerUpdates.sessionID, update, srv); err != nil { return err } // condition when client <-> server connection has been terminated case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects - log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + log.WithContext(ctx).Debugf("stream of peer %s with session %s has been closed", peerKey.String(), peerUpdates.sessionID) + s.cancelPeerRoutines(ctx, accountID, peer, peerUpdates.sessionID) return srv.Context().Err() } } @@ -226,10 +226,10 @@ func (s *GRPCServer) handleUpdates(ctx context.Context, accountID string, peerKe // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, sessionID string, update *UpdateMessage, srv proto.ManagementService_SyncServer) error { encryptedResp, err := encryption.EncryptMessage(peerKey, s.wgKey, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, sessionID) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.SendMsg(&proto.EncryptedMessage{ @@ -237,18 +237,22 @@ func (s *GRPCServer) sendUpdate(ctx context.Context, accountID string, peerKey w Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, sessionID) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) return nil } -func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { - s.peersUpdateManager.CloseChannel(ctx, peer.ID) - s.secretsManager.CancelRefresh(peer.ID) - _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) - s.ephemeralManager.OnPeerDisconnected(ctx, peer) +func (s *GRPCServer) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, sessionID string) { + + bool1 := s.peersUpdateManager.CloseChannel(ctx, peer.ID, sessionID) + if bool1 { + _ = s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) + + s.secretsManager.CancelRefresh(sessionID) + s.ephemeralManager.OnPeerDisconnected(ctx, peer) + } } func (s *GRPCServer) validateToken(ctx context.Context, jwtToken string) (string, error) { diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 846dbf023..3d26ace1b 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -960,7 +960,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Creating a nameserver group with a distribution group no peers should not update account peers @@ -968,7 +968,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { t.Run("creating nameserver group with distribution group no peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -995,7 +995,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { t.Run("saving nameserver group with distribution group no peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1013,7 +1013,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { t.Run("creating nameserver group with distribution group has peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1039,7 +1039,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { t.Run("saving nameserver group with distribution group has peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1069,7 +1069,7 @@ func TestNameServerAccountPeersUpdate(t *testing.T) { t.Run("deleting nameserver group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/peer.go b/management/server/peer.go index 7cc2209c5..7a8657c94 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -313,7 +313,7 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou }, NetworkMap: &NetworkMap{}, }) - am.peersUpdateManager.CloseChannel(ctx, peer.ID) + am.peersUpdateManager.CloseChannel(ctx, peer.ID, SessionIdForceOverwrite) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) } diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 5127f77fb..7eaf7025c 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -864,10 +864,14 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { b.Fatalf("Failed to get account: %v", err) } - peerChannels := make(map[string]chan *UpdateMessage) + peerChannels := make(map[string]*PeerUpdateChannel) for peerID := range account.Peers { - peerChannels[peerID] = make(chan *UpdateMessage, channelBufferSize) + peerChannels[peerID] = &PeerUpdateChannel{ + peerID: peerID, + channel: make(chan *UpdateMessage, channelBufferSize), + sessionID: xid.New().String(), + } } manager.peersUpdateManager.peerChannels = peerChannels @@ -1315,14 +1319,14 @@ func TestPeerAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1340,7 +1344,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("adding peer to unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1365,7 +1369,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with unlinked group", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1383,7 +1387,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("updating peer label", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1417,7 +1421,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1443,7 +1447,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with linked group to policy", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1481,7 +1485,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1507,7 +1511,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with linked group to route", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1536,7 +1540,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1562,7 +1566,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) { t.Run("deleting peer with linked group to route", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/policy_test.go b/management/server/policy_test.go index e7f0f9cd2..7661b0acc 100644 --- a/management/server/policy_test.go +++ b/management/server/policy_test.go @@ -856,7 +856,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Saving policy with rule groups with no peers should not update account's peers and not send peer update @@ -878,7 +878,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -913,7 +913,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -948,7 +948,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -982,7 +982,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1016,7 +1016,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1051,7 +1051,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1085,7 +1085,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1105,7 +1105,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1126,7 +1126,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyID := "policy-destination-has-peers-source-none" done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1145,7 +1145,7 @@ func TestPolicyAccountPeersUpdate(t *testing.T) { policyID := "policy-rule-groups-no-peers" done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/posture_checks_test.go b/management/server/posture_checks_test.go index c63538b9d..9cb5b669d 100644 --- a/management/server/posture_checks_test.go +++ b/management/server/posture_checks_test.go @@ -147,7 +147,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) postureCheck := posture.Checks{ @@ -165,7 +165,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("saving unused posture check", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -183,7 +183,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("updating unused posture check", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -222,7 +222,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("linking posture check to policy with peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -251,7 +251,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -269,7 +269,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("removing posture check from policy", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -289,7 +289,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("deleting unused posture check", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -328,7 +328,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -352,7 +352,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { t.Run("updating linked posture check to policy where destination has peers but source does not", func(t *testing.T) { updMsg1 := manager.peersUpdateManager.CreateChannel(context.Background(), peer2.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer2.ID, updMsg1.sessionID) }) policy = Policy{ ID: "policyB", @@ -375,7 +375,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg1) + peerShouldReceiveUpdate(t, updMsg1.channel) close(done) }() @@ -416,7 +416,7 @@ func TestPostureCheckAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/route_test.go b/management/server/route_test.go index 4893e19b9..c71bdbe3d 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -1807,7 +1807,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1ID, updMsg.sessionID) }) // Creating a route with no routing peer and no peers in PeerGroups or Groups should not update account peers and not send peer update @@ -1827,7 +1827,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1863,7 +1863,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1899,7 +1899,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { t.Run("creating route with a routing peer", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1924,7 +1924,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1942,7 +1942,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { t.Run("deleting route", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1978,7 +1978,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -2018,7 +2018,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/setupkey_test.go b/management/server/setupkey_test.go index 2ed8aef95..4b654290a 100644 --- a/management/server/setupkey_test.go +++ b/management/server/setupkey_test.go @@ -408,7 +408,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) var setupKey *SetupKey @@ -417,7 +417,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { t.Run("creating setup key", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -435,7 +435,7 @@ func TestSetupKeyAccountPeersUpdate(t *testing.T) { t.Run("saving setup key", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() diff --git a/management/server/token_mgr_test.go b/management/server/token_mgr_test.go index 3e63346c2..8f104c496 100644 --- a/management/server/token_mgr_test.go +++ b/management/server/token_mgr_test.go @@ -104,7 +104,7 @@ func TestTimeBasedAuthSecretsManager_SetupRefresh(t *testing.T) { loop: for timeout := time.After(5 * time.Second); ; { select { - case update := <-updateChannel: + case update := <-updateChannel.channel: updates = append(updates, update) case <-timeout: break loop diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index 59b6fd094..7b8055571 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -5,6 +5,7 @@ import ( "sync" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -12,15 +13,22 @@ import ( ) const channelBufferSize = 100 +const SessionIdForceOverwrite = "FORCE" type UpdateMessage struct { Update *proto.SyncResponse NetworkMap *NetworkMap } +type PeerUpdateChannel struct { + peerID string + sessionID string + channel chan *UpdateMessage +} + type PeersUpdateManager struct { - // peerChannels is an update channel indexed by Peer.ID - peerChannels map[string]chan *UpdateMessage + // peerChannels is a map of peerID to the channel used to deliver updates relevant to the peer + peerChannels map[string]*PeerUpdateChannel // channelsMux keeps the mutex to access peerChannels channelsMux *sync.RWMutex // metrics provides method to collect application metrics @@ -30,7 +38,7 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), + peerChannels: make(map[string]*PeerUpdateChannel), channelsMux: &sync.RWMutex{}, metrics: metrics, } @@ -50,14 +58,14 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() - if channel, ok := p.peerChannels[peerID]; ok { + if peerUpdates, ok := p.peerChannels[peerID]; ok { found = true select { - case channel <- update: + case peerUpdates.channel <- update: log.WithContext(ctx).Debugf("update was sent to channel for peer %s", peerID) default: dropped = true - log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(channel)) + log.WithContext(ctx).Warnf("channel for peer %s is %d full or closed", peerID, len(peerUpdates.channel)) } } else { log.WithContext(ctx).Debugf("peer %s has no channel", peerID) @@ -65,7 +73,7 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } // CreateChannel creates a go channel for a given peer used to deliver updates relevant to the peer. -func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) chan *UpdateMessage { +func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) *PeerUpdateChannel { start := time.Now() closed := false @@ -81,24 +89,39 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c if channel, ok := p.peerChannels[peerID]; ok { closed = true delete(p.peerChannels, peerID) - close(channel) + close(channel.channel) + log.WithContext(ctx).Debugf("overwriting existing channel for peer %s", peerID) } - // mbragin: todo shouldn't it be more? or configurable? - channel := make(chan *UpdateMessage, channelBufferSize) - p.peerChannels[peerID] = channel - log.WithContext(ctx).Debugf("opened updates channel for a peer %s", peerID) + peerUpdateChannel := &PeerUpdateChannel{ + peerID: peerID, + sessionID: uuid.New().String(), + // mbragin: todo shouldn't it be more? or configurable? + channel: make(chan *UpdateMessage, channelBufferSize), + } - return channel + p.peerChannels[peerID] = peerUpdateChannel + + log.WithContext(ctx).Debugf("opened updates channel for a peer %s and session %s", peerID, peerUpdateChannel.sessionID) + + return peerUpdateChannel } -func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { - if channel, ok := p.peerChannels[peerID]; ok { - delete(p.peerChannels, peerID) - close(channel) +func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string, sessionID string) bool { + if peerUpdates, ok := p.peerChannels[peerID]; ok { + if peerUpdates.sessionID == sessionID || sessionID == SessionIdForceOverwrite { + delete(p.peerChannels, peerID) + close(peerUpdates.channel) + log.WithContext(ctx).Debugf("closed updates channel of a peer %s and session %s", peerID, sessionID) + return true + } + log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but current session is %s", peerID, sessionID, peerUpdates.sessionID) + return false } - log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) + log.WithContext(ctx).Warnf("tried to close updates channel of a peer %s with session %s, but no channel found", peerID, sessionID) + + return true } // CloseChannels closes updates channel for each given peer @@ -114,12 +137,12 @@ func (p *PeersUpdateManager) CloseChannels(ctx context.Context, peerIDs []string }() for _, id := range peerIDs { - p.closeChannel(ctx, id) + p.closeChannel(ctx, id, SessionIdForceOverwrite) } } // CloseChannel closes updates channel of a given peer -func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { +func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string, sessionID string) bool { start := time.Now() p.channelsMux.Lock() @@ -130,7 +153,7 @@ func (p *PeersUpdateManager) CloseChannel(ctx context.Context, peerID string) { } }() - p.closeChannel(ctx, peerID) + return p.closeChannel(ctx, peerID, sessionID) } // GetAllConnectedPeers returns a copy of the connected peers map diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..ac83dc7a5 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -5,6 +5,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/netbirdio/netbird/management/proto" ) @@ -13,7 +15,7 @@ import ( func TestCreateChannel(t *testing.T) { peer := "test-create" peersUpdater := NewPeersUpdateManager(nil) - defer peersUpdater.CloseChannel(context.Background(), peer) + defer peersUpdater.CloseChannel(context.Background(), peer, "sessionID") _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { @@ -35,7 +37,7 @@ func TestSendUpdate(t *testing.T) { } peersUpdater.SendUpdate(context.Background(), peer, update1) select { - case <-peersUpdater.peerChannels[peer]: + case <-peersUpdater.peerChannels[peer].channel: default: t.Error("Update wasn't send") } @@ -56,7 +58,7 @@ func TestSendUpdate(t *testing.T) { select { case <-timeout: t.Error("timed out reading previously sent updates") - case updateReader := <-peersUpdater.peerChannels[peer]: + case updateReader := <-peersUpdater.peerChannels[peer].channel: if updateReader.Update.NetworkMap.Serial == update2.Update.NetworkMap.Serial { t.Error("got the update that shouldn't have been sent") } @@ -65,15 +67,50 @@ func TestSendUpdate(t *testing.T) { } -func TestCloseChannel(t *testing.T) { +func TestCloseChannel_WithCorrectSessionID(t *testing.T) { peer := "test-close" peersUpdater := NewPeersUpdateManager(nil) - _ = peersUpdater.CreateChannel(context.Background(), peer) + peerUpdates := peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") } - peersUpdater.CloseChannel(context.Background(), peer) + + updateDB := peersUpdater.CloseChannel(context.Background(), peer, peerUpdates.sessionID) if _, ok := peersUpdater.peerChannels[peer]; ok { t.Error("Error closing the channel") } + + assert.Equal(t, true, updateDB) +} + +func TestCloseChannel_WithWrongSessionID(t *testing.T) { + peer := "test-close" + peersUpdater := NewPeersUpdateManager(nil) + peersUpdater.CreateChannel(context.Background(), peer) + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Error creating the channel") + } + + updateDB := peersUpdater.CloseChannel(context.Background(), peer, "wrongSessionID") + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Should not close channel with wrong session id") + } + + assert.Equal(t, false, updateDB) +} + +func TestCloseChannel_WithForceOverwrite(t *testing.T) { + peer := "test-close" + peersUpdater := NewPeersUpdateManager(nil) + peersUpdater.CreateChannel(context.Background(), peer) + if _, ok := peersUpdater.peerChannels[peer]; !ok { + t.Error("Error creating the channel") + } + + updateDB := peersUpdater.CloseChannel(context.Background(), peer, SessionIdForceOverwrite) + if _, ok := peersUpdater.peerChannels[peer]; ok { + t.Error("Should close channel if forced") + } + + assert.Equal(t, true, updateDB) } diff --git a/management/server/user_test.go b/management/server/user_test.go index d4f560a54..ebc5f9b9d 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -10,13 +10,14 @@ import ( "github.com/eko/gocache/v3/cache" cacheStore "github.com/eko/gocache/v3/store" "github.com/google/go-cmp/cmp" - nbgroup "github.com/netbirdio/netbird/management/server/group" - nbpeer "github.com/netbirdio/netbird/management/server/peer" gocache "github.com/patrickmn/go-cache" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + nbgroup "github.com/netbirdio/netbird/management/server/group" + nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/activity" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integration_reference" @@ -1297,14 +1298,14 @@ func TestUserAccountPeersUpdate(t *testing.T) { updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID, updMsg.sessionID) }) // Creating a new regular user should not update account peers and not send peer update t.Run("creating new regular user with no groups", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1327,7 +1328,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("updating user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1350,7 +1351,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("deleting user with no linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldNotReceiveUpdate(t, updMsg) + peerShouldNotReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1387,7 +1388,7 @@ func TestUserAccountPeersUpdate(t *testing.T) { t.Run("updating user with linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, updMsg) + peerShouldReceiveUpdate(t, updMsg.channel) close(done) }() @@ -1408,14 +1409,14 @@ func TestUserAccountPeersUpdate(t *testing.T) { peer4UpdMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer4.ID) t.Cleanup(func() { - manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID) + manager.peersUpdateManager.CloseChannel(context.Background(), peer4.ID, peer4UpdMsg.sessionID) }) // deleting user with linked peers should update account peers and send peer update t.Run("deleting user with linked peers", func(t *testing.T) { done := make(chan struct{}) go func() { - peerShouldReceiveUpdate(t, peer4UpdMsg) + peerShouldReceiveUpdate(t, peer4UpdMsg.channel) close(done) }()