diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index f13eafbcf..b14d0d81a 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -793,21 +793,24 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t return false, nil } -func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { - err := c.bufferSendUpdateAccountPeers(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) +func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error { + if len(affectedPeerIDs) == 0 { + log.WithContext(ctx).Tracef("no affected peers for peer update in account %s, skipping", accountID) + return nil } - - return nil + return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } -func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { +func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error { log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) - return c.bufferSendUpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) == 0 { + log.WithContext(ctx).Tracef("no affected peers for peer add in account %s, skipping", accountID) + return nil + } + return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } -func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { +func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error { network, err := c.repo.GetAccountNetwork(ctx, accountID) if err != nil { return err @@ -840,7 +843,11 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI c.peersUpdateManager.CloseChannel(ctx, peerID) } - return c.bufferSendUpdateAccountPeers(ctx, accountID) + if len(affectedPeerIDs) == 0 { + log.WithContext(ctx).Tracef("no affected peers for peer delete in account %s, skipping network map update", accountID) + return nil + } + return c.BufferUpdateAffectedPeers(ctx, accountID, affectedPeerIDs) } // GetNetworkMap returns Network map for a given peer (omits original peer from the Peers result) diff --git a/management/internals/controllers/network_map/interface.go b/management/internals/controllers/network_map/interface.go index 4b9fdee12..8cd84b605 100644 --- a/management/internals/controllers/network_map/interface.go +++ b/management/internals/controllers/network_map/interface.go @@ -29,9 +29,9 @@ type Controller interface { GetNetworkMap(ctx context.Context, peerID string) (*types.NetworkMap, error) CountStreams() int - OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error - OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error - OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error + OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error + OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error + OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error DisconnectPeers(ctx context.Context, accountId string, peerIDs []string) OnPeerConnected(ctx context.Context, accountID string, peerID string) (chan *UpdateMessage, error) OnPeerDisconnected(ctx context.Context, accountID string, peerID string) diff --git a/management/internals/controllers/network_map/interface_mock.go b/management/internals/controllers/network_map/interface_mock.go index 15b6bdc56..1178969ff 100644 --- a/management/internals/controllers/network_map/interface_mock.go +++ b/management/internals/controllers/network_map/interface_mock.go @@ -172,45 +172,45 @@ func (mr *MockControllerMockRecorder) OnPeerDisconnected(ctx, accountID, peerID } // OnPeersAdded mocks base method. -func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { +func (m *MockController) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs) + ret := m.ctrl.Call(m, "OnPeersAdded", ctx, accountID, peerIDs, affectedPeerIDs) ret0, _ := ret[0].(error) return ret0 } // OnPeersAdded indicates an expected call of OnPeersAdded. -func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) OnPeersAdded(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersAdded", reflect.TypeOf((*MockController)(nil).OnPeersAdded), ctx, accountID, peerIDs, affectedPeerIDs) } // OnPeersDeleted mocks base method. -func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string) error { +func (m *MockController) OnPeersDeleted(ctx context.Context, accountID string, peerIDs []string, affectedPeerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs) + ret := m.ctrl.Call(m, "OnPeersDeleted", ctx, accountID, peerIDs, affectedPeerIDs) ret0, _ := ret[0].(error) return ret0 } // OnPeersDeleted indicates an expected call of OnPeersDeleted. -func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersDeleted", reflect.TypeOf((*MockController)(nil).OnPeersDeleted), ctx, accountID, peerIDs, affectedPeerIDs) } // OnPeersUpdated mocks base method. -func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string) error { +func (m *MockController) OnPeersUpdated(ctx context.Context, accountId string, peerIDs []string, affectedPeerIDs []string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs) + ret := m.ctrl.Call(m, "OnPeersUpdated", ctx, accountId, peerIDs, affectedPeerIDs) ret0, _ := ret[0].(error) return ret0 } // OnPeersUpdated indicates an expected call of OnPeersUpdated. -func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs any) *gomock.Call { +func (mr *MockControllerMockRecorder) OnPeersUpdated(ctx, accountId, peerIDs, affectedPeerIDs any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OnPeersUpdated", reflect.TypeOf((*MockController)(nil).OnPeersUpdated), ctx, accountId, peerIDs, affectedPeerIDs) } // StartWarmup mocks base method. diff --git a/management/server/account.go b/management/server/account.go index 7d53cef03..cbf67904c 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -2215,7 +2215,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us if err != nil { return err } - err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID}) + changedPeerIDs := []string{peerID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs) if err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } diff --git a/management/server/group.go b/management/server/group.go index 4bd249398..cc19cf1a4 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -835,11 +835,9 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto changedSet[id] = struct{}{} } + log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs) + groupSet := make(map[string]struct{}) - // Always include the changed groups themselves - for _, id := range changedGroupIDs { - groupSet[id] = struct{}{} - } peerSet := make(map[string]struct{}) @@ -852,14 +850,17 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto if !policyReferencesGroups(policy, changedSet) { continue } + log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups", policy.ID, policy.Name) for _, gID := range policy.RuleGroups() { groupSet[gID] = struct{}{} } for _, rule := range policy.Rules { if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { + log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID) peerSet[rule.SourceResource.ID] = struct{}{} } if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { + log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID) peerSet[rule.DestinationResource.ID] = struct{}{} } } @@ -875,6 +876,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto if !routeReferencesGroups(r, changedSet) { continue } + log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description) for _, gID := range r.Groups { groupSet[gID] = struct{}{} } @@ -885,6 +887,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto groupSet[gID] = struct{}{} } if r.Peer != "" { + log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer) peerSet[r.Peer] = struct{}{} } } @@ -898,6 +901,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto for _, ns := range nsGroups { for _, gID := range ns.Groups { if _, ok := changedSet[gID]; ok { + log.WithContext(ctx).Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID) for _, g := range ns.Groups { groupSet[g] = struct{}{} } @@ -914,6 +918,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto } else { for _, gID := range dnsSettings.DisabledManagementGroups { if _, ok := changedSet[gID]; ok { + log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID) groupSet[gID] = struct{}{} } } @@ -928,10 +933,12 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto if !routerReferencesGroups(router, changedSet) { continue } + log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID) for _, gID := range router.PeerGroups { groupSet[gID] = struct{}{} } if router.Peer != "" { + log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer) peerSet[router.Peer] = struct{}{} } } @@ -947,6 +954,8 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto directPeerIDs = append(directPeerIDs, pID) } + log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs) + return allGroupIDs, directPeerIDs } diff --git a/management/server/peer.go b/management/server/peer.go index 509d53ebb..ee3a6369f 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -150,7 +150,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK } if expired { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + changedPeerIDs := []string{peer.ID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs) if err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } @@ -334,7 +336,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } } - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + changedPeerIDs := []string{peer.ID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs) if err != nil { return nil, fmt.Errorf("notify network map controller of peer update: %w", err) } @@ -492,6 +496,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer var peer *nbpeer.Peer var settings *types.Settings var eventsToStore []func() + var affectedPeerIDs []string serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID) if err != nil { @@ -516,6 +521,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } + affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID}) + eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings) if err != nil { return fmt.Errorf("failed to delete peer: %w", err) @@ -539,7 +546,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err) } - if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil { + if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil { log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err) } @@ -863,7 +870,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) } - if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil { + changedPeerIDs := []string{newPeer.ID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil { log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err) } @@ -946,7 +955,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy } if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + changedPeerIDs := []string{peer.ID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs) if err != nil { return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err) } @@ -1073,7 +1084,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer } if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) { - err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID}) + changedPeerIDs := []string{peer.ID} + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs) if err != nil { return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err) } @@ -1297,6 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account // UpdateAffectedPeers updates only the specified peers that belong to an account. // Should be called when a change is known to affect only a subset of peers. func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) { + log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID) _ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs) } @@ -1310,6 +1324,7 @@ func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Sto } if len(directPeerIDs) == 0 { + log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers", groupIDs, len(peerIDs)) return peerIDs } @@ -1323,6 +1338,8 @@ func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Sto seen[id] = struct{}{} } } + + log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers", groupIDs, directPeerIDs, len(peerIDs)) return peerIDs } @@ -1332,6 +1349,25 @@ func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context, _ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs) } +// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of +// affected peers: finds groups containing the changed peers, walks all entity linkages, +// and resolves back to peer IDs. +func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string { + groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err) + return nil + } + + log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs) + + allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs) + result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs) + + log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result)) + return result +} + func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { _ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index c285b70c7..0a2e31538 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -4679,6 +4679,23 @@ func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, gro return peerIDs, nil } +func (s *SqlStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) { + if len(peerIDs) == 0 { + return nil, nil + } + + var groupIDs []string + result := s.db.Model(&types.GroupPeer{}). + Select("DISTINCT group_id"). + Where("account_id = ? AND peer_id IN ?", accountID, peerIDs). + Pluck("group_id", &groupIDs) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get group IDs by peers: %s", result.Error) + } + + return groupIDs, nil +} + func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/store.go b/management/server/store/store.go index 82489615f..fddc17f1c 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -160,6 +160,7 @@ type Store interface { GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error) GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error) GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error) + GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error) GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error) diff --git a/management/server/store/store_mock.go b/management/server/store/store_mock.go index 70366ed44..dac042a18 100644 --- a/management/server/store/store_mock.go +++ b/management/server/store/store_mock.go @@ -1868,6 +1868,21 @@ func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs int return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs) } +// GetGroupIDsByPeerIDs mocks base method. +func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs) + ret0, _ := ret[0].([]string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs. +func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs) +} + // GetPeersByIDs mocks base method. func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) { m.ctrl.T.Helper() diff --git a/management/server/user.go b/management/server/user.go index c1f984f2f..52a4eadc8 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1154,7 +1154,8 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou } } - err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs) + affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, peerIDs) + err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs, affectedPeerIDs) if err != nil { return fmt.Errorf("notify network map controller of peer update: %w", err) } @@ -1270,6 +1271,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI var userPeers []*nbpeer.Peer var targetUser *types.User var settings *types.Settings + var affectedPeerIDs []string var err error err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { @@ -1290,6 +1292,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI if len(userPeers) > 0 { updateAccountPeers = true + + var peerIDs []string + for _, peer := range userPeers { + peerIDs = append(peerIDs, peer.ID) + } + // Resolve before delete so group memberships are still present. + affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, peerIDs) + addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings) if err != nil { return fmt.Errorf("failed to delete user peers: %w", err) @@ -1313,7 +1323,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err) } } - if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil { + if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs); err != nil { log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err) } diff --git a/management/server/user_test.go b/management/server/user_test.go index c77ea53d1..68fc58eef 100644 --- a/management/server/user_test.go +++ b/management/server/user_test.go @@ -846,7 +846,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) { ctrl := gomock.NewController(t) networkMapControllerMock := network_map.NewMockController(ctrl) networkMapControllerMock.EXPECT(). - OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil) permissionsManager := permissions.NewManager(store) @@ -962,7 +962,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) { ctrl := gomock.NewController(t) networkMapControllerMock := network_map.NewMockController(ctrl) networkMapControllerMock.EXPECT(). - OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). AnyTimes() @@ -2022,7 +2022,7 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) { ctrl := gomock.NewController(t) networkMapControllerMock := network_map.NewMockController(ctrl) networkMapControllerMock.EXPECT(). - OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()). + OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). Return(nil). AnyTimes()