affected filtering on peers update

This commit is contained in:
pascal
2026-04-28 13:48:01 +02:00
parent 5a16c812fd
commit cefb37e920
11 changed files with 138 additions and 41 deletions

View File

@@ -2215,7 +2215,9 @@ func (am *DefaultAccountManager) UpdatePeerIP(ctx context.Context, accountID, us
if err != nil {
return err
}
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, []string{peerID})
changedPeerIDs := []string{peerID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, peer.AccountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}

View File

@@ -835,11 +835,9 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
changedSet[id] = struct{}{}
}
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
groupSet := make(map[string]struct{})
// Always include the changed groups themselves
for _, id := range changedGroupIDs {
groupSet[id] = struct{}{}
}
peerSet := make(map[string]struct{})
@@ -852,14 +850,17 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
if !policyReferencesGroups(policy, changedSet) {
continue
}
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups", policy.ID, policy.Name)
for _, gID := range policy.RuleGroups() {
groupSet[gID] = struct{}{}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID)
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID)
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
@@ -875,6 +876,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
if !routeReferencesGroups(r, changedSet) {
continue
}
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
@@ -885,6 +887,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer)
peerSet[r.Peer] = struct{}{}
}
}
@@ -898,6 +901,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
for _, ns := range nsGroups {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
@@ -914,6 +918,7 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
} else {
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
groupSet[gID] = struct{}{}
}
}
@@ -928,10 +933,12 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
if !routerReferencesGroups(router, changedSet) {
continue
}
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
peerSet[router.Peer] = struct{}{}
}
}
@@ -947,6 +954,8 @@ func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Sto
directPeerIDs = append(directPeerIDs, pID)
}
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}

View File

@@ -150,7 +150,9 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
}
if expired {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -334,7 +336,9 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -492,6 +496,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
var peer *nbpeer.Peer
var settings *types.Settings
var eventsToStore []func()
var affectedPeerIDs []string
serviceID, err := am.serviceManager.GetServiceIDByTargetID(ctx, accountID, peerID)
if err != nil {
@@ -516,6 +521,8 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
return err
}
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, []string{peerID})
eventsToStore, err = deletePeers(ctx, am, transaction, accountID, userID, []*nbpeer.Peer{peer}, settings)
if err != nil {
return fmt.Errorf("failed to delete peer: %w", err)
@@ -539,7 +546,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peerID, err)
}
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}); err != nil {
if err = am.networkMapController.OnPeersDeleted(ctx, accountID, []string{peerID}, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peer %s from network map: %v", peerID, err)
}
@@ -863,7 +870,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, accountID, setupKe
am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta)
}
if err := am.networkMapController.OnPeersAdded(ctx, accountID, []string{newPeer.ID}); err != nil {
changedPeerIDs := []string{newPeer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err := am.networkMapController.OnPeersAdded(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for peer %s: %v", newPeer.ID, err)
}
@@ -946,7 +955,9 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
if isStatusChanged || sync.UpdateAccountPeers || (updated && (len(postureChecks) > 0 || versionChanged)) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, 0, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1073,7 +1084,9 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
if updateRemotePeers || isStatusChanged || (isPeerUpdated && len(postureChecks) > 0) {
err = am.networkMapController.OnPeersUpdated(ctx, accountID, []string{peer.ID})
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs)
if err != nil {
return nil, nil, nil, fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1297,6 +1310,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account
// UpdateAffectedPeers updates only the specified peers that belong to an account.
// Should be called when a change is known to affect only a subset of peers.
func (am *DefaultAccountManager) UpdateAffectedPeers(ctx context.Context, accountID string, peerIDs []string) {
log.WithContext(ctx).Tracef("UpdateAffectedPeers: %d peers for account %s", len(peerIDs), accountID)
_ = am.networkMapController.UpdateAffectedPeers(ctx, accountID, peerIDs)
}
@@ -1310,6 +1324,7 @@ func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Sto
}
if len(directPeerIDs) == 0 {
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v -> %d peers", groupIDs, len(peerIDs))
return peerIDs
}
@@ -1323,6 +1338,8 @@ func (am *DefaultAccountManager) resolvePeerIDs(ctx context.Context, s store.Sto
seen[id] = struct{}{}
}
}
log.WithContext(ctx).Tracef("resolvePeerIDs: groups=%v + directPeers=%v -> %d peers", groupIDs, directPeerIDs, len(peerIDs))
return peerIDs
}
@@ -1332,6 +1349,25 @@ func (am *DefaultAccountManager) BufferUpdateAffectedPeers(ctx context.Context,
_ = am.networkMapController.BufferUpdateAffectedPeers(ctx, accountID, peerIDs)
}
// resolveAffectedPeersForPeerChanges resolves changed peer IDs into the full set of
// affected peers: finds groups containing the changed peers, walks all entity linkages,
// and resolves back to peer IDs.
func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.Context, s store.Store, accountID string, changedPeerIDs []string) []string {
groupIDs, err := s.GetGroupIDsByPeerIDs(ctx, accountID, changedPeerIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to get group IDs for changed peers: %v", err)
return nil
}
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs)
result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs)
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))
return result
}
func (am *DefaultAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
_ = am.networkMapController.BufferUpdateAccountPeers(ctx, accountID)
}

