Compare commits

...

45 Commits

Author SHA1 Message Date
pascal
ad3989b34e reduce complexity 2026-06-04 12:21:19 +02:00
pascal
2102349574 cleanup 2026-06-03 19:04:02 +02:00
pascal
08b25e1d19 use nmap peers for affected on AddPeer 2026-06-03 18:44:35 +02:00
pascal
c2471510b6 use unified resolver 2026-06-03 18:31:38 +02:00
pascal
42e25ad602 account for routing peers 2026-06-02 18:56:04 +02:00
pascal
54fc223ba6 Merge branch 'main' into feature/affected-peers 2026-06-01 18:42:31 +02:00
pascal
31aa39572d merge main 2026-05-28 18:51:48 +02:00
pascal
77604bb467 remove nil return 2026-05-26 18:32:32 +02:00
pascal
1072f73a73 fix merge issues 2026-05-26 18:08:03 +02:00
pascal
4b6721c878 reduce cognitive complexity 2026-05-26 18:01:43 +02:00
pascal
0b78e310b7 Merge branch 'main' into feature/affected-peers 2026-05-26 17:57:59 +02:00
pascal
68b942722c account for proxy peers and services 2026-05-26 17:55:18 +02:00
pascal
7af7630e5b consider removed peers from group 2026-05-26 11:43:29 +02:00
pascal
12361e5479 hardening channel drain 2026-05-21 17:40:45 +02:00
pascal
d3ae81e601 use uncanceled context 2026-05-21 17:37:17 +02:00
pascal
5c3f2ab0df update test 2026-05-21 17:10:51 +02:00
pascal
ba554a73d0 Merge remote-tracking branch 'origin/main' into feature/affected-peers 2026-05-21 16:40:56 +02:00
pascal
9e236ac20e Merge branch 'main' into feature/affected-peers
# Conflicts:
#	management/server/peer.go
2026-05-20 18:11:01 +02:00
pascal
c948d7398f further improve db calls 2026-05-08 20:51:46 +02:00
pascal
13d26106f8 improve db calls 2026-05-08 20:44:17 +02:00
pascal
3e83164bcd fix affected group handling 2026-05-08 20:27:47 +02:00
pascal
6568c905c6 fix test 2026-05-08 19:54:14 +02:00
pascal
aa9a1a42f5 remove complexity 2026-05-08 19:36:21 +02:00
pascal
5ae6c25ac0 fix test 2026-05-08 19:31:59 +02:00
pascal
1d906e411d fix test 2026-05-08 19:31:46 +02:00
pascal
3012228b91 missing files 2026-05-08 16:48:09 +02:00
pascal
85851bc477 extract submethods 2026-05-08 16:43:27 +02:00
pascal
fed4f1b024 drain channel between tests 2026-05-08 14:33:31 +02:00
pascal
70e84d5228 add own peer on peer update 2026-05-07 18:07:47 +02:00
pascal
57529c7f18 linter 2026-05-07 17:50:02 +02:00
pascal
fd99bc072d Merge branch 'main' into feature/affected-peers 2026-05-07 17:39:38 +02:00
pascal
40e6ec16c6 log 2026-05-07 17:36:09 +02:00
pascal
ec476d5072 extend logging 2026-05-07 16:55:45 +02:00
pascal
550ae5558e update after merge 2026-05-07 16:24:54 +02:00
pascal
46494bd860 bugfixes 2026-05-07 16:08:45 +02:00
pascal
c7bff8f074 Merge branch 'main' into feature/affected-peers
# Conflicts:
#	management/internals/controllers/network_map/controller/controller.go
2026-05-07 15:59:28 +02:00
pascal
3a95f39f2c Merge branch 'main' into feature/affected-peers
# Conflicts:
#	management/server/group.go
#	management/server/peer.go
2026-05-07 12:28:51 +02:00
pascal
6b4d4076f4 extend tests 2026-05-04 15:16:59 +02:00
pascal
63d2217d8a Merge main into feature/affected-peers
Resolve conflicts keeping affected-peers logic while adopting
UpdateReason parameter from main for UpdateAccountPeers and
BufferUpdateAccountPeers signatures.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-05-04 12:22:13 +02:00
pascal
0bfccd65d2 add to networks modules 2026-04-30 16:20:41 +02:00
pascal
26d778374b clean comments 2026-04-28 16:56:34 +02:00
pascal
5ec8bebfa5 add tests 2026-04-28 16:27:44 +02:00
pascal
cefb37e920 affected filtering on peers update 2026-04-28 13:48:01 +02:00
pascal
5a16c812fd use buffering affected peers 2026-04-27 18:18:29 +02:00
pascal
285bbc5ffb calculate affected peers 2026-04-27 17:49:12 +02:00
38 changed files with 5707 additions and 558 deletions

View File

@@ -22,14 +22,14 @@ type removePeerCall struct {
}
type mockServer struct {
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
mu sync.Mutex
addCalls []addPeerCall
removed []removePeerCall
nextID rp.PeerID
addErr error
removeErr error
closed bool
ran bool
}
func (m *mockServer) AddPeer(cfg rp.PeerConfig) (rp.PeerID, error) {
@@ -51,7 +51,7 @@ func (m *mockServer) RemovePeer(id rp.PeerID) error {
return m.removeErr
}
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Run() error { m.ran = true; return nil }
func (m *mockServer) Close() error { m.closed = true; return nil }
type setPSKCall struct {

View File

@@ -41,4 +41,3 @@ func TestDeterministicSeedKey_TooShortKey_ReturnsError(t *testing.T) {
_, err = DeterministicSeedKey(long, short)
require.Error(t, err)
}

View File

@@ -44,7 +44,7 @@ type Controller struct {
EphemeralPeersManager ephemeral.Manager
accountUpdateLocks sync.Map
sendAccountUpdateLocks sync.Map
affectedPeerUpdateLocks sync.Map
updateAccountPeersBufferInterval atomic.Int64
// dnsDomain is used for peer resolution. This is appended to the peer's name
dnsDomain string
@@ -63,6 +63,13 @@ type bufferUpdate struct {
update atomic.Bool
}
type bufferAffectedUpdate struct {
sendMu sync.Mutex
dataMu sync.Mutex
next *time.Timer
peerIDs map[string]struct{}
}
var _ network_map.Controller = (*Controller)(nil)
func NewController(ctx context.Context, store store.Store, metrics telemetry.AppMetrics, peersUpdateManager network_map.PeersUpdateManager, requestBuffer account.RequestBuffer, integratedPeerValidator integrated_validator.IntegratedValidator, settingsManager settings.Manager, dnsDomain string, proxyController port_forwarding.Controller, ephemeralPeersManager ephemeral.Manager, config *config.Config) *Controller {
@@ -200,7 +207,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
@@ -225,44 +232,6 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
return nil
}
func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
log.WithContext(ctx).Tracef("buffer sending update peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.sendAccountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate)
if !b.mu.TryLock() {
b.update.Store(true)
return nil
}
if b.next != nil {
b.next.Stop()
}
go func() {
defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
if !b.update.Load() {
return
}
b.update.Store(false)
if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID, reason)
})
return
}
b.next.Reset(time.Duration(c.updateAccountPeersBufferInterval.Load()))
}()
return nil
}
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
@@ -272,6 +241,143 @@ func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, r
return c.sendUpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
if len(peerIDs) == 0 {
return nil
}
return c.sendUpdateForAffectedPeers(ctx, accountID, peerIDs)
}
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
globalStart := time.Now()
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
}
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: sending network map to %d connected peers", len(peersToUpdate))
approvedPeersMap, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
return fmt.Errorf("failed to get validate peers: %v", err)
}
var wg sync.WaitGroup
semaphore := make(chan struct{}, 10)
account.InjectProxyPolicies(ctx)
dnsCache := &cache.DNSConfigCache{}
dnsDomain := c.GetDNSDomain(account.Settings)
peersCustomZone := account.GetPeersCustomZone(ctx, dnsDomain)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
return fmt.Errorf("failed to get proxy network maps: %v", err)
}
extraSetting, err := c.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get flow enabled status: %v", err)
}
dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion)
accountZones, err := c.repo.GetAccountZones(ctx, account.Id)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account zones: %v", err)
return fmt.Errorf("failed to get account zones: %v", err)
}
for _, peer := range peersToUpdate {
wg.Add(1)
semaphore <- struct{}{}
go func(p *nbpeer.Peer) {
defer wg.Done()
defer func() { <-semaphore }()
start := time.Now()
postureChecks, err := c.getPeerPostureChecks(account, p.ID)
if err != nil {
log.WithContext(ctx).Debugf("failed to get posture checks for peer %s: %v", p.ID, err)
return
}
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
proxyNetworkMap, ok := proxyNetworkMaps[p.ID]
if ok {
remotePeerNetworkMap.Merge(proxyNetworkMap)
}
peerGroups := account.GetPeerGroups(p.ID)
start = time.Now()
update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort)
c.metrics.CountToSyncResponseDuration(time.Since(start))
c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{
Update: update,
MessageType: network_map.MessageTypeNetworkMap,
})
}(peer)
}
wg.Wait()
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersDuration(time.Since(globalStart))
}
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)
@@ -380,6 +486,100 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
return nil
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (c *Controller) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
if len(peerIDs) == 0 {
return nil
}
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
log.WithContext(ctx).Tracef("buffer updating %d affected peers for account %s from %s", len(peerIDs), accountID, util.GetCallerName())
bufUpd, _ := c.affectedPeerUpdateLocks.LoadOrStore(accountID, &bufferAffectedUpdate{
peerIDs: make(map[string]struct{}),
})
b := bufUpd.(*bufferAffectedUpdate)
b.addPeerIDs(peerIDs)
if !b.sendMu.TryLock() {
// Another goroutine is already sending; it will pick up our IDs on its next drain.
return nil
}
b.stopTimer()
collected := b.drainPeerIDs()
go func() {
defer b.sendMu.Unlock()
_ = c.sendUpdateForAffectedPeers(ctx, accountID, collected)
// Check if more peer IDs accumulated while we were sending.
if !b.hasPending() {
return
}
// Schedule a debounced flush for the newly accumulated IDs.
b.setTimer(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
ids := b.drainPeerIDs()
if len(ids) > 0 {
_ = c.sendUpdateForAffectedPeers(ctx, accountID, ids)
}
})
}()
return nil
}
func (b *bufferAffectedUpdate) addPeerIDs(ids []string) {
b.dataMu.Lock()
for _, id := range ids {
b.peerIDs[id] = struct{}{}
}
b.dataMu.Unlock()
}
func (b *bufferAffectedUpdate) drainPeerIDs() []string {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if len(b.peerIDs) == 0 {
return nil
}
ids := make([]string, 0, len(b.peerIDs))
for id := range b.peerIDs {
ids = append(ids, id)
}
b.peerIDs = make(map[string]struct{})
return ids
}
func (b *bufferAffectedUpdate) hasPending() bool {
b.dataMu.Lock()
defer b.dataMu.Unlock()
return len(b.peerIDs) > 0
}
func (b *bufferAffectedUpdate) stopTimer() {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next != nil {
b.next.Stop()
}
}
func (b *bufferAffectedUpdate) setTimer(d time.Duration, f func()) {
b.dataMu.Lock()
defer b.dataMu.Unlock()
if b.next == nil {
b.next = time.AfterFunc(d, f)
return
}
b.next.Reset(d)
}
func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) {
if isRequiresApproval {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
@@ -577,21 +777,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
return false, nil
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID)
return nil
}
return nil
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationUpdate})
}
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationCreate})
}
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
network, err := c.repo.GetAccountNetwork(ctx, accountID)
if err != nil {
return err
@@ -624,7 +827,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
c.peersUpdateManager.CloseChannel(ctx, peerID)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) == 0 {
log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping network map update", accountID)
return nil
}
return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
}
// GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result)

View File

@@ -19,6 +19,8 @@ const (
type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
@@ -27,9 +29,9 @@ type Controller interface {
GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error)
CountStreams() int
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error
OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error
OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error
DisconnectPeers(ctx context.Context, accountId string, peerIDs []string)
OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error)
OnPeerDisconnected(ctx context.Context, accountID string, peerID string)

View File

@@ -57,6 +57,20 @@ func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, r
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockController) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
ret0, _ := ret[0].(error)
return ret0
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// CountStreams mocks base method.
func (m *MockController) CountStreams() int {
m.ctrl.T.Helper()
@@ -158,45 +172,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID
}
// OnPeersAdded mocks base method.
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersAdded indicates an expected call of OnPeersAdded.
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersDeleted mocks base method.
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error {
func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs)
ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersDeleted indicates an expected call of OnPeersDeleted.
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs)
}
// OnPeersUpdated mocks base method.
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error {
func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs)
ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// OnPeersUpdated indicates an expected call of OnPeersUpdated.
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call {
func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs)
}
// StartWarmup mocks base method.
@@ -250,3 +264,17 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockController) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockControllerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockController)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}

View File

@@ -2573,7 +2573,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -2664,7 +2666,9 @@ func (am *DefaultAccountManager) UpdatePeerIPv6(ctx context.Context, accountID,
}
if updateNetworkMap {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peerID}); err != nil {
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
return fmt.Errorf("notify network map controller: %w", err)
}
}

View File

@@ -13,6 +13,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
@@ -109,7 +110,7 @@ type Manager interface {
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
ExtendPeerSession(ctx context.Context, peerPubKey, userID string) (time.Time, error) // used by peer gRPC API for ExtendAuthSession
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) // used by peer gRPC API
GetExternalCacheManager() ExternalCacheManager
GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error)
@@ -128,6 +129,9 @@ type Manager interface {
GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error)
DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason)
BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error)
SyncUserJWTGroups(ctx context.Context, userAuth auth.UserAuth) error

View File

