diff --git a/go.mod b/go.mod index 1da44da3b..af6aa327f 100644 --- a/go.mod +++ b/go.mod @@ -66,6 +66,7 @@ require ( github.com/pion/transport/v3 v3.0.1 github.com/pion/turn/v3 v3.0.1 github.com/prometheus/client_golang v1.19.1 + github.com/r3labs/diff v1.1.0 github.com/rs/xid v1.3.0 github.com/shirou/gopsutil/v3 v3.24.4 github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 diff --git a/go.sum b/go.sum index 842311344..f22e26be6 100644 --- a/go.sum +++ b/go.sum @@ -413,6 +413,8 @@ github.com/prometheus/common v0.53.0 h1:U2pL9w9nmJwJDa4qqLQ3ZaePJ6ZTwt7cMD3AG3+a github.com/prometheus/common v0.53.0/go.mod h1:BrxBKv3FWBIGXw89Mg1AeBq7FSyRzXWI3l3e7W3RN5U= github.com/prometheus/procfs v0.15.0 h1:A82kmvXJq2jTu5YUhSGNlYoxh85zLnKgPz4bMZgI5Ek= github.com/prometheus/procfs v0.15.0/go.mod h1:Y0RJ/Y5g5wJpkTisOtqwDSo4HwhGmLB4VQSw2sQJLHk= +github.com/r3labs/diff v1.1.0 h1:V53xhrbTHrWFWq3gI4b94AjgEJOerO1+1l0xyHOBi8M= +github.com/r3labs/diff v1.1.0/go.mod h1:7WjXasNzi0vJetRcB/RqNl5dlIsmXcTTLmF5IoH6Xig= github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= github.com/rs/cors v1.8.0 h1:P2KMzcFwrPoSjkF1WLRPsp3UMLyql8L4v9hQpVeK5so= diff --git a/management/server/account_test.go b/management/server/account_test.go index 71b43bd65..e6c9b60da 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1108,61 +1108,132 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) { assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"])) } -func TestAccountManager_NetworkUpdates(t *testing.T) { - manager, err := createManager(t) - if err != nil { - t.Fatal(err) - return +func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + group := group.Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, } - userID := "account_creator" + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() - account, err := createAccount(manager, "test_account", userID, "") - if err != nil { - t.Fatal(err) - } - - setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) - if err != nil { - t.Fatal("error creating setup key") - return - } - - if account.Network.Serial != 0 { - t.Errorf("expecting account network to have an initial Serial=0") - return - } - - getPeer := func() *nbpeer.Peer { - key, err := wgtypes.GeneratePrivateKey() - if err != nil { - t.Fatal(err) - return nil + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) } - expectedPeerKey := key.PublicKey().String() + }() - peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ - Key: expectedPeerKey, - Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, - }) - if err != nil { - t.Fatalf("expecting peer1 to be added, got failure %v", err) - return nil - } - - return peer - } - - peer1 := getPeer() - peer2 := getPeer() - peer3 := getPeer() - - account, err = manager.Store.GetAccount(context.Background(), account.Id) - if err != nil { - t.Fatal(err) + if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { + t.Errorf("save group: %v", err) return } + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) { + manager, account, peer1, _, _ := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + policy := Policy{ + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"group-id"}, + Destinations: []string{"group-id"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + } + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 2 { + t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + t.Errorf("save policy: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) { + manager, account, peer1, _, peer3 := setupNetworkMapTest(t) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 1 { + t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) + } + }() + + if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { + t.Errorf("delete peer: %v", err) + return + } + + wg.Wait() +} + +func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) @@ -1185,108 +1256,40 @@ func TestAccountManager_NetworkUpdates(t *testing.T) { }, } + if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } + + if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { + t.Errorf("save policy: %v", err) + return + } + wg := sync.WaitGroup{} - t.Run("save group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + wg.Add(1) + go func() { + defer wg.Done() - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil { - t.Errorf("save group: %v", err) - return + message := <-updMsg + networkMap := message.Update.GetNetworkMap() + if len(networkMap.RemotePeers) != 0 { + t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) } + }() - wg.Wait() - }) + // clean policy is pre requirement for delete group + if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil { + t.Errorf("delete default rule: %v", err) + return + } - t.Run("delete policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() + if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { + t.Errorf("delete group: %v", err) + return + } - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - - t.Run("save policy update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 2 { - t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil { - t.Errorf("delete default rule: %v", err) - return - } - - wg.Wait() - }) - t.Run("delete peer update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 1 { - t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil { - t.Errorf("delete peer: %v", err) - return - } - - wg.Wait() - }) - - t.Run("delete group update", func(t *testing.T) { - wg.Add(1) - go func() { - defer wg.Done() - - message := <-updMsg - networkMap := message.Update.GetNetworkMap() - if len(networkMap.RemotePeers) != 0 { - t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers)) - } - }() - - // clean policy is pre requirement for delete group - _ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID) - - if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil { - t.Errorf("delete group: %v", err) - return - } - - wg.Wait() - }) + wg.Wait() } func TestAccountManager_DeletePeer(t *testing.T) { @@ -2328,3 +2331,46 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool { return true } } + +func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) { + t.Helper() + + manager, err := createManager(t) + if err != nil { + t.Fatal(err) + } + + account, err := createAccount(manager, "test_account", userID, "") + if err != nil { + t.Fatal(err) + } + + setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false) + if err != nil { + t.Fatal("error creating setup key") + } + + getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer { + key, err := wgtypes.GeneratePrivateKey() + if err != nil { + t.Fatal(err) + } + expectedPeerKey := key.PublicKey().String() + + peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + if err != nil { + t.Fatalf("expecting peer to be added, got failure %v", err) + } + + return peer + } + + peer1 := getPeer(manager, setupKey) + peer2 := getPeer(manager, setupKey) + peer3 := getPeer(manager, setupKey) + + return manager, account, peer1, peer2, peer3 +} diff --git a/management/server/network.go b/management/server/network.go index 0e7d753a7..91d844c3e 100644 --- a/management/server/network.go +++ b/management/server/network.go @@ -40,9 +40,9 @@ type Network struct { Dns string // Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added). // Used to synchronize state to the client apps. - Serial uint64 + Serial uint64 `diff:"-"` - mu sync.Mutex `json:"-" gorm:"-"` + mu sync.Mutex `json:"-" gorm:"-" diff:"-"` } // NewNetwork creates a new Network initializing it with a Serial=0 diff --git a/management/server/peer.go b/management/server/peer.go index b8605fbb7..ff30fb1ff 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -261,6 +261,8 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou FirewallRulesIsEmpty: true, }, }, + NetworkMap: &NetworkMap{}, + Checks: []*posture.Checks{}, }) am.peersUpdateManager.CloseChannel(ctx, peer.ID) am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain())) @@ -932,6 +934,6 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account postureChecks := am.getPeerPostureChecks(account, peer) remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap) update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks) - am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update}) + go am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks}) } } diff --git a/management/server/peer/peer.go b/management/server/peer/peer.go index 4f808a79e..a193ac6df 100644 --- a/management/server/peer/peer.go +++ b/management/server/peer/peer.go @@ -18,35 +18,35 @@ type Peer struct { // WireGuard public key Key string `gorm:"index"` // A setup key this peer was registered with - SetupKey string + SetupKey string `diff:"-"` // IP address of the Peer IP net.IP `gorm:"serializer:json"` // Meta is a Peer system meta data - Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"` + Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"` // Name is peer's name (machine name) Name string // DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's // domain to the peer label. e.g. peer-dns-label.netbird.cloud DNSLabel string // Status peer's management connection status - Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"` + Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"` // The user ID that registered the peer - UserID string + UserID string `diff:"-"` // SSHKey is a public SSH key of the peer SSHKey string // SSHEnabled indicates whether SSH server is enabled on the peer SSHEnabled bool // LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login. // Works with LastLogin - LoginExpirationEnabled bool + LoginExpirationEnabled bool `diff:"-"` // LastLogin the time when peer performed last login operation - LastLogin time.Time + LastLogin time.Time `diff:"-"` // CreatedAt records the time the peer was created - CreatedAt time.Time + CreatedAt time.Time `diff:"-"` // Indicate ephemeral peer attribute - Ephemeral bool + Ephemeral bool `diff:"-"` // Geo location based on connection IP - Location Location `gorm:"embedded;embeddedPrefix:location_"` + Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"` } type PeerStatus struct { //nolint:revive diff --git a/management/server/testdata/store.json b/management/server/testdata/store.json index 1fa4e3a9a..6a8fc0712 100644 --- a/management/server/testdata/store.json +++ b/management/server/testdata/store.json @@ -19,7 +19,7 @@ "Revoked": false, "UsedTimes": 0, "LastUsed": "0001-01-01T00:00:00Z", - "AutoGroups": null, + "AutoGroups": ["cq9bbkjjuspi5gd38epg"], "UsageLimit": 0, "Ephemeral": false } @@ -69,9 +69,41 @@ "LastLogin": "0001-01-01T00:00:00Z" } }, - "Groups": null, + "Groups": { + "cq9bbkjjuspi5gd38epg": { + "ID": "cq9bbkjjuspi5gd38epg", + "Name": "All", + "Peers": [] + } + }, "Rules": null, - "Policies": [], + "Policies": [ + { + "ID": "cq9bbkjjuspi5gd38eq0", + "Name": "Default", + "Description": "This is a default rule that allows connections between all the resources", + "Enabled": true, + "Rules": [ + { + "ID": "cq9bbkjjuspi5gd38eq0", + "Name": "Default", + "Description": "This is a default rule that allows connections between all the resources", + "Enabled": true, + "Action": "accept", + "Destinations": [ + "cq9bbkjjuspi5gd38epg" + ], + "Sources": [ + "cq9bbkjjuspi5gd38epg" + ], + "Bidirectional": true, + "Protocol": "all", + "Ports": null + } + ], + "SourcePostureChecks": null + } + ], "Routes": null, "NameServerGroups": null, "DNSSettings": null, diff --git a/management/server/updatechannel.go b/management/server/updatechannel.go index c11225dbc..0db5b323b 100644 --- a/management/server/updatechannel.go +++ b/management/server/updatechannel.go @@ -2,9 +2,12 @@ package server import ( "context" + "fmt" "sync" "time" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/r3labs/diff" log "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/proto" @@ -14,14 +17,18 @@ import ( const channelBufferSize = 100 type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + NetworkMap *NetworkMap + Checks []*posture.Checks } type PeersUpdateManager struct { // peerChannels is an update channel indexed by Peer.ID peerChannels map[string]chan *UpdateMessage + // peerNetworkMaps is the UpdateMessage indexed by Peer.ID. + peerUpdateMessage map[string]*UpdateMessage // channelsMux keeps the mutex to access peerChannels - channelsMux *sync.Mutex + channelsMux *sync.RWMutex // metrics provides method to collect application metrics metrics telemetry.AppMetrics } @@ -29,9 +36,10 @@ type PeersUpdateManager struct { // NewPeersUpdateManager returns a new instance of PeersUpdateManager func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager { return &PeersUpdateManager{ - peerChannels: make(map[string]chan *UpdateMessage), - channelsMux: &sync.Mutex{}, - metrics: metrics, + peerChannels: make(map[string]chan *UpdateMessage), + peerUpdateMessage: make(map[string]*UpdateMessage), + channelsMux: &sync.RWMutex{}, + metrics: metrics, } } @@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda start := time.Now() var found, dropped bool + // skip sending sync update to the peer if there is no change in update message, + // it will not check on turn credential refresh as we do not send network map or client posture checks + if update.NetworkMap != nil { + updated := p.handlePeerMessageUpdate(ctx, peerID, update) + if !updated { + return + } + } + p.channelsMux.Lock() + defer func() { p.channelsMux.Unlock() if p.metrics != nil { @@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda } }() + if update.NetworkMap != nil { + lastSentUpdate := p.peerUpdateMessage[peerID] + if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() >= update.Update.NetworkMap.GetSerial() { + log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update", + peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial()) + return + } + p.peerUpdateMessage[peerID] = update + } + if channel, ok := p.peerChannels[peerID]; ok { found = true select { @@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c closed = true delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } // mbragin: todo shouldn't it be more? or configurable? channel := make(chan *UpdateMessage, channelBufferSize) @@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) { if channel, ok := p.peerChannels[peerID]; ok { delete(p.peerChannels, peerID) close(channel) + delete(p.peerUpdateMessage, peerID) } log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID) @@ -170,3 +200,49 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool { return ok } + +// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent. +func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool { + p.channelsMux.RLock() + lastSentUpdate := p.peerUpdateMessage[peerID] + p.channelsMux.RUnlock() + + if lastSentUpdate != nil { + updated, err := isNewPeerUpdateMessage(lastSentUpdate, update) + if err != nil { + log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err) + return false + } + if !updated { + log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID) + return false + } + } + + return true +} + +// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent. +func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bool, error) { + if lastSentUpdate.Update.NetworkMap.GetSerial() >= currUpdateToSend.Update.NetworkMap.GetSerial() { + return false, nil + } + + changelog, err := diff.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks) + if err != nil { + return false, fmt.Errorf("failed to diff checks: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + changelog, err = diff.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap) + if err != nil { + return false, fmt.Errorf("failed to diff network map: %v", err) + } + if len(changelog) > 0 { + return true, nil + } + + return false, nil +} diff --git a/management/server/updatechannel_test.go b/management/server/updatechannel_test.go index 69f5b895c..6d8caab26 100644 --- a/management/server/updatechannel_test.go +++ b/management/server/updatechannel_test.go @@ -6,6 +6,8 @@ import ( "time" "github.com/netbirdio/netbird/management/proto" + "github.com/netbirdio/netbird/management/server/posture" + "github.com/stretchr/testify/assert" ) // var peersUpdater *PeersUpdateManager @@ -77,3 +79,104 @@ func TestCloseChannel(t *testing.T) { t.Error("Error closing the channel") } } + +func TestHandlePeerMessageUpdate(t *testing.T) { + tests := []struct { + name string + peerID string + existingUpdate *UpdateMessage + newUpdate *UpdateMessage + expectedResult bool + }{ + { + name: "update message with turn credentials update", + peerID: "peer", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + WiretrusteeConfig: &proto.WiretrusteeConfig{}, + }, + }, + expectedResult: true, + }, + { + name: "update message for peer without existing update", + peerID: "peer1", + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + expectedResult: true, + }, + { + name: "update message with no changes in update", + peerID: "peer2", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + expectedResult: false, + }, + { + name: "update message with changes in checks", + peerID: "peer3", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + Checks: []*posture.Checks{}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + Checks: []*posture.Checks{{ID: "check1"}}, + }, + expectedResult: true, + }, + { + name: "update message with lower serial number", + peerID: "peer4", + existingUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 2}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 2}}, + }, + newUpdate: &UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{Serial: 1}, + }, + NetworkMap: &NetworkMap{Network: &Network{Serial: 1}}, + }, + expectedResult: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p := NewPeersUpdateManager(nil) + ctx := context.Background() + + if tt.existingUpdate != nil { + p.peerUpdateMessage[tt.peerID] = tt.existingUpdate + } + + result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate) + assert.Equal(t, tt.expectedResult, result) + }) + } +}