View File

@@ -4679,6 +4679,23 @@ func (s *SqlStore) GetPeerIDsByGroups(ctx context.Context, accountID string, gro
return peerIDs, nil
}
func (s *SqlStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
if len(peerIDs) == 0 {
return nil, nil
}
var groupIDs []string
result := s.db.Model(&types.GroupPeer{}).
Select("DISTINCT group_id").
Where("account_id = ? AND peer_id IN ?", accountID, peerIDs).
Pluck("group_id", &groupIDs)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get group IDs by peers: %s", result.Error)
}
return groupIDs, nil
}
func (s *SqlStore) GetUserIDByPeerKey(ctx context.Context, lockStrength LockingStrength, peerKey string) (string, error) {
tx := s.db
if lockStrength != LockingStrengthNone {

View File

@@ -160,6 +160,7 @@ type Store interface {
GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*nbpeer.Peer, error)
GetPeersByGroupIDs(ctx context.Context, accountID string, groupIDs []string) ([]*nbpeer.Peer, error)
GetPeerIDsByGroups(ctx context.Context, accountID string, groupIDs []string) ([]string, error)
GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error)
GetAccountPeersWithExpiration(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAccountPeersWithInactivity(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbpeer.Peer, error)
GetAllEphemeralPeers(ctx context.Context, lockStrength LockingStrength) ([]*nbpeer.Peer, error)

View File

@@ -1868,6 +1868,21 @@ func (mr *MockStoreMockRecorder) GetPeerIDsByGroups(ctx, accountID, groupIDs int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetPeerIDsByGroups", reflect.TypeOf((*MockStore)(nil).GetPeerIDsByGroups), ctx, accountID, groupIDs)
}
// GetGroupIDsByPeerIDs mocks base method.
func (m *MockStore) GetGroupIDsByPeerIDs(ctx context.Context, accountID string, peerIDs []string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupIDsByPeerIDs", ctx, accountID, peerIDs)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupIDsByPeerIDs indicates an expected call of GetGroupIDsByPeerIDs.
func (mr *MockStoreMockRecorder) GetGroupIDsByPeerIDs(ctx, accountID, peerIDs interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupIDsByPeerIDs", reflect.TypeOf((*MockStore)(nil).GetGroupIDsByPeerIDs), ctx, accountID, peerIDs)
}
// GetPeersByIDs mocks base method.
func (m *MockStore) GetPeersByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, peerIDs []string) (map[string]*peer.Peer, error) {
m.ctrl.T.Helper()

View File

@@ -1154,7 +1154,8 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
}
}
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs)
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, peerIDs)
err = am.networkMapController.OnPeersUpdated(ctx, accountID, peerIDs, affectedPeerIDs)
if err != nil {
return fmt.Errorf("notify network map controller of peer update: %w", err)
}
@@ -1270,6 +1271,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
var userPeers []*nbpeer.Peer
var targetUser *types.User
var settings *types.Settings
var affectedPeerIDs []string
var err error
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
@@ -1290,6 +1292,14 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
if len(userPeers) > 0 {
updateAccountPeers = true
var peerIDs []string
for _, peer := range userPeers {
peerIDs = append(peerIDs, peer.ID)
}
// Resolve before delete so group memberships are still present.
affectedPeerIDs = am.resolveAffectedPeersForPeerChanges(ctx, transaction, accountID, peerIDs)
addPeerRemovedEvents, err = deletePeers(ctx, am, transaction, accountID, targetUserInfo.ID, userPeers, settings)
if err != nil {
return fmt.Errorf("failed to delete user peers: %w", err)
@@ -1313,7 +1323,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, accountI
log.WithContext(ctx).Errorf("failed to delete peer %s from integrated validator: %v", peer.ID, err)
}
}
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs); err != nil {
if err := am.networkMapController.OnPeersDeleted(ctx, accountID, peerIDs, affectedPeerIDs); err != nil {
log.WithContext(ctx).Errorf("failed to delete peers %s from network map: %v", peerIDs, err)
}

View File

@@ -846,7 +846,7 @@ func TestUser_DeleteUser_regularUser(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil)
permissionsManager := permissions.NewManager(store)
@@ -962,7 +962,7 @@ func TestUser_DeleteUser_RegularUsers(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()
@@ -2022,7 +2022,7 @@ func TestUser_Operations_WithEmbeddedIDP(t *testing.T) {
ctrl := gomock.NewController(t)
networkMapControllerMock := network_map.NewMockController(ctrl)
networkMapControllerMock.EXPECT().
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any()).
OnPeersDeleted(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).
Return(nil).
AnyTimes()