Merge branch 'refs/heads/feature/optimize-network-map-updates' into feature/validate-group-association

This commit is contained in:
bcmmbaga
2024-07-22 15:24:30 +03:00
19 changed files with 523 additions and 183 deletions

View File

@@ -1108,61 +1108,132 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
}
func TestAccountManager_NetworkUpdates(t *testing.T) {
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
return
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
group := group.Group{
ID: "group-id",
Name: "GroupA",
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
}
userID := "account_creator"
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
account, err := createAccount(manager, "test_account", userID, "")
if err != nil {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
return
}
if account.Network.Serial != 0 {
t.Errorf("expecting account network to have an initial Serial=0")
return
}
getPeer := func() *nbpeer.Peer {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return nil
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
}
expectedPeerKey := key.PublicKey().String()
}()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
if err != nil {
t.Fatalf("expecting peer1 to be added, got failure %v", err)
return nil
}
return peer
}
peer1 := getPeer()
peer2 := getPeer()
peer3 := getPeer()
account, err = manager.Store.GetAccount(context.Background(), account.Id)
if err != nil {
t.Fatal(err)
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
}
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
wg.Wait()
}
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
manager, account, peer1, _, _ := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
policy := Policy{
Enabled: true,
Rules: []*PolicyRule{
{
Enabled: true,
Sources: []string{"group-id"},
Destinations: []string{"group-id"},
Bidirectional: true,
Action: PolicyTrafficActionAccept,
},
},
}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("save policy: %v", err)
return
}
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 1 {
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err)
return
}
wg.Wait()
}
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
@@ -1185,108 +1256,40 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
},
}
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("save policy: %v", err)
return
}
wg := sync.WaitGroup{}
t.Run("save group update", func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
t.Errorf("save group: %v", err)
return
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
}
}()
wg.Wait()
})
// clean policy is pre requirement for delete group
if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
t.Run("delete policy update", func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
wg.Wait()
})
t.Run("save policy update", func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 2 {
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
t.Errorf("delete default rule: %v", err)
return
}
wg.Wait()
})
t.Run("delete peer update", func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 1 {
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
}
}()
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
t.Errorf("delete peer: %v", err)
return
}
wg.Wait()
})
t.Run("delete group update", func(t *testing.T) {
wg.Add(1)
go func() {
defer wg.Done()
message := <-updMsg
networkMap := message.Update.GetNetworkMap()
if len(networkMap.RemotePeers) != 0 {
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
}
}()
// clean policy is pre requirement for delete group
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
t.Errorf("delete group: %v", err)
return
}
wg.Wait()
})
wg.Wait()
}
func TestAccountManager_DeletePeer(t *testing.T) {
@@ -2328,3 +2331,46 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
return true
}
}
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
t.Helper()
manager, err := createManager(t)
if err != nil {
t.Fatal(err)
}
account, err := createAccount(manager, "test_account", userID, "")
if err != nil {
t.Fatal(err)
}
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
if err != nil {
t.Fatal("error creating setup key")
}
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
}
expectedPeerKey := key.PublicKey().String()
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
Key: expectedPeerKey,
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
})
if err != nil {
t.Fatalf("expecting peer to be added, got failure %v", err)
}
return peer
}
peer1 := getPeer(manager, setupKey)
peer2 := getPeer(manager, setupKey)
peer3 := getPeer(manager, setupKey)
return manager, account, peer1, peer2, peer3
}

View File

@@ -40,9 +40,9 @@ type Network struct {
Dns string
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
// Used to synchronize state to the client apps.
Serial uint64
Serial uint64 `diff:"-"`
mu sync.Mutex `json:"-" gorm:"-"`
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
}
// NewNetwork creates a new Network initializing it with a Serial=0

View File

@@ -266,6 +266,8 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
FirewallRulesIsEmpty: true,
},
},
NetworkMap: &NetworkMap{},
Checks: []*posture.Checks{},
})
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
@@ -888,7 +890,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
postureChecks := am.getPeerPostureChecks(account, peer)
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
go am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
}
}

View File