@@ -15,6 +15,7 @@ import (
dns "github.com/netbirdio/netbird/dns"
service "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
activity "github.com/netbirdio/netbird/management/server/activity"
affectedpeers "github.com/netbirdio/netbird/management/server/affectedpeers"
idp "github.com/netbirdio/netbird/management/server/idp"
peer "github.com/netbirdio/netbird/management/server/peer"
posture "github.com/netbirdio/netbird/management/server/posture"
@@ -122,6 +123,18 @@ func (mr *MockManagerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reas
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
}
// BufferUpdateAffectedPeers mocks base method.
func (m *MockManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "BufferUpdateAffectedPeers", ctx, accountID, peerIDs, reason)
}
// BufferUpdateAffectedPeers indicates an expected call of BufferUpdateAffectedPeers.
func (mr *MockManagerMockRecorder) BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).BufferUpdateAffectedPeers), ctx, accountID, peerIDs, reason)
}
// BuildUserInfosForAccount mocks base method.
func (m *MockManager) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) {
m.ctrl.T.Helper()
@@ -1637,6 +1650,32 @@ func (mr *MockManagerMockRecorder) UpdateAccountPeers(ctx, accountID, reason int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockManager)(nil).UpdateAccountPeers), ctx, accountID, reason)
}
// UpdateAffectedPeers mocks base method.
func (m *MockManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "UpdateAffectedPeers", ctx, accountID, peerIDs)
}
// UpdateAffectedPeers indicates an expected call of UpdateAffectedPeers.
func (mr *MockManagerMockRecorder) UpdateAffectedPeers(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAffectedPeers", reflect.TypeOf((*MockManager)(nil).UpdateAffectedPeers), ctx, accountID, peerIDs)
}
// ResolveAffectedPeers mocks base method.
func (m *MockManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ResolveAffectedPeers", ctx, s, accountID, change)
ret0, _ := ret[0].([]string)
return ret0
}
// ResolveAffectedPeers indicates an expected call of ResolveAffectedPeers.
func (mr *MockManagerMockRecorder) ResolveAffectedPeers(ctx, s, accountID, change interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ResolveAffectedPeers", reflect.TypeOf((*MockManager)(nil).ResolveAffectedPeers), ctx, s, accountID, change)
}
// UpdateAccountSettings mocks base method.
func (m *MockManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
m.ctrl.T.Helper()

View File

@@ -3282,6 +3282,19 @@ func setupNetworkMapTest(t *testing.T) (*DefaultAccountManager, *update_channel.
// when the channel delivers.
const peerUpdateTimeout = 5 * time.Second
func drainPeerUpdates(ch <-chan *network_map.UpdateMessage) {
for {
select {
case _, ok := <-ch:
if !ok {
return
}
case <-time.After(200 * time.Millisecond):
return
}
}
}
func peerShouldNotReceiveUpdate(t *testing.T, updateMessage <-chan *network_map.UpdateMessage) {
t.Helper()
select {

View File

@@ -0,0 +1,119 @@
package server
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
// TestAffectedPeers_DependencyCoverageMatrix enumerates each network-map
// dependency crossed with the change-type that can alter it, asserting the
// resolver folds in exactly the peers whose map changes. A new dependency that
// the resolver fails to walk should fail one of these rows; a new change-type
// without a row is a coverage gap to add here.
func TestAffectedPeers_DependencyCoverageMatrix(t *testing.T) {
type row struct {
name string
build func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string)
}
rows := []row{
{
name: "policy-groups/source-group-change refreshes source+routing, excludes unrelated",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-routing-bridge/router-peer-change refreshes policy sources",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ChangedPeerIDs: []string{s.routerPeerID}},
[]string{s.sourcePeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-change/explicit-policy refreshes source+routing",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "policy-destinationresource/explicit-policy bridges to routing peer",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
return affectedpeers.Change{Policies: []*types.Policy{policy}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "resource-change refreshes source+routing on its network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{ResourceIDs: []string{s.resourceID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "network-change refreshes source+routing on that network",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
return affectedpeers.Change{NetworkIDs: []string{s.networkID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "posture-check-change refreshes source+routing of gated policy",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "cov-min-version",
Checks: posture.ChecksDefinition{NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"}},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return affectedpeers.Change{PostureCheckIDs: []string{check.ID}},
[]string{s.sourcePeerID, s.routerPeerID}, []string{s.unrelatedPeerID}
},
},
{
name: "empty-change yields nothing",
build: func(t *testing.T, s *routerScenario, ctx context.Context) (affectedpeers.Change, []string, []string) {
return affectedpeers.Change{}, nil, []string{s.sourcePeerID, s.routerPeerID, s.unrelatedPeerID}
},
},
}
for _, r := range rows {
t.Run(r.name, func(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
change, mustContain, mustExclude := r.build(t, s, ctx)
affected := s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, change)
for _, id := range mustContain {
assert.Contains(t, affected, id, "expected peer to be affected")
}
for _, id := range mustExclude {
assert.NotContains(t, affected, id, "peer must not be affected")
}
})
}
}

View File

@@ -0,0 +1,143 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/require"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// An update spans an old and a new state. The affected set must be the UNION of
// peers reachable before and after the change; resolving only against the final
// state drops peers that were reachable but no longer are. These tests pin the
// two paths where the old state is reachable only by the changed object's
// previous references: detaching a resource group, and re-pointing a router peer.
// TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources:
// a resource is reachable by a source group via two destination resource groups;
// detaching one of them must still refresh that group's policy source peers, even
// though the post-update resource no longer maps to it.
func TestAffectedPeers_E2E_UpdateResource_DetachGroup_RefreshesOldGroupSources(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A second resource group + a second source group/peer that reaches the
// resource only through that second group.
const detachGroupID = "rs-detach-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: detachGroupID, Name: "rs-detach"}))
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-detach-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
resourcesManager, _, _ := s.managers()
// Attach the resource to the detach group as well: now in [resourceGroup, detachGroup].
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID, detachGroupID},
Enabled: true,
})
require.NoError(t, err)
// Policy granting the second source group access via the detach group.
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(secondSourceGroupID, detachGroupID), true)
require.NoError(t, err)
secondSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, secondSourcePeer.ID) })
settleAffectedUpdates(secondSrcCh)
done := make(chan struct{})
go func() {
// Detaching the resource from detachGroup removes the second source's
// access; that source peer must be refreshed even though the post-update
// resource no longer maps to detachGroup.
peerShouldReceiveUpdate(t, secondSrcCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID}, // detached detachGroup
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: detaching a resource group did not refresh the old group's policy source peer")
}
}
// TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer:
// changing router.Peer within the same network must still refresh the OLD routing
// peer, which loses its routing role.
func TestAffectedPeers_E2E_UpdateRouter_RepointPeer_RefreshesOldRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
_, routersManager, _ := s.managers()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
router := routers[0]
oldRoutingPeer := router.Peer
require.NotEmpty(t, oldRoutingPeer)
// A new peer to become the routing peer in place of the old one.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-newrouter-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
newRoutingPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
oldCh := s.updateManager.CreateChannel(ctx, oldRoutingPeer)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, oldRoutingPeer) })
settleAffectedUpdates(oldCh)
done := make(chan struct{})
go func() {
// The old routing peer stops serving the resource and must be refreshed.
peerShouldReceiveUpdate(t, oldCh)
close(done)
}()
_, err = routersManager.UpdateRouter(ctx, userID, &routerTypes.NetworkRouter{
ID: router.ID,
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: newRoutingPeer.ID, // repoint within the same network
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the router peer did not refresh the old routing peer")
}
}

View File

@@ -0,0 +1,251 @@
package server
import (
"context"
"encoding/json"
"fmt"
"math/rand"
"sort"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/exp/maps"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// allPeerMaps computes the serialized per-peer network map for every peer in the
// account, mirroring the controller's compute path so the property test compares
// against real output.
func allPeerMaps(t *testing.T, manager *DefaultAccountManager, accountID string) map[string]string {
t.Helper()
ctx := context.Background()
account, err := manager.Store.GetAccount(ctx, accountID)
require.NoError(t, err)
account.InjectProxyPolicies(ctx)
validated := make(map[string]struct{}, len(account.Peers))
for id := range account.Peers {
validated[id] = struct{}{}
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
out := make(map[string]string, len(account.Peers))
for peerID := range account.Peers {
nm := account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validated, resourcePolicies, routers, nil, groupIDToUserIDs)
// Network.Serial is an account-global counter bumped on every change; it
// is not a per-peer dependency, so normalize it out of the comparison.
if nm.Network != nil {
nm.Network.Serial = 0
}
out[peerID] = canonicalJSON(t, nm)
}
return out
}
// canonicalJSON marshals v and returns an order-insensitive string form: every
// JSON array is sorted by the canonical form of its elements. The network map's
// Peers/Routes/FirewallRules/SourceRanges slices have nondeterministic order, so
// a raw JSON compare would report spurious changes.
func canonicalJSON(t *testing.T, v interface{}) string {
t.Helper()
b, err := json.Marshal(v)
require.NoError(t, err)
var parsed interface{}
require.NoError(t, json.Unmarshal(b, &parsed))
canonicalized, err := json.Marshal(sortAny(parsed))
require.NoError(t, err)
return string(canonicalized)
}
func sortAny(v interface{}) interface{} {
switch val := v.(type) {
case []interface{}:
for i := range val {
val[i] = sortAny(val[i])
}
sort.Slice(val, func(i, j int) bool {
bi, _ := json.Marshal(val[i])
bj, _ := json.Marshal(val[j])
return string(bi) < string(bj)
})
return val
case map[string]interface{}:
for k := range val {
val[k] = sortAny(val[k])
}
return val
default:
return v
}
}
// changedPeers returns the peer IDs whose serialized map differs between before
// and after.
func changedPeers(before, after map[string]string) []string {
var changed []string
for id, b := range before {
a, ok := after[id]
if !ok || a != b {
changed = append(changed, id)
}
}
for id := range after {
if _, ok := before[id]; !ok {
changed = append(changed, id)
}
}
return changed
}
// TestAffectedPeers_Property_ResolverSupersetsRealChanges builds a topology,
// applies random changes, and asserts that the resolver's affected set is a
// superset of the peers whose real network map actually changed. If the resolver
// ever misses a dependency, a change will alter a peer's map without that peer
// appearing in the affected set, failing here.
func TestAffectedPeers_Property_ResolverSupersetsRealChanges(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing peer->resource policy so the resource/router bridge is live.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Extra peers and groups to give mutations room to move membership around.
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "prop-key", types.SetupKeyReusable, 0, nil, 999, userID, false, false)
require.NoError(t, err)
extraPeers := make([]string, 0, 4)
for i := 0; i < 4; i++ {
p := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
extraPeers = append(extraPeers, p.ID)
}
extraGroups := []string{"prop-grp-0", "prop-grp-1"}
for _, g := range extraGroups {
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{ID: g, Name: g}))
}
rng := rand.New(rand.NewSource(1))
allGroups := append([]string{s.sourceGroupID, s.resourceGroupID, s.routerPeerGroupID}, extraGroups...)
allPeers := append([]string{s.sourcePeerID, s.routerPeerID, s.routerGroupPeerID, s.unrelatedPeerID}, extraPeers...)
for iter := 0; iter < 60; iter++ {
change, apply := s.randomMutation(t, rng, allGroups, allPeers)
if apply == nil {
continue
}
before := allPeerMaps(t, s.manager, s.accountID)
resolvedSet := make(map[string]struct{})
resolve := func() {
require.NoError(t, s.manager.Store.ExecuteInTransaction(ctx, func(tx store.Store) error {
for _, id := range s.manager.ResolveAffectedPeers(ctx, tx, s.accountID, change) {
resolvedSet[id] = struct{}{}
}
return nil
}))
}
// Resolve on both sides of the mutation and union: removals are visible
// only pre-apply (the leaving peer is still a member), additions only
// post-apply (the joining peer is now a member). Production captures both
// via per-path handling (e.g. UpdateGroup passes peersToRemove); the union
// models that without coupling the test to each path's ordering.
resolve()
changedIDs := change.ChangedPeerIDs
apply()
resolve()
after := allPeerMaps(t, s.manager, s.accountID)
// The explicitly-changed peer's own map refresh is the caller's
// responsibility (the resolver returns the peers to propagate to), so it
// is allowed to be absent from the resolved set.
changedExplicitly := make(map[string]struct{}, len(changedIDs))
for _, id := range changedIDs {
changedExplicitly[id] = struct{}{}
}
for _, id := range changedPeers(before, after) {
if _, stillExists := after[id]; !stillExists {
continue
}
if _, isExplicit := changedExplicitly[id]; isExplicit {
continue
}
_, ok := resolvedSet[id]
require.Truef(t, ok,
"iter %d: peer %s network map changed but was not in the resolver's affected set %v (change=%+v)",
iter, id, maps.Keys(resolvedSet), change)
}
}
}
// randomMutation picks a random change, returns the Change to resolve and a
// function that applies the underlying store mutation. apply is nil when the
// drawn mutation is a no-op for the current state.
func (s *routerScenario) randomMutation(t *testing.T, rng *rand.Rand, allGroups, allPeers []string) (affectedpeers.Change, func()) {
t.Helper()
ctx := context.Background()
switch rng.Intn(3) {
case 0:
groupID := allGroups[rng.Intn(len(allGroups))]
peerID := allPeers[rng.Intn(len(allPeers))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if slicesContains(grp.Peers, peerID) {
return affectedpeers.Change{}, nil
}
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupAddPeer(ctx, s.accountID, groupID, peerID))
}
case 1:
groupID := allGroups[rng.Intn(len(allGroups))]
grp, err := s.manager.Store.GetGroupByID(ctx, store.LockingStrengthNone, s.accountID, groupID)
require.NoError(t, err)
if len(grp.Peers) == 0 {
return affectedpeers.Change{}, nil
}
peerID := grp.Peers[rng.Intn(len(grp.Peers))]
return affectedpeers.Change{ChangedGroupIDs: []string{groupID}, ChangedPeerIDs: []string{peerID}},
func() {
require.NoError(t, s.manager.GroupDeletePeer(ctx, s.accountID, groupID, peerID))
}
default:
src := allGroups[rng.Intn(len(allGroups))]
dst := allGroups[rng.Intn(len(allGroups))]
policy := &types.Policy{
Enabled: true,
Name: fmt.Sprintf("prop-policy-%d", rng.Int()),
Rules: []*types.PolicyRule{{
Enabled: true,
Sources: []string{src},
Destinations: []string{dst},
Action: types.PolicyTrafficActionAccept,
}},
}
return affectedpeers.Change{Policies: []*types.Policy{policy}},
func() {
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
}
}
}
func slicesContains(s []string, v string) bool {
for _, x := range s {
if x == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,125 @@
package server
import (
"context"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// countingStore wraps a real store and counts the per-account collection loads
// the resolver performs, so a test can assert each is read at most once and that
// irrelevant collections are skipped entirely.
type countingStore struct {
store.Store
mu sync.Mutex
counts map[string]int
}
func newCountingStore(s store.Store) *countingStore {
return &countingStore{Store: s, counts: map[string]int{}}
}
func (c *countingStore) bump(name string) {
c.mu.Lock()
c.counts[name]++
c.mu.Unlock()
}
func (c *countingStore) count(name string) int {
c.mu.Lock()
defer c.mu.Unlock()
return c.counts[name]
}
func (c *countingStore) GetAccountPolicies(ctx context.Context, ls store.LockingStrength, accountID string) ([]*types.Policy, error) {
c.bump("policies")
return c.Store.GetAccountPolicies(ctx, ls, accountID)
}
func (c *countingStore) GetAccountRoutes(ctx context.Context, ls store.LockingStrength, accountID string) ([]*route.Route, error) {
c.bump("routes")
return c.Store.GetAccountRoutes(ctx, ls, accountID)
}
func (c *countingStore) GetAccountNameServerGroups(ctx context.Context, ls store.LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
c.bump("nameservers")
return c.Store.GetAccountNameServerGroups(ctx, ls, accountID)
}
func (c *countingStore) GetAccountDNSSettings(ctx context.Context, ls store.LockingStrength, accountID string) (*types.DNSSettings, error) {
c.bump("dnssettings")
return c.Store.GetAccountDNSSettings(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkRoutersByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*routerTypes.NetworkRouter, error) {
c.bump("routers")
return c.Store.GetNetworkRoutersByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetNetworkResourcesByAccountID(ctx context.Context, ls store.LockingStrength, accountID string) ([]*resourceTypes.NetworkResource, error) {
c.bump("resources")
return c.Store.GetNetworkResourcesByAccountID(ctx, ls, accountID)
}
func (c *countingStore) GetAccountServices(ctx context.Context, ls store.LockingStrength, accountID string) ([]*rpservice.Service, error) {
c.bump("services")
return c.Store.GetAccountServices(ctx, ls, accountID)
}
// TestAffectedPeers_QueryCount_NoRedundantFullTableLoads asserts the resolver
// loads each per-account collection at most once per Resolve (memoization) even
// on a change that drives every bridge, and skips the services table when the
// account has no embedded proxy peers.
func TestAffectedPeers_QueryCount_NoRedundantFullTableLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
cs := newCountingStore(s.manager.Store)
// A group change that exercises policies, routers, resources and the bridge.
affected, err := affectedpeers.Resolve(ctx, cs, s.accountID, affectedpeers.Change{ChangedGroupIDs: []string{s.sourceGroupID}})
require.NoError(t, err)
assert.Contains(t, affected, s.routerPeerID, "bridge must still resolve the routing peer")
for _, name := range []string{"policies", "routes", "nameservers", "dnssettings", "routers", "resources"} {
assert.LessOrEqualf(t, cs.count(name), 1,
"%s must be loaded at most once per Resolve, got %d", name, cs.count(name))
}
assert.Equal(t, 0, cs.count("services"),
"services must not be loaded when the account has no embedded proxy peers")
}
// TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads asserts that a change with
// no group/peer signal touches no per-account collections beyond what its inputs
// require.
func TestAffectedPeers_QueryCount_NarrowChangeSkipsLoads(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
cs := newCountingStore(s.manager.Store)
// A bare network-id change drives only the router->source bridge: routers and
// resources are needed, but routes/nameservers/dnssettings/services are not.
_, err := affectedpeers.Resolve(ctx, cs, s.accountID, affectedpeers.Change{NetworkIDs: []string{s.networkID}})
require.NoError(t, err)
assert.Equal(t, 0, cs.count("routes"), "routes must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("nameservers"), "nameservers must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("dnssettings"), "dnssettings must not be loaded for a network-only change")
assert.Equal(t, 0, cs.count("services"), "services must not be loaded for a network-only change")
}

View File

@@ -0,0 +1,369 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/affectedpeers"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/types"
)
func (s *routerScenario) resolveGroupChangeAffected(ctx context.Context, changedGroupIDs []string) []string {
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{ChangedGroupIDs: changedGroupIDs})
}
func (s *routerScenario) resolvePeerChangeAffected(ctx context.Context, changedPeerIDs []string) []string {
return s.manager.resolveAffectedPeersForPeerChanges(ctx, s.manager.Store, s.accountID, changedPeerIDs)
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.sourcePeerID, "source group member must be affected")
assert.Contains(t, affected, s.routerPeerID,
"changing the source group of a peer->resource policy must refresh the resource's routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_SourceGroupMembership_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerGroupPeerID,
"changing the source group must refresh the routing peer defined via router.PeerGroups")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_GroupChange_RouterPeerGroupMembership_RefreshesPolicySources(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.routerPeerGroupID})
assert.Contains(t, affected, s.routerGroupPeerID, "the routing peer itself must be affected")
assert.Contains(t, affected, s.sourcePeerID,
"changing the router's PeerGroups must refresh the source peers of policies serving the resource")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"a status change on a source peer must refresh the resource's routing peer that serves it")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_RoutingPeer_RefreshesPolicySources(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.routerPeerID})
assert.Contains(t, affected, s.sourcePeerID,
"a status change on the routing peer must refresh the source peers that route through it")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_PeerChange_SourcePeer_ByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID,
"DestinationResource-targeted policy must still bridge a source-peer change to the routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_E2E_DeleteGroup_ResolvesAffectedPeers(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const memberOnlyGroupID = "rs-memberonly-grp"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: memberOnlyGroupID, Name: "rs-memberonly", Peers: []string{s.sourcePeerID},
}))
affected := s.resolveGroupChangeAffected(ctx, []string{memberOnlyGroupID})
assert.Empty(t, affected, "an unlinked group has no network-map impact, so no peer is affected")
require.NoError(t, s.manager.DeleteGroup(ctx, s.accountID, userID, memberOnlyGroupID))
}
func TestAffectedPeers_DeleteGroup_LinkedGroupIsBlocked(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
err = s.manager.DeleteGroup(ctx, s.accountID, userID, s.sourceGroupID)
require.Error(t, err, "deleting a policy-linked group must be blocked by validateDeleteGroup")
var linkErr *GroupLinkError
require.ErrorAs(t, err, &linkErr, "expected a GroupLinkError")
}
func TestAffectedPeers_GroupAddResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const extraResourceGroupID = "rs-resource-grp-extra"
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: extraResourceGroupID, Name: "rs-resource-extra",
}))
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, extraResourceGroupID), true)
require.NoError(t, err)
require.NoError(t, s.manager.GroupAddResource(ctx, s.accountID, extraResourceGroupID, types.Resource{
ID: s.resourceID,
Type: types.ResourceTypeHost,
}))
affected := s.resolveGroupChangeAffected(ctx, []string{extraResourceGroupID})
assert.Contains(t, affected, s.routerPeerID,
"attaching a resource to a policy destination group must refresh the resource's routing peer")
assert.Contains(t, affected, s.sourcePeerID, "policy source peers must refresh")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func (s *routerScenario) resolvePostureCheckAffected(ctx context.Context, postureCheckID string) []string {
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{PostureCheckIDs: []string{postureCheckID}})
}
func (s *routerScenario) createPostureCheckGatedPolicy(t *testing.T, ctx context.Context) string {
t.Helper()
check, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.30.0"},
},
}, true)
require.NoError(t, err)
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.SourcePostureChecks = []string{check.ID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, true)
require.NoError(t, err)
return check.ID
}
func TestAffectedPeers_PostureCheckChange_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
checkID := s.createPostureCheckGatedPolicy(t, ctx)
affected := s.resolvePostureCheckAffected(ctx, checkID)
assert.Contains(t, affected, s.sourcePeerID, "policy source peer must be affected by a posture-check change")
assert.Contains(t, affected, s.routerPeerID,
"a posture check gating a peer->resource policy must refresh the resource's routing peer")
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
func TestAffectedPeers_E2E_SavePostureCheck_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
checkID := s.createPostureCheckGatedPolicy(t, ctx)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePostureChecks(ctx, s.accountID, userID, &posture.Checks{
ID: checkID,
Name: "rs-min-version",
Checks: posture.ChecksDefinition{
NBVersionCheck: &posture.NBVersionCheck{MinVersion: "0.31.0"},
},
}, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: editing a posture check did not refresh source + routing peers")
}
}
func TestAffectedPeers_E2E_UpdateResource_DestinationResourcePolicy_RefreshesSourcePeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
resourcesManager, _, _ := s.managers()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: updating a DestinationResource-targeted resource did not refresh its policy source peer")
}
}
func TestAffectedPeers_E2E_UpdateResource_DisabledSiblingRouter_StillBridged(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
resourcesManager, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-disabled", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
disabledRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: disabledRouterPeer.ID,
Masquerade: true,
Metric: 9000,
Enabled: false,
})
require.NoError(t, err)
disabledCh := s.updateManager.CreateChannel(ctx, disabledRouterPeer.ID)
t.Cleanup(func() { s.updateManager.CloseChannel(ctx, disabledRouterPeer.ID) })
settleAffectedUpdates(disabledCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, disabledCh)
close(done)
}()
_, err = resourcesManager.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/25",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh the disabled sibling router's peer")
}
}
func TestAffectedPeers_GroupChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "groupiso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolveGroupChangeAffected(ctx, []string{s.sourceGroupID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-group change for another resource")
}
func TestAffectedPeers_PeerChange_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "peeriso")
ctx := context.Background()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
affected := s.resolvePeerChangeAffected(ctx, []string{s.sourcePeerID})
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a source-peer change for another resource")
}

