mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-30 06:06:38 +00:00
Merge branch 'refs/heads/feature/optimize-network-map-updates' into feature/validate-group-association
This commit is contained in:
@@ -1108,61 +1108,132 @@ func TestAccountManager_AddPeerWithUserID(t *testing.T) {
|
||||
assert.Equal(t, peer.IP.String(), fmt.Sprint(ev.Meta["ip"]))
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
func TestAccountManager_NetworkUpdates_SaveGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
group := group.Group{
|
||||
ID: "group-id",
|
||||
Name: "GroupA",
|
||||
Peers: []string{peer1.ID, peer2.ID, peer3.ID},
|
||||
}
|
||||
|
||||
userID := "account_creator"
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
account, err := createAccount(manager, "test_account", userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
return
|
||||
}
|
||||
|
||||
if account.Network.Serial != 0 {
|
||||
t.Errorf("expecting account network to have an initial Serial=0")
|
||||
return
|
||||
}
|
||||
|
||||
getPeer := func() *nbpeer.Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return nil
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
}()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer1 to be added, got failure %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
peer1 := getPeer()
|
||||
peer2 := getPeer()
|
||||
peer3 := getPeer()
|
||||
|
||||
account, err = manager.Store.GetAccount(context.Background(), account.Id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_SavePolicy(t *testing.T) {
|
||||
manager, account, peer1, _, _ := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
policy := Policy{
|
||||
Enabled: true,
|
||||
Rules: []*PolicyRule{
|
||||
{
|
||||
Enabled: true,
|
||||
Sources: []string{"group-id"},
|
||||
Destinations: []string{"group-id"},
|
||||
Bidirectional: true,
|
||||
Action: PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeletePeer(t *testing.T) {
|
||||
manager, account, peer1, _, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_NetworkUpdates_DeleteGroup(t *testing.T) {
|
||||
manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t)
|
||||
|
||||
updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID)
|
||||
defer manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID)
|
||||
|
||||
@@ -1185,108 +1256,40 @@ func TestAccountManager_NetworkUpdates(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
t.Errorf("save policy: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
t.Run("save group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SaveGroup(context.Background(), account.Id, userID, &group); err != nil {
|
||||
t.Errorf("save group: %v", err)
|
||||
return
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
// clean policy is pre requirement for delete group
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
t.Run("delete policy update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("save policy update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 2 {
|
||||
t.Errorf("mismatch peers count: 2 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.SavePolicy(context.Background(), account.Id, userID, &policy); err != nil {
|
||||
t.Errorf("delete default rule: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
t.Run("delete peer update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 1 {
|
||||
t.Errorf("mismatch peers count: 1 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
if err := manager.DeletePeer(context.Background(), account.Id, peer3.ID, userID); err != nil {
|
||||
t.Errorf("delete peer: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
t.Run("delete group update", func(t *testing.T) {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
message := <-updMsg
|
||||
networkMap := message.Update.GetNetworkMap()
|
||||
if len(networkMap.RemotePeers) != 0 {
|
||||
t.Errorf("mismatch peers count: 0 expected, got %v", len(networkMap.RemotePeers))
|
||||
}
|
||||
}()
|
||||
|
||||
// clean policy is pre requirement for delete group
|
||||
_ = manager.DeletePolicy(context.Background(), account.Id, policy.ID, userID)
|
||||
|
||||
if err := manager.DeleteGroup(context.Background(), account.Id, "", group.ID); err != nil {
|
||||
t.Errorf("delete group: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
})
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestAccountManager_DeletePeer(t *testing.T) {
|
||||
@@ -2328,3 +2331,46 @@ func waitTimeout(wg *sync.WaitGroup, timeout time.Duration) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *Account, *nbpeer.Peer, *nbpeer.Peer, *nbpeer.Peer) {
|
||||
t.Helper()
|
||||
|
||||
manager, err := createManager(t)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
account, err := createAccount(manager, "test_account", userID, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
setupKey, err := manager.CreateSetupKey(context.Background(), account.Id, "test-key", SetupKeyReusable, time.Hour, nil, 999, userID, false)
|
||||
if err != nil {
|
||||
t.Fatal("error creating setup key")
|
||||
}
|
||||
|
||||
getPeer := func(manager *DefaultAccountManager, setupKey *SetupKey) *nbpeer.Peer {
|
||||
key, err := wgtypes.GeneratePrivateKey()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expectedPeerKey := key.PublicKey().String()
|
||||
|
||||
peer, _, _, err := manager.AddPeer(context.Background(), setupKey.Key, "", &nbpeer.Peer{
|
||||
Key: expectedPeerKey,
|
||||
Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("expecting peer to be added, got failure %v", err)
|
||||
}
|
||||
|
||||
return peer
|
||||
}
|
||||
|
||||
peer1 := getPeer(manager, setupKey)
|
||||
peer2 := getPeer(manager, setupKey)
|
||||
peer3 := getPeer(manager, setupKey)
|
||||
|
||||
return manager, account, peer1, peer2, peer3
|
||||
}
|
||||
|
||||
@@ -40,9 +40,9 @@ type Network struct {
|
||||
Dns string
|
||||
// Serial is an ID that increments by 1 when any change to the network happened (e.g. new peer has been added).
|
||||
// Used to synchronize state to the client apps.
|
||||
Serial uint64
|
||||
Serial uint64 `diff:"-"`
|
||||
|
||||
mu sync.Mutex `json:"-" gorm:"-"`
|
||||
mu sync.Mutex `json:"-" gorm:"-" diff:"-"`
|
||||
}
|
||||
|
||||
// NewNetwork creates a new Network initializing it with a Serial=0
|
||||
|
||||
@@ -266,6 +266,8 @@ func (am *DefaultAccountManager) deletePeers(ctx context.Context, account *Accou
|
||||
FirewallRulesIsEmpty: true,
|
||||
},
|
||||
},
|
||||
NetworkMap: &NetworkMap{},
|
||||
Checks: []*posture.Checks{},
|
||||
})
|
||||
am.peersUpdateManager.CloseChannel(ctx, peer.ID)
|
||||
am.StoreEvent(ctx, userID, peer.ID, account.Id, activity.PeerRemovedByUser, peer.EventMeta(am.GetDNSDomain()))
|
||||
@@ -888,7 +890,7 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
|
||||
postureChecks := am.getPeerPostureChecks(account, peer)
|
||||
remotePeerNetworkMap := account.GetPeerNetworkMap(ctx, peer.ID, am.dnsDomain, approvedPeersMap)
|
||||
update := toSyncResponse(ctx, nil, peer, nil, remotePeerNetworkMap, am.GetDNSDomain(), postureChecks)
|
||||
am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update})
|
||||
go am.peersUpdateManager.SendUpdate(ctx, peer.ID, &UpdateMessage{Update: update, NetworkMap: remotePeerNetworkMap, Checks: postureChecks})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -18,35 +18,35 @@ type Peer struct {
|
||||
// WireGuard public key
|
||||
Key string `gorm:"index"`
|
||||
// A setup key this peer was registered with
|
||||
SetupKey string
|
||||
SetupKey string `diff:"-"`
|
||||
// IP address of the Peer
|
||||
IP net.IP `gorm:"serializer:json"`
|
||||
// Meta is a Peer system meta data
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_"`
|
||||
Meta PeerSystemMeta `gorm:"embedded;embeddedPrefix:meta_" diff:"-"`
|
||||
// Name is peer's name (machine name)
|
||||
Name string
|
||||
// DNSLabel is the parsed peer name for domain resolution. It is used to form an FQDN by appending the account's
|
||||
// domain to the peer label. e.g. peer-dns-label.netbird.cloud
|
||||
DNSLabel string
|
||||
// Status peer's management connection status
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_"`
|
||||
Status *PeerStatus `gorm:"embedded;embeddedPrefix:peer_status_" diff:"-"`
|
||||
// The user ID that registered the peer
|
||||
UserID string
|
||||
UserID string `diff:"-"`
|
||||
// SSHKey is a public SSH key of the peer
|
||||
SSHKey string
|
||||
// SSHEnabled indicates whether SSH server is enabled on the peer
|
||||
SSHEnabled bool
|
||||
// LoginExpirationEnabled indicates whether peer's login expiration is enabled and once expired the peer has to re-login.
|
||||
// Works with LastLogin
|
||||
LoginExpirationEnabled bool
|
||||
LoginExpirationEnabled bool `diff:"-"`
|
||||
// LastLogin the time when peer performed last login operation
|
||||
LastLogin time.Time
|
||||
LastLogin time.Time `diff:"-"`
|
||||
// CreatedAt records the time the peer was created
|
||||
CreatedAt time.Time
|
||||
CreatedAt time.Time `diff:"-"`
|
||||
// Indicate ephemeral peer attribute
|
||||
Ephemeral bool
|
||||
Ephemeral bool `diff:"-"`
|
||||
// Geo location based on connection IP
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_"`
|
||||
Location Location `gorm:"embedded;embeddedPrefix:location_" diff:"-"`
|
||||
}
|
||||
|
||||
type PeerStatus struct { //nolint:revive
|
||||
|
||||
@@ -274,10 +274,15 @@ func (s *SqlStore) GetInstallationID() string {
|
||||
func (s *SqlStore) SavePeerStatus(accountID, peerID string, peerStatus nbpeer.PeerStatus) error {
|
||||
var peerCopy nbpeer.Peer
|
||||
peerCopy.Status = &peerStatus
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Updates(peerCopy)
|
||||
|
||||
fieldsToUpdate := []string{
|
||||
"peer_status_last_seen", "peer_status_connected",
|
||||
"peer_status_login_expired", "peer_status_required_approval",
|
||||
}
|
||||
result := s.db.Model(&nbpeer.Peer{}).
|
||||
Select(fieldsToUpdate).
|
||||
Where("account_id = ? AND id = ?", accountID, peerID).
|
||||
Updates(&peerCopy)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
@@ -373,7 +373,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// save status of non-existing peer
|
||||
newStatus := nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()}
|
||||
newStatus := nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()}
|
||||
err = store.SavePeerStatus(account.Id, "non-existing-peer", newStatus)
|
||||
assert.Error(t, err)
|
||||
parsedErr, ok := status.FromError(err)
|
||||
@@ -388,7 +388,7 @@ func TestSqlite_SavePeerStatus(t *testing.T) {
|
||||
IP: net.IP{127, 0, 0, 1},
|
||||
Meta: nbpeer.PeerSystemMeta{},
|
||||
Name: "peer name",
|
||||
Status: &nbpeer.PeerStatus{Connected: false, LastSeen: time.Now().UTC()},
|
||||
Status: &nbpeer.PeerStatus{Connected: true, LastSeen: time.Now().UTC()},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
|
||||
38
management/server/testdata/store.json
vendored
38
management/server/testdata/store.json
vendored
@@ -19,7 +19,7 @@
|
||||
"Revoked": false,
|
||||
"UsedTimes": 0,
|
||||
"LastUsed": "0001-01-01T00:00:00Z",
|
||||
"AutoGroups": null,
|
||||
"AutoGroups": ["cq9bbkjjuspi5gd38epg"],
|
||||
"UsageLimit": 0,
|
||||
"Ephemeral": false
|
||||
}
|
||||
@@ -69,9 +69,41 @@
|
||||
"LastLogin": "0001-01-01T00:00:00Z"
|
||||
}
|
||||
},
|
||||
"Groups": null,
|
||||
"Groups": {
|
||||
"cq9bbkjjuspi5gd38epg": {
|
||||
"ID": "cq9bbkjjuspi5gd38epg",
|
||||
"Name": "All",
|
||||
"Peers": []
|
||||
}
|
||||
},
|
||||
"Rules": null,
|
||||
"Policies": [],
|
||||
"Policies": [
|
||||
{
|
||||
"ID": "cq9bbkjjuspi5gd38eq0",
|
||||
"Name": "Default",
|
||||
"Description": "This is a default rule that allows connections between all the resources",
|
||||
"Enabled": true,
|
||||
"Rules": [
|
||||
{
|
||||
"ID": "cq9bbkjjuspi5gd38eq0",
|
||||
"Name": "Default",
|
||||
"Description": "This is a default rule that allows connections between all the resources",
|
||||
"Enabled": true,
|
||||
"Action": "accept",
|
||||
"Destinations": [
|
||||
"cq9bbkjjuspi5gd38epg"
|
||||
],
|
||||
"Sources": [
|
||||
"cq9bbkjjuspi5gd38epg"
|
||||
],
|
||||
"Bidirectional": true,
|
||||
"Protocol": "all",
|
||||
"Ports": null
|
||||
}
|
||||
],
|
||||
"SourcePostureChecks": null
|
||||
}
|
||||
],
|
||||
"Routes": null,
|
||||
"NameServerGroups": null,
|
||||
"DNSSettings": null,
|
||||
|
||||
@@ -2,9 +2,12 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/r3labs/diff"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
@@ -14,14 +17,18 @@ import (
|
||||
const channelBufferSize = 100
|
||||
|
||||
type UpdateMessage struct {
|
||||
Update *proto.SyncResponse
|
||||
Update *proto.SyncResponse
|
||||
NetworkMap *NetworkMap
|
||||
Checks []*posture.Checks
|
||||
}
|
||||
|
||||
type PeersUpdateManager struct {
|
||||
// peerChannels is an update channel indexed by Peer.ID
|
||||
peerChannels map[string]chan *UpdateMessage
|
||||
// peerNetworkMaps is the UpdateMessage indexed by Peer.ID.
|
||||
peerUpdateMessage map[string]*UpdateMessage
|
||||
// channelsMux keeps the mutex to access peerChannels
|
||||
channelsMux *sync.Mutex
|
||||
channelsMux *sync.RWMutex
|
||||
// metrics provides method to collect application metrics
|
||||
metrics telemetry.AppMetrics
|
||||
}
|
||||
@@ -29,9 +36,10 @@ type PeersUpdateManager struct {
|
||||
// NewPeersUpdateManager returns a new instance of PeersUpdateManager
|
||||
func NewPeersUpdateManager(metrics telemetry.AppMetrics) *PeersUpdateManager {
|
||||
return &PeersUpdateManager{
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
channelsMux: &sync.Mutex{},
|
||||
metrics: metrics,
|
||||
peerChannels: make(map[string]chan *UpdateMessage),
|
||||
peerUpdateMessage: make(map[string]*UpdateMessage),
|
||||
channelsMux: &sync.RWMutex{},
|
||||
metrics: metrics,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -40,7 +48,17 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
start := time.Now()
|
||||
var found, dropped bool
|
||||
|
||||
// skip sending sync update to the peer if there is no change in update message,
|
||||
// it will not check on turn credential refresh as we do not send network map or client posture checks
|
||||
if update.NetworkMap != nil {
|
||||
updated := p.handlePeerMessageUpdate(ctx, peerID, update)
|
||||
if !updated {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
p.channelsMux.Lock()
|
||||
|
||||
defer func() {
|
||||
p.channelsMux.Unlock()
|
||||
if p.metrics != nil {
|
||||
@@ -48,6 +66,16 @@ func (p *PeersUpdateManager) SendUpdate(ctx context.Context, peerID string, upda
|
||||
}
|
||||
}()
|
||||
|
||||
if update.NetworkMap != nil {
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
if lastSentUpdate != nil && lastSentUpdate.Update.NetworkMap.GetSerial() >= update.Update.NetworkMap.GetSerial() {
|
||||
log.WithContext(ctx).Debugf("peer %s new network map serial: %d not greater than last sent: %d, skip sending update",
|
||||
peerID, update.Update.NetworkMap.GetSerial(), lastSentUpdate.Update.NetworkMap.GetSerial())
|
||||
return
|
||||
}
|
||||
p.peerUpdateMessage[peerID] = update
|
||||
}
|
||||
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
found = true
|
||||
select {
|
||||
@@ -80,6 +108,7 @@ func (p *PeersUpdateManager) CreateChannel(ctx context.Context, peerID string) c
|
||||
closed = true
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
// mbragin: todo shouldn't it be more? or configurable?
|
||||
channel := make(chan *UpdateMessage, channelBufferSize)
|
||||
@@ -94,6 +123,7 @@ func (p *PeersUpdateManager) closeChannel(ctx context.Context, peerID string) {
|
||||
if channel, ok := p.peerChannels[peerID]; ok {
|
||||
delete(p.peerChannels, peerID)
|
||||
close(channel)
|
||||
delete(p.peerUpdateMessage, peerID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Debugf("closed updates channel of a peer %s", peerID)
|
||||
@@ -170,3 +200,49 @@ func (p *PeersUpdateManager) HasChannel(peerID string) bool {
|
||||
|
||||
return ok
|
||||
}
|
||||
|
||||
// handlePeerMessageUpdate checks if the update message for a peer is new and should be sent.
|
||||
func (p *PeersUpdateManager) handlePeerMessageUpdate(ctx context.Context, peerID string, update *UpdateMessage) bool {
|
||||
p.channelsMux.RLock()
|
||||
lastSentUpdate := p.peerUpdateMessage[peerID]
|
||||
p.channelsMux.RUnlock()
|
||||
|
||||
if lastSentUpdate != nil {
|
||||
updated, err := isNewPeerUpdateMessage(lastSentUpdate, update)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("error checking for SyncResponse updates: %v", err)
|
||||
return false
|
||||
}
|
||||
if !updated {
|
||||
log.WithContext(ctx).Debugf("peer %s network map is not updated, skip sending update", peerID)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// isNewPeerUpdateMessage checks if the given current update message is a new update that should be sent.
|
||||
func isNewPeerUpdateMessage(lastSentUpdate, currUpdateToSend *UpdateMessage) (bool, error) {
|
||||
if lastSentUpdate.Update.NetworkMap.GetSerial() >= currUpdateToSend.Update.NetworkMap.GetSerial() {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
changelog, err := diff.Diff(lastSentUpdate.Checks, currUpdateToSend.Checks)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff checks: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
changelog, err = diff.Diff(lastSentUpdate.NetworkMap, currUpdateToSend.NetworkMap)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to diff network map: %v", err)
|
||||
}
|
||||
if len(changelog) > 0 {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/proto"
|
||||
"github.com/netbirdio/netbird/management/server/posture"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// var peersUpdater *PeersUpdateManager
|
||||
@@ -77,3 +79,104 @@ func TestCloseChannel(t *testing.T) {
|
||||
t.Error("Error closing the channel")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandlePeerMessageUpdate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
peerID string
|
||||
existingUpdate *UpdateMessage
|
||||
newUpdate *UpdateMessage
|
||||
expectedResult bool
|
||||
}{
|
||||
{
|
||||
name: "update message with turn credentials update",
|
||||
peerID: "peer",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
WiretrusteeConfig: &proto.WiretrusteeConfig{},
|
||||
},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message for peer without existing update",
|
||||
peerID: "peer1",
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with no changes in update",
|
||||
peerID: "peer2",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
Checks: []*posture.Checks{},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
Checks: []*posture.Checks{},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
{
|
||||
name: "update message with changes in checks",
|
||||
peerID: "peer3",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
Checks: []*posture.Checks{},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
Checks: []*posture.Checks{{ID: "check1"}},
|
||||
},
|
||||
expectedResult: true,
|
||||
},
|
||||
{
|
||||
name: "update message with lower serial number",
|
||||
peerID: "peer4",
|
||||
existingUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 2},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 2}},
|
||||
},
|
||||
newUpdate: &UpdateMessage{
|
||||
Update: &proto.SyncResponse{
|
||||
NetworkMap: &proto.NetworkMap{Serial: 1},
|
||||
},
|
||||
NetworkMap: &NetworkMap{Network: &Network{Serial: 1}},
|
||||
},
|
||||
expectedResult: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := NewPeersUpdateManager(nil)
|
||||
ctx := context.Background()
|
||||
|
||||
if tt.existingUpdate != nil {
|
||||
p.peerUpdateMessage[tt.peerID] = tt.existingUpdate
|
||||
}
|
||||
|
||||
result := p.handlePeerMessageUpdate(ctx, tt.peerID, tt.newUpdate)
|
||||
assert.Equal(t, tt.expectedResult, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user