@@ -18,35 +18,35 @@ type Peer struct {
// WireGuard public key
Key string `gorm:"index"`
// A setup key this peer was registered with
SetupKey string
SetupKey string `diff:"-"`
// IP address of the Peer
IP net.IP `gorm:"serializer:json"`
// Meta is a Peer system meta data
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
// Name is peer's name (machine name)
Name string
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
DNSLabel string
// Status peer's management connection status
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
// The user ID that registered the peer
UserID string
UserID string `diff:"-"`
// SSHKey is a public SSH key of the peer
SSHKey string
// SSHEnabled indicates whether SSH server is enabled on the peer
SSHEnabled bool
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
// Works with LastLogin
LoginExpirationEnabled bool
LoginExpirationEnabled bool `diff:"-"`
// LastLogin the time when peer performed last login operation
LastLogin time.Time
LastLogin time.Time `diff:"-"`
// CreatedAt records the time the peer was created
CreatedAt time.Time
CreatedAt time.Time `diff:"-"`
// Indicate ephemeral peer attribute
Ephemeral bool
Ephemeral bool `diff:"-"`
// Geo location based on connection IP
Location Location `gorm:"embedded;embeddedPrefix:location_"`
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
}
type PeerStatus struct { //nolint:revive

View File

@@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string {
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
var peerCopy nbpeer.Peer
peerCopy.Status = &peerStatus
result := s.db.Model(&nbpeer.Peer{}).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(peerCopy)
fieldsToUpdate := []string{
"peer_status_last_seen", "peer_status_connected",
"peer_status_login_expired", "peer_status_required_approval",
}
result := s.db.Model(&nbpeer.Peer{}).
Select(fieldsToUpdate).
Where("account_id = ? AND id = ?", accountID, peerID).
Updates(&peerCopy)
if result.Error != nil {
return result.Error
}

View File

@@ -373,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
require.NoError(t, err)
// save status of non-existing peer
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
assert.Error(t, err)
parsedErr, ok := status.FromError(err)
@@ -388,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
IP: net.IP{127, 0, 0, 1},
Meta: nbpeer.PeerSystemMeta{},
Name: "peer name",
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
}
err = store.SaveAccount(context.Background(), account)

View File

@@ -19,7 +19,7 @@
"Revoked": false,
"UsedTimes": 0,
"LastUsed": "0001-01-01T00:00:00Z",
"AutoGroups": null,
"AutoGroups": ["cq9bbkjjuspi5gd38epg"],
"UsageLimit": 0,
"Ephemeral": false
}
@@ -69,9 +69,41 @@
"LastLogin": "0001-01-01T00:00:00Z"
}
},
"Groups": null,
"Groups": {
"cq9bbkjjuspi5gd38epg": {
"ID": "cq9bbkjjuspi5gd38epg",
"Name": "All",
"Peers": []
}
},
"Rules": null,
"Policies": [],
"Policies": [
{
"ID": "cq9bbkjjuspi5gd38eq0",
"Name": "Default",
"Description": "This is a default rule that allows connections between all the resources",
"Enabled": true,
"Rules": [
{
"ID": "cq9bbkjjuspi5gd38eq0",
"Name": "Default",
"Description": "This is a default rule that allows connections between all the resources",
"Enabled": true,
"Action": "accept",
"Destinations": [
"cq9bbkjjuspi5gd38epg"
],
"Sources": [
"cq9bbkjjuspi5gd38epg"
],
"Bidirectional": true,
"Protocol": "all",
"Ports": null
}
],
"SourcePostureChecks": null
}
],
"Routes": null,
"NameServerGroups": null,
"DNSSettings": null,

View File