View File

@@ -0,0 +1,791 @@
package server
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks"
"github.com/netbirdio/netbird/management/server/networks/resources"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/networks/routers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
)
// routerScenario captures the topology from the bug report:
//
// network ── router (routing peer) ── resource (in resourceGroup)
// independent peer ──(policy: source -> resource)──> resource
//
// The routing peer must be refreshed when a policy grants a source peer access
// to the resource, because the network map connects the source peer to the
// routing peer at compute time (Account.GetPoliciesForNetworkResource +
// addNetworksRoutingPeers). The routing peer is NOT a member of the resource
// group, so static group/peer resolution alone cannot find it.
type routerScenario struct {
manager *DefaultAccountManager
updateManager *update_channel.PeersUpdateManager
accountID string
networkID string
sourcePeerID string // independent peer that the policy grants access from
sourceGroupID string // group containing the source peer
routerPeerID string // peer acting as the routing peer (direct router.Peer)
routerGroupPeerID string // peer that is a member of routerPeerGroup
routerPeerGroupID string // group used for router.PeerGroups
resourceID string // network resource
resourceGroupID string // group whose member is the resource (no peers)
unrelatedPeerID string // peer in no relevant entity
}
// setupRouterScenario builds the topology above with the default policy removed
// and channels NOT yet created, so callers control exactly when updates can flow.
func setupRouterScenario(t *testing.T, directRouterPeer bool) *routerScenario {
t.Helper()
manager, updateManager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
account, err := createAccount(manager, "router_scenario", userID, "")
require.NoError(t, err)
accountID := account.Id
// Remove the default policy so AddPeer/CreateGroup don't schedule unrelated updates.
policies, err := manager.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
require.NoError(t, err)
for _, p := range policies {
require.NoError(t, manager.Store.DeletePolicy(ctx, accountID, p.ID))
}
setupKey, err := manager.CreateSetupKey(ctx, accountID, "rs-key", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
sourcePeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
routerGroupPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
unrelatedPeer := addPeerToAccount(t, manager, accountID, setupKey.Key)
const (
sourceGroupID = "rs-source-grp"
routerPeerGroupID = "rs-router-grp"
resourceGroupID = "rs-resource-grp"
)
for _, g := range []*types.Group{
{ID: sourceGroupID, Name: "rs-source", Peers: []string{sourcePeer.ID}},
{ID: routerPeerGroupID, Name: "rs-router", Peers: []string{routerGroupPeer.ID}},
{ID: resourceGroupID, Name: "rs-resource"}, // intentionally peerless; the resource is its only member
} {
require.NoError(t, manager.CreateGroup(ctx, accountID, userID, g))
}
permissionsManager := permissions.NewManager(manager.Store)
groupsManager := groups.NewManager(manager.Store, permissionsManager, manager)
resourcesManager := resources.NewManager(manager.Store, permissionsManager, groupsManager, manager, manager.serviceManager)
routersManager := routers.NewManager(manager.Store, permissionsManager, manager)
networksManager := networks.NewManager(manager.Store, permissionsManager, resourcesManager, routersManager, manager)
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network",
AccountID: accountID,
Name: "rs-network",
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: accountID,
NetworkID: network.ID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
router := &routerTypes.NetworkRouter{
ID: "rs-router",
NetworkID: network.ID,
AccountID: accountID,
Masquerade: true,
Metric: 9999,
Enabled: true,
}
if directRouterPeer {
router.Peer = routerPeer.ID
} else {
router.PeerGroups = []string{routerPeerGroupID}
}
_, err = routersManager.CreateRouter(ctx, userID, router)
require.NoError(t, err)
return &routerScenario{
manager: manager,
updateManager: updateManager,
accountID: accountID,
networkID: network.ID,
sourcePeerID: sourcePeer.ID,
sourceGroupID: sourceGroupID,
routerPeerID: routerPeer.ID,
routerGroupPeerID: routerGroupPeer.ID,
routerPeerGroupID: routerPeerGroupID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
unrelatedPeerID: unrelatedPeer.ID,
}
}
// peerToResourcePolicy builds a policy granting the source group access to the
// resource, referencing the resource by its group in the rule destination.
func peerToResourcePolicyByGroup(sourceGroupID, resourceGroupID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-group",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
Destinations: []string{resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// peerToResourcePolicyByResource builds a policy referencing the resource
// directly via DestinationResource rather than its group.
func peerToResourcePolicyByResource(sourceGroupID, resourceID string) *types.Policy {
return &types.Policy{
Enabled: true,
Name: "peer-to-resource-by-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{sourceGroupID},
DestinationResource: types.Resource{ID: resourceID, Type: types.ResourceTypeHost},
Action: types.PolicyTrafficActionAccept,
},
},
}
}
// ---------------------------------------------------------------------------
// Resolution-level tests: collectPolicyAffectedGroupsAndPeers + resolvePeerIDs.
//
// These isolate the resolver from the controller and assert directly on the set
// of peer IDs the policy path would refresh. They make the gap explicit: the
// routing peer is expected to be in the affected set but is not.
// ---------------------------------------------------------------------------
// resolvePolicyAffected mirrors SavePolicy's resolution: resolve the affected
// peers for the given policy.
func (s *routerScenario) resolvePolicyAffected(ctx context.Context, policy *types.Policy) []string {
return s.manager.ResolveAffectedPeers(ctx, s.manager.Store, s.accountID, affectedpeers.Change{Policies: []*types.Policy{policy}})
}
func TestAffectedPeers_PolicyToResourceByGroup_IncludesSourcePeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// The source peer is in the source group and must always be present.
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
}
func TestAffectedPeers_PolicyToResourceByGroup_IncludesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// BUG: the direct routing peer serves the resource's subnet to the source
// peer, so it must be refreshed when the policy is created. The policy path
// only resolves the literal rule groups (source group + resource group);
// the resource group has no peer members and the router peer is reachable
// only through the network, so it is dropped.
assert.Contains(t, affected, s.routerPeerID,
"routing peer (router.Peer) serving the resource must be affected by a policy granting access to it")
}
func TestAffectedPeers_PolicyToResourceByGroup_IncludesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Same gap when the router is defined via PeerGroups instead of a direct peer.
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (router.PeerGroups member) serving the resource must be affected")
}
func TestAffectedPeers_PolicyToResourceByDestinationResource_IncludesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
// When the resource is referenced via DestinationResource, RuleGroups()
// returns only the source group and the resource ID is not a peer, so
// collectPolicyAffectedGroupsAndPeers yields nothing for the destination at
// all. The routing peer is dropped here too.
assert.Contains(t, affected, s.routerPeerID,
"routing peer must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_PolicyToResourceByDestinationResource_IncludesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
policy := peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID)
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerGroupPeerID,
"routing peer (PeerGroups) must be affected when the resource is referenced via DestinationResource")
}
func TestAffectedPeers_PolicyToResourceWithSourceResourcePeer_IncludesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// Source expressed as a direct peer (SourceResource), destination as resource group.
policy := &types.Policy{
Enabled: true,
Name: "sourceResource-peer-to-resource",
Rules: []*types.PolicyRule{
{
Enabled: true,
SourceResource: types.Resource{ID: s.sourcePeerID, Type: types.ResourceTypePeer},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
// The direct source peer IS picked up (collectPolicyAffectedGroupsAndPeers
// handles SourceResource peers), but the routing peer is still missing.
assert.Contains(t, affected, s.sourcePeerID, "direct source peer must be affected")
assert.Contains(t, affected, s.routerPeerID, "routing peer must be affected")
}
func TestAffectedPeers_PolicyToResource_UnrelatedPeerNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
affected := s.resolvePolicyAffected(ctx, policy)
// Guard against an over-broad fix: a peer in no relevant entity must never
// be pulled in.
assert.NotContains(t, affected, s.unrelatedPeerID, "unrelated peer must not be affected")
}
// ---------------------------------------------------------------------------
// Control: the resource/router managers DO bridge resource-group -> router.
// These document the existing (correct) behaviour on the resource side and
// highlight the asymmetry with the policy side above.
// ---------------------------------------------------------------------------
func TestAffectedPeers_ResourceSideBridgesToRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
// A pre-existing policy grants the source group access to the resource.
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
// Drive an update through the resource manager and assert the routing peer
// is among the affected set by observing the channel. This path walks
// policies whose destinations reference the resource's groups, folds in the
// source groups, and loads the network's routers, so it reaches both the
// source peer and the routing peer.
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
rm := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err = rm.UpdateResource(ctx, userID, &resourceTypes.NetworkResource{
ID: s.resourceID,
AccountID: s.accountID,
NetworkID: s.networkID,
Name: "rs-resource-host",
Address: "10.20.30.0/24",
GroupIDs: []string{s.resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: resource update did not refresh source peer + routing peer")
}
}
// ---------------------------------------------------------------------------
// End-to-end: reproduce the reported symptom through SavePolicy with channels.
//
// Creating the peer->resource policy must wake the routing peer. These fail on
// the current code because the policy path never resolves the routing peer.
//
// IMPORTANT: setup (CreateNetwork/CreateResource/CreateRouter) fires async
// `go UpdateAffectedPeers` goroutines. Channels are opened and then drained via
// settleAffectedUpdates before the action under test, so the assertion only
// observes updates caused by that action and not stragglers from setup. The
// resolution-level tests above are the timing-free, authoritative proof; these
// reproduce the operator-visible symptom.
// ---------------------------------------------------------------------------
// settleAffectedUpdates waits for in-flight async updates to arrive, then drains
// every given channel so subsequent assertions start from a clean slate.
func settleAffectedUpdates(chans ...<-chan *network_map.UpdateMessage) {
time.Sleep(300 * time.Millisecond)
for _, ch := range chans {
drainPeerUpdates(ch)
}
}
func TestAffectedPeers_E2E_CreatePolicyToResource_RefreshesRoutingPeer_DirectRouter(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
unrelatedCh := s.updateManager.CreateChannel(ctx, s.unrelatedPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, s.unrelatedPeerID)
})
settleAffectedUpdates(srcCh, routerCh, unrelatedCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh) // FAILS today: routing peer not resolved
peerShouldNotReceiveUpdate(t, unrelatedCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: creating peer->resource policy did not refresh the routing peer")
}
}
func TestAffectedPeers_E2E_CreatePolicyToResource_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh) // FAILS today
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer (PeerGroups) not refreshed on policy create")
}
}
func TestAffectedPeers_E2E_CreatePolicyByDestinationResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh) // FAILS today
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: routing peer not refreshed when policy targets DestinationResource")
}
}
func TestAffectedPeers_E2E_DeletePolicyToResource_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh) // FAILS today: deleting the policy must also refresh the router
close(done)
}()
require.NoError(t, s.manager.DeletePolicy(ctx, s.accountID, policy.ID, userID))
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: deleting peer->resource policy did not refresh the routing peer")
}
}
func (s *routerScenario) managers() (resources.Manager, routers.Manager, networks.Manager) {
permissionsManager := permissions.NewManager(s.manager.Store)
groupsManager := groups.NewManager(s.manager.Store, permissionsManager, s.manager)
resourcesManager := resources.NewManager(s.manager.Store, permissionsManager, groupsManager, s.manager, s.manager.serviceManager)
routersManager := routers.NewManager(s.manager.Store, permissionsManager, s.manager)
networksManager := networks.NewManager(s.manager.Store, permissionsManager, resourcesManager, routersManager, s.manager)
return resourcesManager, routersManager, networksManager
}
type secondTopology struct {
networkID string
resourceID string
resourceGroupID string
routerPeerID string
}
func (s *routerScenario) addSecondTopology(t *testing.T, suffix string) secondTopology {
t.Helper()
ctx := context.Background()
resourcesManager, routersManager, networksManager := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-"+suffix, types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
routerPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
resourceGroupID := "rs-resource-grp-" + suffix
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: resourceGroupID, Name: "rs-resource-" + suffix,
}))
network, err := networksManager.CreateNetwork(ctx, userID, &networkTypes.Network{
ID: "rs-network-" + suffix,
AccountID: s.accountID,
Name: "rs-network-" + suffix,
})
require.NoError(t, err)
resource, err := resourcesManager.CreateResource(ctx, userID, &resourceTypes.NetworkResource{
AccountID: s.accountID,
NetworkID: network.ID,
Name: "rs-resource-host-" + suffix,
Address: "10.40.50.0/24",
GroupIDs: []string{resourceGroupID},
Enabled: true,
})
require.NoError(t, err)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: network.ID,
AccountID: s.accountID,
Peer: routerPeer.ID,
Masquerade: true,
Metric: 9999,
Enabled: true,
})
require.NoError(t, err)
return secondTopology{
networkID: network.ID,
resourceID: resource.ID,
resourceGroupID: resourceGroupID,
routerPeerID: routerPeer.ID,
}
}
func TestAffectedPeers_E2E_UpdatePolicyRepointResource_RefreshesBothRoutingPeers(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "b")
ctx := context.Background()
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerACh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
routerBCh := s.updateManager.CreateChannel(ctx, second.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
s.updateManager.CloseChannel(ctx, second.routerPeerID)
})
settleAffectedUpdates(srcCh, routerACh, routerBCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerACh)
peerShouldReceiveUpdate(t, routerBCh)
close(done)
}()
policy.Rules[0].Destinations = []string{second.resourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: re-pointing the policy destination did not refresh both routing peers")
}
}
func TestAffectedPeers_E2E_UpdatePolicyAddSourceGroup_RefreshesRoutingPeer(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
const secondSourceGroupID = "rs-source-grp-2"
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondSourcePeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
require.NoError(t, s.manager.CreateGroup(ctx, s.accountID, userID, &types.Group{
ID: secondSourceGroupID, Name: "rs-source-2", Peers: []string{secondSourcePeer.ID},
}))
policy, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID), true)
require.NoError(t, err)
newSrcCh := s.updateManager.CreateChannel(ctx, secondSourcePeer.ID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, secondSourcePeer.ID)
s.updateManager.CloseChannel(ctx, s.routerPeerID)
})
settleAffectedUpdates(newSrcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, newSrcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
policy.Rules[0].Sources = []string{s.sourceGroupID, secondSourceGroupID}
_, err = s.manager.SavePolicy(ctx, s.accountID, userID, policy, false)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: adding a source group did not refresh the new source peer + routing peer")
}
}
func TestAffectedPeers_E2E_CreatePolicyByDestinationResource_RefreshesRoutingPeer_RouterPeerGroups(t *testing.T) {
s := setupRouterScenario(t, false)
ctx := context.Background()
srcCh := s.updateManager.CreateChannel(ctx, s.sourcePeerID)
routerCh := s.updateManager.CreateChannel(ctx, s.routerGroupPeerID)
t.Cleanup(func() {
s.updateManager.CloseChannel(ctx, s.sourcePeerID)
s.updateManager.CloseChannel(ctx, s.routerGroupPeerID)
})
settleAffectedUpdates(srcCh, routerCh)
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, srcCh)
peerShouldReceiveUpdate(t, routerCh)
close(done)
}()
_, err := s.manager.SavePolicy(ctx, s.accountID, userID, peerToResourcePolicyByResource(s.sourceGroupID, s.resourceID), true)
require.NoError(t, err)
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout: DestinationResource policy with PeerGroups router did not refresh the routing peer")
}
}
func TestAffectedPeers_PolicyToResource_IncludesAllRoutingPeersOnNetwork(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
_, routersManager, _ := s.managers()
setupKey, err := s.manager.CreateSetupKey(ctx, s.accountID, "rs-key-r2", types.SetupKeyReusable, time.Hour, nil, 999, userID, false, false)
require.NoError(t, err)
secondRouterPeer := addPeerToAccount(t, s.manager, s.accountID, setupKey.Key)
_, err = routersManager.CreateRouter(ctx, userID, &routerTypes.NetworkRouter{
NetworkID: s.networkID,
AccountID: s.accountID,
Peer: secondRouterPeer.ID,
Masquerade: true,
Metric: 9998,
Enabled: true,
})
require.NoError(t, err)
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "first routing peer must be affected")
assert.Contains(t, affected, secondRouterPeer.ID, "second routing peer on the same network must also be affected")
}
func TestAffectedPeers_PolicyToResource_DisabledRouterStillAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
routers, err := s.manager.Store.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, s.accountID, s.networkID)
require.NoError(t, err)
require.Len(t, routers, 1)
routers[0].Enabled = false
require.NoError(t, s.manager.Store.UpdateNetworkRouter(ctx, routers[0]))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled router's peer must still be affected: Enabled must not gate affected-peers")
}
func TestAffectedPeers_PolicyToResource_DisabledResourceStillAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
res, err := s.manager.Store.GetNetworkResourceByID(ctx, store.LockingStrengthNone, s.accountID, s.resourceID)
require.NoError(t, err)
res.Enabled = false
require.NoError(t, s.manager.Store.SaveNetworkResource(ctx, res))
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.sourcePeerID, "source peer must be affected")
assert.Contains(t, affected, s.routerPeerID,
"disabled resource must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_PolicyToResource_DisabledRuleStillAffected(t *testing.T) {
s := setupRouterScenario(t, true)
ctx := context.Background()
policy := peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID)
policy.Rules[0].Enabled = false
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID,
"disabled rule must still resolve the routing peer: Enabled must not gate affected-peers")
}
func TestAffectedPeers_MultiRulePolicy_IncludesAllRoutingPeers(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "c")
ctx := context.Background()
policy := &types.Policy{
Enabled: true,
Name: "multi-rule-two-resources",
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{s.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
{
Enabled: true,
Sources: []string{s.sourceGroupID},
Destinations: []string{second.resourceGroupID},
Action: types.PolicyTrafficActionAccept,
},
},
}
affected := s.resolvePolicyAffected(ctx, policy)
assert.Contains(t, affected, s.routerPeerID, "routing peer for resource A must be affected")
assert.Contains(t, affected, second.routerPeerID, "routing peer for resource B must be affected")
}
func TestAffectedPeers_PolicyToResource_RouterInOtherNetworkNotAffected(t *testing.T) {
s := setupRouterScenario(t, true)
second := s.addSecondTopology(t, "d")
ctx := context.Background()
affected := s.resolvePolicyAffected(ctx, peerToResourcePolicyByGroup(s.sourceGroupID, s.resourceGroupID))
assert.Contains(t, affected, s.routerPeerID, "network A's routing peer must be affected")
assert.NotContains(t, affected, second.routerPeerID,
"a router in an unrelated network must not be affected by a policy that does not target its resource")
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,669 @@
package affectedpeers
import (
"context"
log "github.com/sirupsen/logrus"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// Change describes what changed in an account. The resolver never consults the
// Enabled flag of any object: toggling Enabled is itself an observable change.
type Change struct {
ChangedGroupIDs []string
ChangedPeerIDs []string
Policies []*types.Policy
Routes []*route.Route
PostureCheckIDs []string
ResourceIDs []string
NetworkIDs []string
}
func (c Change) isEmpty() bool {
return len(c.ChangedGroupIDs) == 0 &&
len(c.ChangedPeerIDs) == 0 &&
len(c.Policies) == 0 &&
len(c.Routes) == 0 &&
len(c.PostureCheckIDs) == 0 &&
len(c.ResourceIDs) == 0 &&
len(c.NetworkIDs) == 0
}
// Resolve returns the deduplicated peer IDs whose network map may have changed by
// the given Change. Safe to call inside or after a transaction.
//
// At trace level it logs the full reasoning — which inputs drove which graph
// walks to which groups/peers, including the resource<->router bridge hops — so a
// miscalculation can be diagnosed from the logs alone.
func Resolve(ctx context.Context, s store.Store, accountID string, c Change) ([]string, error) {
if c.isEmpty() {
return nil, nil
}
r := newResolver(ctx, s, accountID, c)
log.WithContext(ctx).Tracef("affectedpeers resolve start: account=%s changedGroups=%v changedPeers=%v policies=%d routes=%d postureChecks=%v resources=%v networks=%v",
accountID, c.ChangedGroupIDs, c.ChangedPeerIDs, len(c.Policies), len(c.Routes), c.PostureCheckIDs, c.ResourceIDs, c.NetworkIDs)
r.walk()
return r.expand()
}
// Collect returns the affected group IDs and direct peer IDs without expanding
// groups to members. For tests asserting on the intermediate sets; use Resolve otherwise.
func Collect(ctx context.Context, s store.Store, accountID string, c Change) (groupIDs []string, directPeerIDs []string) {
if c.isEmpty() {
return nil, nil
}
r := newResolver(ctx, s, accountID, c)
r.walk()
return setToSlice(r.groupSet), setToSlice(r.peerSet)
}
func newResolver(ctx context.Context, s store.Store, accountID string, c Change) *resolver {
r := &resolver{
ctx: ctx,
store: s,
accountID: accountID,
change: c,
changedGroupSet: toSet(c.ChangedGroupIDs),
changedPeerSet: toSet(c.ChangedPeerIDs),
groupSet: make(map[string]struct{}),
peerSet: make(map[string]struct{}),
resourceIDs: toSet(c.ResourceIDs),
networkIDs: toSet(c.NetworkIDs),
}
r.matchedPolicies = append(r.matchedPolicies, c.Policies...)
return r
}
func (r *resolver) walk() {
r.collectFromExplicitPolicies()
r.collectFromExplicitRoutes(r.change.Routes)
r.collectFromPostureChecks(r.change.PostureCheckIDs)
if len(r.changedGroupSet) > 0 || len(r.changedPeerSet) > 0 {
r.collectFromPolicies()
r.collectFromRoutes()
r.collectFromNameServers()
r.collectFromDNSSettings()
r.collectFromNetworkRouters()
r.collectFromProxyServices()
}
r.collectResourceRouterBridge()
}
type resolver struct {
ctx context.Context
store store.Store
accountID string
change Change
changedGroupSet map[string]struct{}
changedPeerSet map[string]struct{}
groupSet map[string]struct{}
peerSet map[string]struct{}
matchedPolicies []*types.Policy
resourceIDs map[string]struct{}
networkIDs map[string]struct{}
// Memoized per-account collections: each is loaded from the store at most
// once per Resolve and only when a walker actually needs it.
cachedPolicies []*types.Policy
policiesLoaded bool
cachedResources []*resourceTypes.NetworkResource
resourcesLoaded bool
cachedRouters []*routerTypes.NetworkRouter
routersLoaded bool
}
func (r *resolver) policies() []*types.Policy {
if r.policiesLoaded {
return r.cachedPolicies
}
r.policiesLoaded = true
policies, err := r.store.GetAccountPolicies(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get policies for affected peers resolution: %v", err)
return nil
}
r.cachedPolicies = policies
return r.cachedPolicies
}
func (r *resolver) networkResources() []*resourceTypes.NetworkResource {
if r.resourcesLoaded {
return r.cachedResources
}
r.resourcesLoaded = true
resources, err := r.store.GetNetworkResourcesByAccountID(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get network resources for affected peers resolution: %v", err)
return nil
}
r.cachedResources = resources
return r.cachedResources
}
func (r *resolver) networkRouters() []*routerTypes.NetworkRouter {
if r.routersLoaded {
return r.cachedRouters
}
r.routersLoaded = true
routers, err := r.store.GetNetworkRoutersByAccountID(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get network routers for affected peers resolution: %v", err)
return nil
}
r.cachedRouters = routers
return r.cachedRouters
}
func (r *resolver) expand() ([]string, error) {
groupIDs := setToSlice(r.groupSet)
var peerIDs []string
if len(groupIDs) > 0 {
ids, err := r.store.GetPeerIDsByGroups(r.ctx, r.accountID, groupIDs)
if err != nil {
return nil, err
}
peerIDs = ids
}
log.WithContext(r.ctx).Tracef("affectedpeers resolve expand: account=%s affectedGroups=%v -> %d group-member peers; direct peers=%v",
r.accountID, groupIDs, len(peerIDs), setToSlice(r.peerSet))
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for id := range r.peerSet {
if _, ok := seen[id]; !ok {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
log.WithContext(r.ctx).Tracef("affectedpeers resolve done: account=%s -> %d affected peers: %v", r.accountID, len(peerIDs), peerIDs)
return peerIDs, nil
}
func (r *resolver) collectFromExplicitPolicies() {
for _, policy := range r.matchedPolicies {
if policy == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitPolicies: changed policy %s (%s) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
}
}
func (r *resolver) collectFromExplicitRoutes(routes []*route.Route) {
for _, rt := range routes {
if rt == nil {
continue
}
log.WithContext(r.ctx).Tracef("collectFromExplicitRoutes: changed route %s -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromPostureChecks(postureCheckIDs []string) {
if len(postureCheckIDs) == 0 {
return
}
ids := toSet(postureCheckIDs)
for _, policy := range r.policies() {
if !policyReferencesPostureChecks(policy, ids) {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPostureChecks: policy %s (%s) references changed posture checks %v -> folding rule groups %v + direct peers",
policy.ID, policy.Name, postureCheckIDs, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromPolicies() {
for _, policy := range r.policies() {
matchedByGroup := policyReferencesGroups(policy, r.changedGroupSet)
matchedByPeer := len(r.changedPeerSet) > 0 && policyReferencesDirectPeers(policy, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromPolicies: policy %s (%s) matched (byGroup=%t byPeer=%t) -> folding rule groups %v + direct peers",
policy.ID, policy.Name, matchedByGroup, matchedByPeer, policy.RuleGroups())
addAll(r.groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, r.peerSet)
r.matchedPolicies = append(r.matchedPolicies, policy)
}
}
func (r *resolver) collectFromRoutes() {
routes, err := r.store.GetAccountRoutes(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get routes for affected peers resolution: %v", err)
return
}
for _, rt := range routes {
matchedByGroup := anyInSet(rt.Groups, r.changedGroupSet) || anyInSet(rt.PeerGroups, r.changedGroupSet) || anyInSet(rt.AccessControlGroups, r.changedGroupSet)
matchedByPeer := rt.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(rt.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromRoutes: route %s matched (byGroup=%t byPeer=%t) -> folding groups=%v peerGroups=%v accessControlGroups=%v peer=%q",
rt.ID, matchedByGroup, matchedByPeer, rt.Groups, rt.PeerGroups, rt.AccessControlGroups, rt.Peer)
addAll(r.groupSet, rt.Groups, rt.PeerGroups, rt.AccessControlGroups)
if rt.Peer != "" {
r.peerSet[rt.Peer] = struct{}{}
}
}
}
func (r *resolver) collectFromNameServers() {
if len(r.changedGroupSet) == 0 {
return
}
nsGroups, err := r.store.GetAccountNameServerGroups(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get nameserver groups for affected peers resolution: %v", err)
return
}
for _, ns := range nsGroups {
if anyInSet(ns.Groups, r.changedGroupSet) {
log.WithContext(r.ctx).Tracef("collectFromNameServers: nameserver group %s references a changed group -> folding its groups %v", ns.ID, ns.Groups)
addAll(r.groupSet, ns.Groups)
}
}
}
func (r *resolver) collectFromDNSSettings() {
if len(r.changedGroupSet) == 0 {
return
}
dnsSettings, err := r.store.GetAccountDNSSettings(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get DNS settings for affected peers resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := r.changedGroupSet[gID]; ok {
log.WithContext(r.ctx).Tracef("collectFromDNSSettings: changed group %s is in DisabledManagementGroups -> folding it", gID)
r.groupSet[gID] = struct{}{}
}
}
}
func (r *resolver) collectFromNetworkRouters() {
for _, router := range r.networkRouters() {
matchedByGroup := anyInSet(router.PeerGroups, r.changedGroupSet)
matchedByPeer := router.Peer != "" && len(r.changedPeerSet) > 0 && isInSet(router.Peer, r.changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
log.WithContext(r.ctx).Tracef("collectFromNetworkRouters: router %s on network %s matched (byGroup=%t byPeer=%t) -> folding peerGroups=%v peer=%q and marking network for source bridge",
router.ID, router.NetworkID, matchedByGroup, matchedByPeer, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
r.networkIDs[router.NetworkID] = struct{}{}
}
}
func (r *resolver) collectFromProxyServices() {
services, proxyByCluster, ok := r.loadProxyServiceContext()
if !ok {
return
}
expanded := r.expandChangedPeersWithGroups()
for _, svc := range services {
if svc == nil {
continue
}
proxyPeers := proxyByCluster[svc.ProxyCluster]
if len(proxyPeers) == 0 {
continue
}
matchedByPeer := serviceMatchesChangedPeers(svc, proxyPeers, expanded)
matchedByAccessGroup := anyInSet(svc.AccessGroups, r.changedGroupSet)
if !matchedByPeer && !matchedByAccessGroup {
continue
}
log.WithContext(r.ctx).Tracef("collectFromProxyServices: service %s (cluster=%s) matched (byProxyOrTargetPeer=%t byAccessGroup=%t) -> folding %d proxy peers, peer targets and access groups %v",
svc.ID, svc.ProxyCluster, matchedByPeer, matchedByAccessGroup, len(proxyPeers), svc.AccessGroups)
for _, pid := range proxyPeers {
r.peerSet[pid] = struct{}{}
}
for _, target := range svc.Targets {
if target.TargetType == rpservice.TargetTypePeer && target.TargetId != "" {
r.peerSet[target.TargetId] = struct{}{}
}
}
addAll(r.groupSet, svc.AccessGroups)
}
}
func (r *resolver) loadProxyServiceContext() ([]*rpservice.Service, map[string][]string, bool) {
// Embedded proxy peers are the prerequisite for any synthesized proxy policy.
// Probe that first (a narrow, indexed lookup) and skip the services table load
// entirely when the account has no embedded proxy peers.
proxyByCluster, err := r.store.GetEmbeddedProxyPeerIDsByCluster(r.ctx, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get embedded proxy peers for affected peers resolution: %v", err)
return nil, nil, false
}
if len(proxyByCluster) == 0 {
return nil, nil, false
}
services, err := r.store.GetAccountServices(r.ctx, store.LockingStrengthNone, r.accountID)
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get services for affected peers resolution: %v", err)
return nil, nil, false
}
if len(services) == 0 {
return nil, nil, false
}
return services, proxyByCluster, true
}
func (r *resolver) expandChangedPeersWithGroups() map[string]struct{} {
if len(r.changedGroupSet) == 0 {
return r.changedPeerSet
}
ids, err := r.store.GetPeerIDsByGroups(r.ctx, r.accountID, setToSlice(r.changedGroupSet))
if err != nil {
log.WithContext(r.ctx).Errorf("failed to expand changed groups to peers for service resolution: %v", err)
return r.changedPeerSet
}
if len(ids) == 0 {
return r.changedPeerSet
}
merged := make(map[string]struct{}, len(r.changedPeerSet)+len(ids))
for id := range r.changedPeerSet {
merged[id] = struct{}{}
}
for _, id := range ids {
merged[id] = struct{}{}
}
return merged
}
// collectResourceRouterBridge folds in the routing peers serving the resources
// targeted by matched/explicit policies (source -> router), and the source peers
// of policies serving resources on the affected networks (router -> source). The
// routing peer is reachable only through resource -> network -> router, never
// through the policy's own groups, so it must be collected here.
func (r *resolver) collectResourceRouterBridge() {
r.bridgeSourceToRouters()
r.bridgeRoutersToSources()
}
func (r *resolver) bridgeSourceToRouters() {
resourceIDs := r.policyDestinationResourceIDs(r.matchedPolicies...)
for id := range r.resourceIDs {
resourceIDs[id] = struct{}{}
}
if len(resourceIDs) == 0 {
return
}
networkIDs := r.resourceNetworkIDs(resourceIDs)
log.WithContext(r.ctx).Tracef("bridgeSourceToRouters: targeted resources %v -> networks %v (their routers become affected via the router->source pass)",
setToSlice(resourceIDs), setToSlice(networkIDs))
for id := range networkIDs {
r.networkIDs[id] = struct{}{}
}
}
func (r *resolver) bridgeRoutersToSources() {
if len(r.networkIDs) == 0 {
return
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: affected networks %v -> folding their routing peers and the source peers of policies targeting their resources",
setToSlice(r.networkIDs))
r.foldRoutersOnNetworks(r.networkIDs)
resourceIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := r.networkIDs[resource.NetworkID]; ok {
resourceIDs[resource.ID] = struct{}{}
}
}
if len(resourceIDs) == 0 {
return
}
for _, policy := range r.policies() {
if r.policyTargetsResources(policy, resourceIDs) {
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: policy %s (%s) targets an affected-network resource -> folding its source groups/peers", policy.ID, policy.Name)
collectPolicySources(policy, r.groupSet, r.peerSet)
}
}
}
func (r *resolver) foldRoutersOnNetworks(networkIDs map[string]struct{}) {
for _, router := range r.networkRouters() {
if _, ok := networkIDs[router.NetworkID]; !ok {
continue
}
log.WithContext(r.ctx).Tracef("bridgeRoutersToSources: router %s serves affected network %s -> folding peerGroups=%v peer=%q",
router.ID, router.NetworkID, router.PeerGroups, router.Peer)
addAll(r.groupSet, router.PeerGroups)
if router.Peer != "" {
r.peerSet[router.Peer] = struct{}{}
}
}
}
func (r *resolver) resourceNetworkIDs(resourceIDs map[string]struct{}) map[string]struct{} {
networkIDs := make(map[string]struct{})
for _, resource := range r.networkResources() {
if _, ok := resourceIDs[resource.ID]; ok {
networkIDs[resource.NetworkID] = struct{}{}
}
}
return networkIDs
}
func (r *resolver) policyTargetsResources(policy *types.Policy, resourceIDs map[string]struct{}) bool {
if policy == nil {
return false
}
destGroupSet := make(map[string]struct{})
for _, rule := range policy.Rules {
if rule.DestinationResource.Type != types.ResourceTypePeer && isInSet(rule.DestinationResource.ID, resourceIDs) {
return true
}
for _, gID := range rule.Destinations {
destGroupSet[gID] = struct{}{}
}
}
if len(destGroupSet) == 0 {
return false
}
groups, err := r.store.GetGroupsByIDs(r.ctx, store.LockingStrengthNone, r.accountID, setToSlice(destGroupSet))
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get destination groups for router policy bridge: %v", err)
return false
}
for _, group := range groups {
for _, res := range group.Resources {
if isInSet(res.ID, resourceIDs) {
return true
}
}
}
return false
}
func (r *resolver) policyDestinationResourceIDs(policies ...*types.Policy) map[string]struct{} {
resourceIDs := make(map[string]struct{})
destGroupSet := collectPolicyDestinations(resourceIDs, policies...)
r.addGroupResourceIDs(destGroupSet, resourceIDs)
return resourceIDs
}
// collectPolicyDestinations adds each rule's direct destination resource IDs to
// resourceIDs and returns the set of destination group IDs referenced.
func collectPolicyDestinations(resourceIDs map[string]struct{}, policies ...*types.Policy) map[string]struct{} {
destGroupSet := make(map[string]struct{})
for _, policy := range policies {
if policy == nil {
continue
}
for _, rule := range policy.Rules {
addAll(destGroupSet, rule.Destinations)
if rule.DestinationResource.Type != types.ResourceTypePeer && rule.DestinationResource.ID != "" {
resourceIDs[rule.DestinationResource.ID] = struct{}{}
}
}
}
return destGroupSet
}
// addGroupResourceIDs folds the resource IDs of the given groups into resourceIDs.
func (r *resolver) addGroupResourceIDs(groupIDs map[string]struct{}, resourceIDs map[string]struct{}) {
if len(groupIDs) == 0 {
return
}
groups, err := r.store.GetGroupsByIDs(r.ctx, store.LockingStrengthNone, r.accountID, setToSlice(groupIDs))
if err != nil {
log.WithContext(r.ctx).Errorf("failed to get destination groups for resource router bridge: %v", err)
return
}
for _, group := range groups {
for _, res := range group.Resources {
if res.ID != "" {
resourceIDs[res.ID] = struct{}{}
}
}
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectPolicySources(policy *types.Policy, groupSet, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
addAll(groupSet, rule.Sources)
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func policyReferencesPostureChecks(policy *types.Policy, ids map[string]struct{}) bool {
for _, id := range policy.SourcePostureChecks {
if _, ok := ids[id]; ok {
return true
}
}
return false
}
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
if res.Type != types.ResourceTypePeer || res.ID == "" {
return false
}
_, ok := set[res.ID]
return ok
}
func serviceMatchesChangedPeers(svc *rpservice.Service, proxyPeers []string, changedPeers map[string]struct{}) bool {
for _, pid := range proxyPeers {
if _, ok := changedPeers[pid]; ok {
return true
}
}
for _, target := range svc.Targets {
if target.TargetType != rpservice.TargetTypePeer || target.TargetId == "" {
continue
}
if _, ok := changedPeers[target.TargetId]; ok {
return true
}
}
return false
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAll(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
set[id] = struct{}{}
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}

View File

@@ -0,0 +1,138 @@
package affectedpeers
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/netbirdio/netbird/management/server/types"
)
// policyGroupsAndPeers mirrors the explicit-policy extraction (RuleGroups +
// direct peers) the resolver folds in, for asserting the pure logic.
func policyGroupsAndPeers(policies ...*types.Policy) (groups []string, peers []string) {
peerSet := map[string]struct{}{}
for _, p := range policies {
if p == nil {
continue
}
groups = append(groups, p.RuleGroups()...)
collectPolicyDirectPeers(p, peerSet)
}
for id := range peerSet {
peers = append(peers, id)
}
return groups, peers
}
func TestPolicyGroupsAndPeers_Basic(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2", "g3"}, groups)
assert.Empty(t, peers)
}
func TestPolicyGroupsAndPeers_WithPeerResources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "p1", Type: types.ResourceTypePeer},
Destinations: []string{"g2"},
DestinationResource: types.Resource{ID: "p2", Type: types.ResourceTypePeer},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.ElementsMatch(t, []string{"p1", "p2"}, peers)
}
func TestPolicyGroupsAndPeers_NilPolicy(t *testing.T) {
groups, peers := policyGroupsAndPeers(nil)
assert.Nil(t, groups)
assert.Nil(t, peers)
}
func TestPolicyGroupsAndPeers_MultiplePolicies(t *testing.T) {
old := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1"}, Destinations: []string{"g2"}}}}
updated := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g3"}, Destinations: []string{"g4"}}}}
groups, _ := policyGroupsAndPeers(updated, old)
assert.ElementsMatch(t, []string{"g1", "g2", "g3", "g4"}, groups)
}
func TestPolicyGroupsAndPeers_NonPeerResource(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{ID: "domain-1", Type: types.ResourceTypeDomain},
Destinations: []string{"g2"},
}}}
groups, peers := policyGroupsAndPeers(policy)
assert.ElementsMatch(t, []string{"g1", "g2"}, groups)
assert.Empty(t, peers, "domain resource type should not produce direct peer IDs")
}
func TestChangeIsEmpty(t *testing.T) {
assert.True(t, Change{}.isEmpty())
assert.False(t, Change{ChangedGroupIDs: []string{"g"}}.isEmpty())
assert.False(t, Change{ChangedPeerIDs: []string{"p"}}.isEmpty())
assert.False(t, Change{Policies: []*types.Policy{{}}}.isEmpty())
assert.False(t, Change{ResourceIDs: []string{"r"}}.isEmpty())
assert.False(t, Change{NetworkIDs: []string{"n"}}.isEmpty())
assert.False(t, Change{PostureCheckIDs: []string{"pc"}}.isEmpty())
}
func TestPolicyReferencesGroups(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{Sources: []string{"g1", "g2"}, Destinations: []string{"g3"}}}}
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g1": {}}))
assert.True(t, policyReferencesGroups(policy, map[string]struct{}{"g3": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{"g4": {}}))
assert.False(t, policyReferencesGroups(policy, map[string]struct{}{}))
}
func TestPolicyReferencesDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
assert.True(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"r1": {}}))
assert.False(t, policyReferencesDirectPeers(policy, map[string]struct{}{"p2": {}}))
}
func TestPolicyReferencesPostureChecks(t *testing.T) {
policy := &types.Policy{SourcePostureChecks: []string{"pc1", "pc2"}}
assert.True(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc1": {}}))
assert.False(t, policyReferencesPostureChecks(policy, map[string]struct{}{"pc3": {}}))
}
func TestCollectPolicyDirectPeers(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
DestinationResource: types.Resource{Type: types.ResourceTypePeer, ID: "p2"},
}, {
DestinationResource: types.Resource{Type: types.ResourceTypeHost, ID: "r1"},
}}}
peerSet := map[string]struct{}{}
collectPolicyDirectPeers(policy, peerSet)
assert.Contains(t, peerSet, "p1")
assert.Contains(t, peerSet, "p2")
assert.NotContains(t, peerSet, "r1")
}
func TestCollectPolicySources(t *testing.T) {
policy := &types.Policy{Rules: []*types.PolicyRule{{
Sources: []string{"g1"},
SourceResource: types.Resource{Type: types.ResourceTypePeer, ID: "p1"},
Destinations: []string{"g2"},
}}}
groupSet := map[string]struct{}{}
peerSet := map[string]struct{}{}
collectPolicySources(policy, groupSet, peerSet)
assert.Contains(t, groupSet, "g1")
assert.NotContains(t, groupSet, "g2", "destination groups must not be collected as sources")
assert.Contains(t, peerSet, "p1")
}

View File

@@ -47,8 +47,8 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var eventsToStore []func()
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateDNSSettings(ctx, transaction, accountID, dnsSettingsToSave); err != nil {
@@ -63,11 +63,6 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
addedGroups := util.Difference(dnsSettingsToSave.DisabledManagementGroups, oldSettings.DisabledManagementGroups)
removedGroups := util.Difference(oldSettings.DisabledManagementGroups, dnsSettingsToSave.DisabledManagementGroups)
updateAccountPeers, err = areDNSSettingChangesAffectPeers(ctx, transaction, accountID, addedGroups, removedGroups)
if err != nil {
return err
}
events := am.prepareDNSSettingsEvents(ctx, transaction, accountID, userID, addedGroups, removedGroups)
eventsToStore = append(eventsToStore, events...)
@@ -75,6 +70,9 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
return err
}
allGroups := slices.Concat(addedGroups, removedGroups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -85,8 +83,11 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceDNSSettings, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveDNSSettings: updating %d affected peers: %v", len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveDNSSettings: no affected peers")
}
return nil
@@ -133,20 +134,6 @@ func (am *DefaultAccountManager) prepareDNSSettingsEvents(ctx context.Context, t
return eventsToStore
}
// areDNSSettingChangesAffectPeers checks if the DNS settings changes affect any peers.
func areDNSSettingChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, addedGroups, removedGroups []string) (bool, error) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, addedGroups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, accountID, removedGroups)
}
// validateDNSSettings validates the DNS settings.
func validateDNSSettings(ctx context.Context, transaction store.Store, accountID string, settings *types.DNSSettings) error {
if len(settings.DisabledManagementGroups) == 0 {

View File

@@ -11,6 +11,7 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -79,7 +80,7 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -91,11 +92,6 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
events := am.prepareGroupEvents(ctx, transaction, accountID, userID, newGroup)
eventsToStore = append(eventsToStore, events...)
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
return err
}
if err := transaction.CreateGroup(ctx, newGroup); err != nil {
return status.Errorf(status.Internal, "failed to create group: %v", err)
}
@@ -106,6 +102,8 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
}
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -116,8 +114,11 @@ func (am *DefaultAccountManager) CreateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateGroup %s: no affected peers", newGroup.ID)
}
return nil
@@ -134,7 +135,7 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
}
var eventsToStore []func()
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNewGroup(ctx, transaction, accountID, newGroup); err != nil {
@@ -151,22 +152,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return status.Errorf(status.NotFound, "group with ID %s not found", newGroup.ID)
}
peersToAdd := util.Difference(newGroup.Peers, oldGroup.Peers)
peersToRemove := util.Difference(oldGroup.Peers, newGroup.Peers)
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, newGroup.ID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, newGroup.ID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, newGroup.ID, err)
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{newGroup.ID})
if err != nil {
if err = syncGroupMembership(ctx, transaction, accountID, newGroup.ID, util.Difference(newGroup.Peers, oldGroup.Peers), peersToRemove); err != nil {
return err
}
@@ -178,6 +165,8 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{newGroup.ID}, ChangedPeerIDs: peersToRemove})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -188,13 +177,31 @@ func (am *DefaultAccountManager) UpdateGroup(ctx context.Context, accountID, use
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateGroup %s: updating %d affected peers: %v", newGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateGroup %s: no affected peers", newGroup.ID)
}
return nil
}
// syncGroupMembership applies the peer membership delta for a group within a transaction.
func syncGroupMembership(ctx context.Context, transaction store.Store, accountID, groupID string, peersToAdd, peersToRemove []string) error {
for _, peerID := range peersToAdd {
if err := transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return status.Errorf(status.Internal, "failed to add peer %s to group %s: %v", peerID, groupID, err)
}
}
for _, peerID := range peersToRemove {
if err := transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return status.Errorf(status.Internal, "failed to remove peer %s from group %s: %v", peerID, groupID, err)
}
}
return nil
}
// CreateGroups adds new groups to the account.
// Note: This function does not acquire the global lock.
// It is the caller's responsibility to ensure proper locking is in place before invoking this method.
@@ -209,7 +216,6 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -247,17 +253,16 @@ func (am *DefaultAccountManager) CreateGroups(ctx context.Context, accountID, us
}
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationCreate})
affectedPeerIDs := am.ResolveAffectedPeers(ctx, am.Store, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDs})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateGroups %v: no affected peers", groupIDs)
}
return globalErr
@@ -277,7 +282,6 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
}
var eventsToStore []func()
var updateAccountPeers bool
var globalErr error
groupIDs := make([]string, 0, len(groups))
@@ -295,17 +299,16 @@ func (am *DefaultAccountManager) UpdateGroups(ctx context.Context, accountID, us
groupIDs = append(groupIDs, newGroup.ID)
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, am.Store, accountID, groupIDs)
if err != nil {
return err
}
for _, storeEvent := range eventsToStore {
storeEvent()
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
affectedPeerIDs := am.ResolveAffectedPeers(ctx, am.Store, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDs})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateGroups %v: updating %d affected peers: %v", groupIDs, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateGroups %v: no affected peers", groupIDs)
}
return globalErr
@@ -438,6 +441,7 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
var allErrors error
var groupIDsToDelete []string
var deletedGroups []*types.Group
var affectedPeerIDs []string
extraSettings, err := am.settingsManager.GetExtraSettings(ctx, accountID)
if err != nil {
@@ -445,26 +449,17 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err = validateDeleteGroup(ctx, transaction, group, userID, extraSettings.FlowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
groupIDsToDelete = append(groupIDsToDelete, groupID)
deletedGroups = append(deletedGroups, group)
deletedGroups, allErrors = collectDeletableGroups(ctx, transaction, accountID, userID, groupIDs, extraSettings.FlowGroups)
for _, group := range deletedGroups {
groupIDsToDelete = append(groupIDsToDelete, group.ID)
}
if len(groupIDsToDelete) == 0 {
return allErrors
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: groupIDsToDelete})
if err = transaction.DeleteGroups(ctx, accountID, groupIDsToDelete); err != nil {
return err
}
@@ -483,20 +478,42 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountID, us
am.StoreEvent(ctx, userID, group.ID, accountID, activity.GroupDeleted, group.EventMeta())
}
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteGroups %v: updating %d affected peers: %v", groupIDsToDelete, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteGroups %v: no affected peers", groupIDsToDelete)
}
return allErrors
}
// collectDeletableGroups loads and validates each group for deletion, returning
// the groups that may be deleted and the joined validation errors for the rest.
func collectDeletableGroups(ctx context.Context, transaction store.Store, accountID, userID string, groupIDs, flowGroups []string) ([]*types.Group, error) {
var deletable []*types.Group
var allErrors error
for _, groupID := range groupIDs {
group, err := transaction.GetGroupByID(ctx, store.LockingStrengthNone, accountID, groupID)
if err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
if err = validateDeleteGroup(ctx, transaction, group, userID, flowGroups); err != nil {
allErrors = errors.Join(allErrors, err)
continue
}
deletable = append(deletable, group)
}
return deletable, allErrors
}
// GroupAddPeer appends peer to the group
func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.AddPeerToGroup(ctx, accountID, peerID, groupID); err != nil {
return err
}
@@ -505,14 +522,19 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupAddPeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupAddPeer group=%s peer=%s: no affected peers", groupID, peerID)
}
return nil
@@ -521,7 +543,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
// GroupAddResource appends resource to the group
func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -534,23 +556,23 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupAddResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupAddResource group=%s resource=%s: no affected peers", groupID, resource.ID)
}
return nil
@@ -558,14 +580,12 @@ func (am *DefaultAccountManager) GroupAddResource(ctx context.Context, accountID
// GroupDeletePeer removes peer from the group
func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, groupID, peerID string) error {
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
// Resolve before removing, so the peer being removed is still included
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
if err = transaction.RemovePeerFromGroup(ctx, peerID, groupID); err != nil {
return err
@@ -581,8 +601,11 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupDeletePeer group=%s peer=%s: updating %d affected peers: %v", groupID, peerID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupDeletePeer group=%s peer=%s: no affected peers", groupID, peerID)
}
return nil
@@ -591,7 +614,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
// GroupDeleteResource removes resource from the group
func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accountID, groupID string, resource types.Resource) error {
var group *types.Group
var updateAccountPeers bool
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -604,23 +627,23 @@ func (am *DefaultAccountManager) GroupDeleteResource(ctx context.Context, accoun
return nil
}
updateAccountPeers, err = areGroupChangesAffectPeers(ctx, transaction, accountID, []string{groupID})
if err != nil {
return err
}
if err = transaction.UpdateGroup(ctx, group); err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ChangedGroupIDs: []string{groupID}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
return err
}
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("GroupDeleteResource group=%s resource=%s: updating %d affected peers: %v", groupID, resource.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("GroupDeleteResource group=%s resource=%s: no affected peers", groupID, resource.ID)
}
return nil
@@ -832,49 +855,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store,
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
// It fetches each collection once and checks all groupIDs against them in memory.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
groupSet := make(map[string]struct{}, len(groupIDs))
for _, id := range groupIDs {
groupSet[id] = struct{}{}
}
if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil {
return affected, err
}
return false, nil
}
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil
}
// anyGroupHasPeersOrResources checks if any of the given groups in the account have peers or resources.
func anyGroupHasPeersOrResources(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, groupIDs)
func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, group := range groups {
if group.HasPeers() || group.HasResources() {
for _, ns := range nameServerGroups {
if anyInSet(ns.Groups, groupSet) {
return true, nil
}
}
return false, nil
}
func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true, nil
}
}
}
return false, nil
}
func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, r := range routes {
if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) {
return true, nil
}
}
return false, nil
}
func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, router := range routers {
if anyInSet(router.PeerGroups, groupSet) {
return true, nil
}
}
return false, nil
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/idp"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/posture"
@@ -132,6 +133,9 @@ type MockAccountManager struct {
AllowSyncFunc func(string, uint64) bool
UpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
UpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string)
BufferUpdateAffectedPeersFunc func(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason)
ResolveAffectedPeersFunc func(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string
BufferUpdateAccountPeersFunc func(ctx context.Context, accountID string, reason types.UpdateReason)
RecalculateNetworkMapCacheFunc func(ctx context.Context, accountId string) error
@@ -209,6 +213,25 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID
}
}
func (am *MockAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
if am.UpdateAffectedPeersFunc != nil {
am.UpdateAffectedPeersFunc(ctx, accountID, peerIDs)
}
}
func (am *MockAccountManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
if am.ResolveAffectedPeersFunc != nil {
return am.ResolveAffectedPeersFunc(ctx, s, accountID, change)
}
return nil
}
func (am *MockAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
if am.BufferUpdateAffectedPeersFunc != nil {
am.BufferUpdateAffectedPeersFunc(ctx, accountID, peerIDs, reason)
}
}
func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
if am.BufferUpdateAccountPeersFunc != nil {
am.BufferUpdateAccountPeersFunc(ctx, accountID, reason)

View File

@@ -4,10 +4,12 @@ import (
"context"
"errors"
"fmt"
"slices"
"strings"
"unicode/utf8"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
@@ -57,22 +59,19 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
SearchDomainsEnabled: searchDomainEnabled,
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateNameServerGroup(ctx, transaction, accountID, newNSGroup); err != nil {
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, newNSGroup.Groups)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, newNSGroup); err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, newNSGroup.Groups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -81,8 +80,11 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateNameServerGroup %s: updating %d affected peers: %v", newNSGroup.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateNameServerGroup %s: no affected peers", newNSGroup.ID)
}
return newNSGroup.Copy(), nil
@@ -102,7 +104,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
oldNSGroup, err := transaction.GetNameServerGroupByID(ctx, store.LockingStrengthNone, accountID, nsGroupToSave.ID)
@@ -115,15 +117,13 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
return err
}
updateAccountPeers, err = areNameServerGroupChangesAffectPeers(ctx, transaction, nsGroupToSave, oldNSGroup)
if err != nil {
return err
}
if err = transaction.SaveNameServerGroup(ctx, nsGroupToSave); err != nil {
return err
}
allGroups := slices.Concat(nsGroupToSave.Groups, oldNSGroup.Groups)
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, allGroups, nil)
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -132,8 +132,11 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveNameServerGroup %s: updating %d affected peers: %v", nsGroupToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveNameServerGroup %s: no affected peers", nsGroupToSave.ID)
}
return nil
@@ -150,7 +153,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
}
var nsGroup *nbdns.NameServerGroup
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
nsGroup, err = transaction.GetNameServerGroupByID(ctx, store.LockingStrengthUpdate, accountID, nsGroupID)
@@ -158,10 +161,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
return err
}
updateAccountPeers, err = anyGroupHasPeersOrResources(ctx, transaction, accountID, nsGroup.Groups)
if err != nil {
return err
}
affectedPeerIDs = am.resolvePeerIDs(ctx, transaction, accountID, nsGroup.Groups, nil)
if err = transaction.DeleteNameServerGroup(ctx, accountID, nsGroupID); err != nil {
return err
@@ -175,8 +175,11 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceNameServerGroup, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteNameServerGroup %s: updating %d affected peers: %v", nsGroupID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteNameServerGroup %s: no affected peers", nsGroupID)
}
return nil
@@ -224,24 +227,6 @@ func validateNameServerGroup(ctx context.Context, transaction store.Store, accou
return validateGroups(nameserverGroup.Groups, groups)
}
// areNameServerGroupChangesAffectPeers checks if the changes in the nameserver group affect the peers.
func areNameServerGroupChangesAffectPeers(ctx context.Context, transaction store.Store, newNSGroup, oldNSGroup *nbdns.NameServerGroup) (bool, error) {
if !newNSGroup.Enabled && !oldNSGroup.Enabled {
return false, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, newNSGroup.AccountID, newNSGroup.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, oldNSGroup.AccountID, oldNSGroup.Groups)
}
func validateDomainInput(primary bool, domains []string, searchDomainsEnabled bool) error {
if !primary && len(domains) == 0 {
return status.Errorf(status.InvalidArgument, "nameserver group primary status is false and domains are empty,"+

View File

@@ -5,9 +5,11 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/networks/resources"
"github.com/netbirdio/netbird/management/server/networks/routers"
"github.com/netbirdio/netbird/management/server/networks/types"
@@ -15,7 +17,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -127,7 +128,10 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
}
var eventsToStore []func()
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{NetworkIDs: []string{networkID}})
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get resources in network: %w", err)
@@ -141,12 +145,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
eventsToStore = append(eventsToStore, event...)
}
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
if err != nil {
return fmt.Errorf("failed to get routers in network: %w", err)
}
for _, router := range routers {
for _, router := range netRouters {
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
if err != nil {
return fmt.Errorf("failed to delete router: %w", err)
@@ -178,7 +182,12 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetwork, Operation: serverTypes.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteNetwork %s: updating %d affected peers: %v", networkID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteNetwork %s: no affected peers", networkID)
}
return nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/groups"
"github.com/netbirdio/netbird/management/server/networks/resources/types"
"github.com/netbirdio/netbird/management/server/permissions"
@@ -114,45 +115,11 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
}
var eventsToStore []func()
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
}
eventsToStore = append(eventsToStore, event)
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
var txErr error
eventsToStore, affectedPeerIDs, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource)
return txErr
})
if err != nil {
return nil, fmt.Errorf("failed to create network resource: %w", err)
@@ -162,11 +129,57 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
event()
}
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateResource %s: no affected peers", resource.ID)
}
return resource, nil
}
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource) ([]func(), []string, error) {
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
}
var eventsToStore []func()
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
})
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
affectedPeerIDs := m.accountManager.ResolveAffectedPeers(ctx, transaction, resource.AccountID, affectedpeers.Change{ResourceIDs: []string{resource.ID}})
return eventsToStore, affectedPeerIDs, nil
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
@@ -207,6 +220,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
resource.Prefix = prefix
var eventsToStore []func()
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
@@ -232,6 +246,15 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
return fmt.Errorf("failed to get network resource: %w", err)
}
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, resource.AccountID, resource.ID)
if err != nil {
return fmt.Errorf("failed to get old resource groups: %w", err)
}
var oldGroupIDs []string
for _, g := range oldGroups {
oldGroupIDs = append(oldGroupIDs, g.ID)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
@@ -247,6 +270,13 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
})
// Pass both old and new resource group IDs so policies that targeted the
// resource via a now-detached group still refresh their source peers.
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, resource.AccountID, affectedpeers.Change{
ResourceIDs: []string{resource.ID},
ChangedGroupIDs: append(oldGroupIDs, resource.GroupIDs...),
})
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
@@ -270,7 +300,12 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
}
}()
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateResource %s: updating %d affected peers: %v", resource.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateResource %s: no affected peers", resource.ID)
}
return resource, nil
}
@@ -331,7 +366,10 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
}
var events []func()
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{ResourceIDs: []string{resourceID}})
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
if err != nil {
return fmt.Errorf("failed to delete resource: %w", err)
@@ -352,7 +390,12 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
event()
}
go m.accountManager.UpdateAccountPeers(ctx, accountID, nbtypes.UpdateReason{Resource: nbtypes.UpdateResourceNetworkResource, Operation: nbtypes.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteResource %s: updating %d affected peers: %v", resourceID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteResource %s: no affected peers", resourceID)
}
return nil
}