@@ -2,9 +2,12 @@ package server
import (
"context"
"fmt"
"sync"
"time"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/r3labs/diff"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/proto"
@@ -14,14 +17,18 @@ import (
const channelBufferSize = 100
type UpdateMessage struct {
Update *proto.SyncResponse
Update *proto.SyncResponse
NetworkMap *NetworkMap
Checks []*posture.Checks
}
type PeersUpdateManager struct {
// peerChannels is an update channel indexed by Peer.ID
peerChannels map[string]chan *UpdateMessage
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
peerUpdateMessage map[string]*UpdateMessage
// channelsMux keeps the mutex to access peerChannels
channelsMux *sync.Mutex
channelsMux *sync.RWMutex
// metrics provides method to collect application metrics
metrics telemetry.AppMetrics
}
@@ -29,9 +36,10 @@ type PeersUpdateManager struct {
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
return &PeersUpdateManager{
peerChannels: make(map[string]chan *UpdateMessage),
channelsMux: &sync.Mutex{},
metrics: metrics,
peerChannels: make(map[string]chan *UpdateMessage),
peerUpdateMessage: make(map[string]*UpdateMessage),
channelsMux: &sync.RWMutex{},
metrics: metrics,
}
}
@@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
start := time.Now()
var found, dropped bool
// skip sending sync update to the peer if there is no change in update message,
// it will not check on turn credential refresh as we do not send network map or client posture checks
if update.NetworkMap != nil {
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
if !updated {
return
}
}
p.channelsMux.Lock()
defer func() {
p.channelsMux.Unlock()
if p.metrics != nil {
@@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
}
}()
if update.NetworkMap != nil {
lastSentUpdate := p.peerUpdateMessage[peerID]
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() >= update.Update.NetworkMap.GetSerial() {
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
return
}
p.peerUpdateMessage[peerID] = update
}
if channel, ok := p.peerChannels[peerID]; ok {
found = true
select {
@@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
closed = true
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
}
// mbragin: todo shouldn't it be more? or configurable?
channel := make(chan *UpdateMessage, channelBufferSize)
@@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
if channel, ok := p.peerChannels[peerID]; ok {
delete(p.peerChannels, peerID)
close(channel)
delete(p.peerUpdateMessage, peerID)
}
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
@@ -170,3 +200,49 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
return ok
}
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
p.channelsMux.RLock()
lastSentUpdate := p.peerUpdateMessage[peerID]
p.channelsMux.RUnlock()
if lastSentUpdate != nil {
updated, err := isNewPeerUpdateMessage(lastSentUpdate, update)
if err != nil {
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
return false
}
if !updated {
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
return false
}
}
return true
}
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bool, error) {
if lastSentUpdate.Update.NetworkMap.GetSerial() >= currUpdateToSend.Update.NetworkMap.GetSerial() {
return false, nil
}
changelog, err := diff.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks)
if err != nil {
return false, fmt.Errorf("failed to diff checks: %v", err)
}
if len(changelog) > 0 {
return true, nil
}
changelog, err = diff.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
if err != nil {
return false, fmt.Errorf("failed to diff network map: %v", err)
}
if len(changelog) > 0 {
return true, nil
}
return false, nil
}

View File

@@ -6,6 +6,8 @@ import (
"time"
"github.com/netbirdio/netbird/management/proto"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/stretchr/testify/assert"
)
// var peersUpdater *PeersUpdateManager
@@ -77,3 +79,104 @@ func TestCloseChannel(t *testing.T) {
t.Error("Error closing the channel")
}
}
func TestHandlePeerMessageUpdate(t *testing.T) {
tests := []struct {
name string
peerID string
existingUpdate *UpdateMessage
newUpdate *UpdateMessage
expectedResult bool
}{
{
name: "update message with turn credentials update",
peerID: "peer",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
WiretrusteeConfig: &proto.WiretrusteeConfig{},
},
},
expectedResult: true,
},
{
name: "update message for peer without existing update",
peerID: "peer1",
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
expectedResult: true,
},
{
name: "update message with no changes in update",
peerID: "peer2",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
expectedResult: false,
},
{
name: "update message with changes in checks",
peerID: "peer3",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
Checks: []*posture.Checks{},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
Checks: []*posture.Checks{{ID: "check1"}},
},
expectedResult: true,
},
{
name: "update message with lower serial number",
peerID: "peer4",
existingUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 2},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
},
newUpdate: &UpdateMessage{
Update: &proto.SyncResponse{
NetworkMap: &proto.NetworkMap{Serial: 1},
},
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
},
expectedResult: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := NewPeersUpdateManager(nil)
ctx := context.Background()
if tt.existingUpdate != nil {
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
}
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
assert.Equal(t, tt.expectedResult, result)
})
}
}