View File

@@ -6,16 +6,17 @@ import (
"fmt"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/networks/routers/types"
networkTypes "github.com/netbirdio/netbird/management/server/networks/types"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
serverTypes "github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -90,6 +91,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
@@ -112,6 +114,8 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, router.AccountID, affectedpeers.Change{NetworkIDs: []string{router.NetworkID}})
return nil
})
if err != nil {
@@ -120,7 +124,12 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateRouter %s: no affected peers", router.ID)
}
return router, nil
}
@@ -156,36 +165,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
}
var network *networkTypes.Network
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
err = transaction.UpdateNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
return nil
var txErr error
network, affectedPeerIDs, txErr = m.updateRouterInTransaction(ctx, transaction, router)
return txErr
})
if err != nil {
return nil, err
@@ -193,11 +177,76 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("UpdateRouter %s: updating %d affected peers: %v", router.ID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("UpdateRouter %s: no affected peers", router.ID)
}
return router, nil
}
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, []string, error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
existing, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthUpdate, router.AccountID, router.ID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network router: %w", err)
}
if existing.AccountID != router.AccountID {
return nil, nil, status.NewNetworkRouterNotFoundError(router.ID)
}
if existing.NetworkID != router.NetworkID {
return nil, nil, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
if err = transaction.UpdateNetworkRouter(ctx, router); err != nil {
return nil, nil, fmt.Errorf("failed to update network router: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
networkIDs := []string{router.NetworkID}
if existing.NetworkID != router.NetworkID {
networkIDs = append(networkIDs, existing.NetworkID)
}
affectedPeerIDs := m.accountManager.ResolveAffectedPeers(ctx, transaction, router.AccountID, affectedpeers.Change{NetworkIDs: networkIDs})
// The previous routing peer / peer-group members lose their routing role and
// are no longer reachable from the post-update network state, so add them
// explicitly.
affectedPeerIDs = append(affectedPeerIDs, oldRoutingPeerIDs(ctx, transaction, router.AccountID, existing)...)
return network, affectedPeerIDs, nil
}
// oldRoutingPeerIDs returns the peer IDs that served as the router's routing peers
// before an update (direct Peer plus PeerGroups members).
func oldRoutingPeerIDs(ctx context.Context, transaction store.Store, accountID string, existing *types.NetworkRouter) []string {
var ids []string
if existing.Peer != "" {
ids = append(ids, existing.Peer)
}
if len(existing.PeerGroups) > 0 {
groupPeers, err := transaction.GetPeerIDsByGroups(ctx, accountID, existing.PeerGroups)
if err != nil {
log.WithContext(ctx).Errorf("failed to get old router peer-group members for affected peers: %v", err)
} else {
ids = append(ids, groupPeers...)
}
}
return ids
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
ok, ctx, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -208,7 +257,10 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
}
var event func()
var affectedPeerIDs []string
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
affectedPeerIDs = m.accountManager.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{NetworkIDs: []string{networkID}})
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
if err != nil {
return fmt.Errorf("failed to delete network router: %w", err)
@@ -227,7 +279,12 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
event()
go m.accountManager.UpdateAccountPeers(ctx, accountID, serverTypes.UpdateReason{Resource: serverTypes.UpdateResourceNetworkRouter, Operation: serverTypes.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteRouter %s: updating %d affected peers: %v", routerID, len(affectedPeerIDs), affectedPeerIDs)
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteRouter %s: no affected peers", routerID)
}
return nil
}

View File

@@ -27,6 +27,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
nbpeer "github.com/netbirdio/netbird/management/server/peer"
"github.com/netbirdio/netbird/management/server/telemetry"
"github.com/netbirdio/netbird/shared/management/status"
@@ -120,19 +121,23 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
}
// An embedded proxy peer flipping to connected is the trigger for
// SynthesizePrivateServiceZones to emit DNS A records pointing at its
// tunnel IP. Without an account-wide netmap recompute, user peers keep
// the stale synth (or no synth at all on first connect) until some
// other change pokes the controller. Fire OnPeersUpdated so the
// buffered recompute fans the new state out to every peer.
// tunnel IP. Fan the change out to every peer that has a synthesized
// policy or DNS-record edge from this proxy peer; otherwise their
// synth state stays stale until some other change pokes the controller.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s connect: %v", peer.ID, err)
}
}
@@ -174,11 +179,12 @@ func (am *DefaultAccountManager) MarkPeerDisconnected(ctx context.Context, peerP
am.metrics.AccountManagerMetrics().CountPeerStatusUpdate(telemetry.PeerStatusDisconnect, telemetry.PeerStatusApplied)
// Symmetric with MarkPeerConnected: when an embedded proxy peer goes
// offline, drive an account-wide netmap recompute so the synthesized
// DNS records that pointed at it are pulled. Without this the records
// linger client-side at TTL until something else triggers a refresh.
// offline, refresh the peers that had synthesized records pointing at
// it so they pull the stale entries instead of waiting out TTL.
if peer.ProxyMeta.Embedded {
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}); err != nil {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Warnf("notify network map controller of embedded proxy %s disconnect: %v", peer.ID, err)
}
}
@@ -345,7 +351,10 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
affectedPeerIDs = append(affectedPeerIDs, peer.ID)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -503,6 +512,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var peer *nbpeer.Peer
var settings *types.Settings
var eventsToStore []func()
var affectedPeerIDs []string
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
@@ -527,6 +537,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID})
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -550,7 +562,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
@@ -923,12 +935,18 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
if err != nil {
return p, nmap, pc, err
}
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := affectedPeerIDsFromNetworkMap(nmap, newPeer.ID)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
p, nmap, pc, _, err := am.networkMapController.GetValidatedPeerWithMap(ctx, false, accountID, newPeer)
return p, nmap, pc, err
return p, nmap, pc, nil
}
func getPeerIPDNSLabel(ip netip.Addr, peerHostName string) (string, error) {
@@ -1011,7 +1029,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || ipv6CapabilityChanged || (updated && (len(postureChecks) > 0 || versionChanged)) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1141,7 +1161,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
if updateRemotePeers || isStatusChanged || ipv6CapabilityChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1406,6 +1428,99 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
_ = am.networkMapController.UpdateAccountPeers(ctx, accountID, reason)
}
// UpdateAffectedPeers updates only the specified peers that belong to an account.
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
ctx = context.WithoutCancel(ctx)
log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID)
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolvePeerIDs resolves group IDs and direct peer IDs into a deduplicated peer ID list.
func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Store, accountID string, groupIDs []string, directPeerIDs []string) []string {
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs by groups: %v", err)
return nil
}
if len(directPeerIDs) == 0 {
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers: %v", groupIDs, len(peerIDs), peerIDs)
return peerIDs
}
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers: %v", groupIDs, directPeerIDs, len(peerIDs), peerIDs)
return peerIDs
}
// BufferUpdateAffectedPeers accumulates peer IDs and flushes them after the buffer interval.
func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs, reason)
}
// ResolveAffectedPeers resolves a description of what changed into the peer IDs
// whose network map may have changed. It is the single entry point shared by the
// server package and the networks managers (via the account.Manager interface).
func (am *DefaultAccountManager) ResolveAffectedPeers(ctx context.Context, s store.Store, accountID string, change affectedpeers.Change) []string {
peerIDs, err := affectedpeers.Resolve(ctx, s, accountID, change)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve affected peers: %v", err)
return nil
}
return peerIDs
}
// affectedPeerIDsFromNetworkMap returns the peer IDs referenced by a peer's
// network map (its connected and offline peers, which include routing and proxy
// peers), excluding the peer itself. For a freshly added peer these are, by ACL
// symmetry, exactly the peers its addition affects.
func affectedPeerIDsFromNetworkMap(nmap *types.NetworkMap, selfPeerID string) []string {
if nmap == nil {
return nil
}
seen := make(map[string]struct{}, len(nmap.Peers)+len(nmap.OfflinePeers))
ids := make([]string, 0, len(nmap.Peers)+len(nmap.OfflinePeers))
add := func(peers []*nbpeer.Peer) {
for _, p := range peers {
if p == nil || p.ID == "" || p.ID == selfPeerID {
continue
}
if _, ok := seen[p.ID]; ok {
continue
}
seen[p.ID] = struct{}{}
ids = append(ids, p.ID)
}
}
add(nmap.Peers)
add(nmap.OfflinePeers)
return ids
}
// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of affected peer IDs.
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err)
return nil
}
return am.ResolveAffectedPeers(ctx, s, accountID, affectedpeers.Change{
ChangedGroupIDs: groupIDs,
ChangedPeerIDs: changedPeerIDs,
})
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID, reason)
}

View File

@@ -1855,7 +1855,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("adding peer to unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg) //
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1880,7 +1880,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
t.Run("deleting peer with unlinked group", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2018,7 +2018,10 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
}
})
// Adding peer to group linked with route should update account peers and send peer update
// drain any buffered updates from previous subtests
drainPeerUpdates(updMsg)
// Adding peer to group linked with route should update peers in that group, not unrelated peers
t.Run("adding peer to group linked with route", func(t *testing.T) {
route := nbroute.Route{
ID: "testingRoute1",
@@ -2042,7 +2045,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2059,16 +2062,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting peer with linked group to route should update account peers and send peer update
// Deleting peer with linked group to route should update peers in that group, not unrelated peers
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2077,12 +2080,12 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Adding peer to group linked with name server group should update account peers and send peer update
// Adding peer to group linked with name server group should update peers in that group, not unrelated peers
t.Run("adding peer to group linked with name server group", func(t *testing.T) {
_, err = manager.CreateNameServerGroup(
context.Background(), account.Id, "nsGroup", "nsGroup", []nbdns.NameServer{{
@@ -2097,7 +2100,7 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2114,16 +2117,16 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
// Deleting peer with linked group to name server group should update account peers and send peer update
// Deleting peer with linked group to name server group should update peers in that group, not unrelated peers
t.Run("deleting peer with linked group to route", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2132,8 +2135,8 @@ func TestPeerAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})
}

View File

@@ -5,7 +5,7 @@ import (
_ "embed"
"github.com/rs/xid"
"github.com/sirupsen/logrus"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
@@ -13,6 +13,7 @@ import (
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -45,44 +46,37 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
}
var isUpdate = policy.ID != ""
var updateAccountPeers bool
var existingPolicy *types.Policy
var action = activity.PolicyAdded
var unchanged bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy)
existingPolicy, err = validatePolicy(ctx, transaction, accountID, policy)
if err != nil {
return err
}
if isUpdate {
if policy.Equal(existingPolicy) {
logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
log.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID)
unchanged = true
return nil
}
action = activity.PolicyUpdated
updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy)
if err != nil {
return err
}
if err = transaction.SavePolicy(ctx, policy); err != nil {
return err
}
} else {
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
if err = transaction.CreatePolicy(ctx, policy); err != nil {
return err
}
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Policies: []*types.Policy{policy, existingPolicy}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -95,12 +89,11 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
policyOp := types.UpdateOperationCreate
if isUpdate {
policyOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: policyOp})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Tracef("SavePolicy %s: updating %d affected peers: %v", policy.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SavePolicy %s: no affected peers", policy.ID)
}
return policy, nil
@@ -117,7 +110,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
}
var policy *types.Policy
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
policy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthUpdate, accountID, policyID)
@@ -125,10 +118,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
return err
}
updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy)
if err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Policies: []*types.Policy{policy}})
if err = transaction.DeletePolicy(ctx, accountID, policyID); err != nil {
return err
@@ -142,8 +132,11 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePolicy, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeletePolicy %s: updating %d affected peers: %v", policyID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeletePolicy %s: no affected peers", policyID)
}
return nil
@@ -162,46 +155,6 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us
return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
}
// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers.
func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) {
if !policy.Enabled && !existingPolicy.Enabled {
return false, nil
}
for _, rule := range existingPolicy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" {
return true, nil
}
}
return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups())
}
// validatePolicy validates the policy and its rules. For updates it returns
// the existing policy loaded from the store so callers can avoid a second read.
func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) {

View File

@@ -1319,12 +1319,14 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
}
})
// Updating disabled policy with destination and source groups containing peers should not update account's peers
// or send peer update
// Updating disabled policy with destination and source groups containing peers should still update account's peers
// because affected peer resolution does not filter by policy enabled state
t.Run("updating disabled policy with source and destination groups with peers", func(t *testing.T) {
drainPeerUpdates(updMsg)
done := make(chan struct{})
go func() {
peerShouldNotReceiveUpdate(t, updMsg)
peerShouldReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1335,8 +1337,8 @@ func TestPolicyAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
}
})

View File

@@ -5,13 +5,14 @@ import (
"slices"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/posture"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -41,9 +42,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
return nil, status.NewPermissionDeniedError()
}
var updateAccountPeers bool
var isUpdate = postureChecks.ID != ""
var action = activity.PostureCheckCreated
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validatePostureChecks(ctx, transaction, accountID, postureChecks); err != nil {
@@ -51,12 +52,9 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
}
if isUpdate {
updateAccountPeers, err = arePostureCheckChangesAffectPeers(ctx, transaction, accountID, postureChecks.ID)
if err != nil {
return err
}
action = activity.PostureCheckUpdated
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{PostureCheckIDs: []string{postureChecks.ID}})
}
postureChecks.AccountID = accountID
@@ -76,12 +74,11 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
postureOp := types.UpdateOperationCreate
if isUpdate {
postureOp = types.UpdateOperationUpdate
}
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePostureCheck, Operation: postureOp})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SavePostureChecks %s: updating %d affected peers: %v", postureChecks.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SavePostureChecks %s: no affected peers", postureChecks.ID)
}
return postureChecks, nil
@@ -137,29 +134,6 @@ func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountI
return am.Store.GetAccountPostureChecks(ctx, store.LockingStrengthNone, accountID)
}
// arePostureCheckChangesAffectPeers checks if the changes in posture checks are affecting peers.
func arePostureCheckChangesAffectPeers(ctx context.Context, transaction store.Store, accountID, postureCheckID string) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, policy := range policies {
if slices.Contains(policy.SourcePostureChecks, postureCheckID) {
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, accountID, policy.RuleGroups())
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
}
}
return false, nil
}
// validatePostureChecks validates the posture checks.
func validatePostureChecks(ctx context.Context, transaction store.Store, accountID string, postureChecks *posture.Checks) error {
if err := postureChecks.Validate(); err != nil {

View File

@@ -503,21 +503,20 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
require.NoError(t, err, "failed to save policy")
t.Run("posture check exists and is linked to policy with peers", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check exists but is not linked to any policy", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
require.NoError(t, err)
assert.False(t, result)
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckB.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
t.Run("posture check does not exist", func(t *testing.T) {
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, "unknown")
require.NoError(t, err)
assert.False(t, result)
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, "unknown")
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
t.Run("posture check is linked to policy with no peers in source groups", func(t *testing.T) {
@@ -526,9 +525,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy with no peers in destination groups", func(t *testing.T) {
@@ -537,9 +535,8 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.True(t, result)
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy but no peers in groups", func(t *testing.T) {
@@ -547,9 +544,9 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
err = manager.UpdateGroup(context.Background(), account.Id, adminUserID, groupA)
require.NoError(t, err, "failed to save groups")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
// The collector returns groups even if they have no peers — the groups are still referenced
groupIDs, _ := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.NotEmpty(t, groupIDs)
})
t.Run("posture check is linked to policy with non-existent group", func(t *testing.T) {
@@ -558,8 +555,10 @@ func TestArePostureCheckChangesAffectPeers(t *testing.T) {
_, err = manager.SavePolicy(context.Background(), account.Id, adminUserID, policy, true)
require.NoError(t, err, "failed to update policy")
result, err := arePostureCheckChangesAffectPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
require.NoError(t, err)
assert.False(t, result)
// Non-existent groups are filtered out during SavePolicy validation,
// so the saved policy has empty Sources/Destinations
groupIDs, directPeerIDs := collectPostureCheckAffectedGroupsAndPeers(context.Background(), manager.Store, account.Id, postureCheckA.ID)
assert.Empty(t, groupIDs)
assert.Empty(t, directPeerIDs)
})
}

View File

@@ -8,8 +8,10 @@ import (
"unicode/utf8"
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/affectedpeers"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
@@ -147,7 +149,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
}
var newRoute *route.Route
var updateAccountPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
newRoute = &route.Route{
@@ -173,15 +175,12 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, newRoute)
if err != nil {
return err
}
if err = transaction.SaveRoute(ctx, newRoute); err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{newRoute}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -190,8 +189,11 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationCreate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("CreateRoute %s: updating %d affected peers: %v", newRoute.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("CreateRoute %s: no affected peers", newRoute.ID)
}
return newRoute, nil
@@ -208,8 +210,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
}
var oldRoute *route.Route
var oldRouteAffectsPeers bool
var newRouteAffectsPeers bool
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
if err = validateRoute(ctx, transaction, accountID, routeToSave); err != nil {
@@ -221,21 +222,14 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
return err
}
oldRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, oldRoute)
if err != nil {
return err
}
newRouteAffectsPeers, err = areRouteChangesAffectPeers(ctx, transaction, routeToSave)
if err != nil {
return err
}
routeToSave.AccountID = accountID
if err = transaction.SaveRoute(ctx, routeToSave); err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{routeToSave, oldRoute}})
return transaction.IncrementNetworkSerial(ctx, accountID)
})
if err != nil {
@@ -244,8 +238,11 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationUpdate})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("SaveRoute %s: updating %d affected peers: %v", routeToSave.ID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("SaveRoute %s: no affected peers", routeToSave.ID)
}
return nil
@@ -261,19 +258,16 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return status.NewPermissionDeniedError()
}
var route *route.Route
var updateAccountPeers bool
var rt *route.Route
var affectedPeerIDs []string
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
route, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
rt, err = transaction.GetRouteByID(ctx, store.LockingStrengthUpdate, accountID, string(routeID))
if err != nil {
return err
}
updateAccountPeers, err = areRouteChangesAffectPeers(ctx, transaction, route)
if err != nil {
return err
}
affectedPeerIDs = am.ResolveAffectedPeers(ctx, transaction, accountID, affectedpeers.Change{Routes: []*route.Route{rt}})
if err = transaction.DeleteRoute(ctx, accountID, string(routeID)); err != nil {
return err
@@ -285,10 +279,13 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
return fmt.Errorf("failed to delete route %s: %w", routeID, err)
}
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
am.StoreEvent(ctx, userID, string(rt.ID), accountID, activity.RouteRemoved, rt.EventMeta())
if updateAccountPeers {
am.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceRoute, Operation: types.UpdateOperationDelete})
if len(affectedPeerIDs) > 0 {
log.WithContext(ctx).Debugf("DeleteRoute %s: updating %d affected peers: %v", routeID, len(affectedPeerIDs), affectedPeerIDs)
am.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
} else {
log.WithContext(ctx).Tracef("DeleteRoute %s: no affected peers", routeID)
}
return nil
@@ -377,25 +374,6 @@ func getPlaceholderIP() netip.Prefix {
return netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 0, 2, 0}), 32)
}
// areRouteChangesAffectPeers checks if a given route affects peers by determining
// if it has a routing peer, distribution, or peer groups that include peers.
func areRouteChangesAffectPeers(ctx context.Context, transaction store.Store, route *route.Route) (bool, error) {
if route.Peer != "" {
return true, nil
}
hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.Groups)
if err != nil {
return false, err
}
if hasPeers {
return true, nil
}
return anyGroupHasPeersOrResources(ctx, transaction, route.AccountID, route.PeerGroups)
}
// GetRoutesByPrefixOrDomains return list of routes by account and route prefix
func getRoutesByPrefixOrDomains(ctx context.Context, transaction store.Store, accountID string, prefix netip.Prefix, domains domain.List) ([]*route.Route, error) {
accountRoutes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)

View File

@@ -1962,8 +1962,10 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
})
// Creating a route with no routing peer and having peers in groups should update account peers and send peer update
// Creating a route with no routing peer and having peers in groups that don't include peer1 should not send peer1 an update
t.Run("creating a route with peers in PeerGroups and Groups", func(t *testing.T) {
drainPeerUpdates(updMsg)
route := route.Route{
ID: "testingRoute2",
Network: netip.MustParsePrefix("192.0.2.0/32"),
@@ -1979,7 +1981,7 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -1992,8 +1994,8 @@ func TestRouteAccountPeersUpdate(t *testing.T) {
select {
case <-done:
case <-time.After(peerUpdateTimeout):
t.Error("timeout waiting for peerShouldReceiveUpdate")
case <-time.After(time.Second):
t.Error("timeout waiting for peerShouldNotReceiveUpdate")
}
})

View File

@@ -265,7 +265,8 @@ func (s *SqlStore) AcquireGlobalLock(ctx context.Context) (unlock func()) {
return unlock
}
// Deprecated: Full account operations are no longer supported
// Deprecated: Full
// account operations are no longer supported
func (s *SqlStore) SaveAccount(ctx context.Context, account *types.Account) error {
start := time.Now()
defer func() {
@@ -4894,6 +4895,64 @@ func (s *SqlStore) GetPeersByGroupIDs(ctx context.Context, accountID string, gro
return peers, nil
}
func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
if len(groupIDs) == 0 {
return nil, nil
}
var peerIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT peer_id").
Where("account_id = ? AND group_id IN ?", accountID, groupIDs).
Pluck("peer_id", &peerIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get peer IDs by groups: %s", result.Error)
}
return peerIDs, nil
}
func (s *SqlStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
if len(peerIDs) == 0 {
return nil, nil
}
var groupIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT group_id").
Where("account_id = ? AND peer_id IN ?", accountID, peerIDs).
Pluck("group_id", &groupIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get group IDs by peers: %s", result.Error)
}
return groupIDs, nil
}
// GetEmbeddedProxyPeerIDsByCluster returns peer IDs of all embedded proxy peers
// in the account, grouped by their ProxyCluster. The map is nil when no embedded
// proxy peers exist.
func (s *SqlStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
type row struct {
ID string
Cluster string
}
var rows []row
result := s.db.Model(&nbpeer.Peer{}).
Select("id, proxy_meta_cluster AS cluster").
Where("account_id = ? AND proxy_meta_embedded = ?", accountID, true).
Scan(&rows)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get embedded proxy peers: %s", result.Error)
}
out := make(map[string][]string, len(rows))
for _, r := range rows {
out[r.Cluster] = append(out[r.Cluster], r.ID)
}
return out, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -162,6 +162,9 @@ type Store interface {
GetPeerByID(ctx context.Context, lockStrength LockingStrength, accountID string, peerID string) (*nbpeer.Peer, error)
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error)
GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)

View File

@@ -1925,6 +1925,51 @@ func (mr *MockStoreMockRecorder) GetPeersByGroupIDs(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeersByGroupIDs", reflect.TypeOf((*MockStore)(nil).GetPeersByGroupIDs), ctx, accountID, groupIDs)
}
// GetPeerIDsByGroups mocks base method.
func (m *MockStore) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetPeerIDsByGroups", ctx, accountID, groupIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetPeerIDsByGroups indicates an expected call of GetPeerIDsByGroups.
func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetEmbeddedProxyPeerIDsByCluster mocks base method.
func (m *MockStore) GetEmbeddedProxyPeerIDsByCluster(ctx context.Context, accountID string) (map[string][]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetEmbeddedProxyPeerIDsByCluster", ctx, accountID)
ret0, _ := ret[0].(map[string][]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetEmbeddedProxyPeerIDsByCluster indicates an expected call of GetEmbeddedProxyPeerIDsByCluster.
func (mr *MockStoreMockRecorder) GetEmbeddedProxyPeerIDsByCluster(ctx, accountID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEmbeddedProxyPeerIDsByCluster", reflect.TypeOf((*MockStore)(nil).GetEmbeddedProxyPeerIDsByCluster), ctx, accountID)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1157,7 +1157,8 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, peerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1273,6 +1274,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var userPeers []*nbpeer.Peer
var targetUser *types.User
var settings *types.Settings
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1293,6 +1295,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
if len(userPeers) > 0 {
updateAccountPeers = true
var peerIDs []string
for _, peer := range userPeers {
peerIDs = append(peerIDs, peer.ID)
}
// Resolve before delete so group memberships are still present.
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, peerIDs)
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings)
if err != nil {
return fmt.Errorf("failed to delete user peers: %w", err)
@@ -1316,7 +1326,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err)
}
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}

View File

@@ -846,7 +846,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
@@ -962,7 +962,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
@@ -1531,11 +1531,14 @@ func TestUserAccountPeersUpdate(t *testing.T) {
}
})
// drain any buffered updates from previous subtests
drainPeerUpdates(updMsg)
// deleting user with no linked peers should not update account peers and not send peer update
t.Run("deleting user with no linked peers", func(t *testing.T) {
done := make(chan struct{})
go func() {
peerShouldReceiveUpdate(t, updMsg)
peerShouldNotReceiveUpdate(t, updMsg)
close(done)
}()
@@ -2022,7 +2025,7 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()