diff --git a/management/server/affected_groups.go b/management/server/affected_groups.go index 05e40db0e..7554fd763 100644 --- a/management/server/affected_groups.go +++ b/management/server/affected_groups.go @@ -175,6 +175,120 @@ func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.S } } +// collectDirectPeerRefAffectedGroups finds entities (policies, routes, network routers) that reference +// the changed peers directly by peer ID (not via group membership) and collects the affected groups and peers. +func collectDirectPeerRefAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedPeerIDs []string) (groupIDs []string, directPeerIDs []string) { + if len(changedPeerIDs) == 0 { + return nil, nil + } + + changedSet := make(map[string]struct{}, len(changedPeerIDs)) + for _, id := range changedPeerIDs { + changedSet[id] = struct{}{} + } + + groupSet := make(map[string]struct{}) + peerSet := make(map[string]struct{}) + + collectPolicyDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + collectRouteDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + collectRouterDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + + groupIDs = make([]string, 0, len(groupSet)) + for gID := range groupSet { + groupIDs = append(groupIDs, gID) + } + + directPeerIDs = make([]string, 0, len(peerSet)) + for pID := range peerSet { + directPeerIDs = append(directPeerIDs, pID) + } + + return groupIDs, directPeerIDs +} + +func collectPolicyDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get policies for direct peer ref resolution: %v", err) + return + } + + for _, policy := range policies { + if !policyReferencesDirectPeers(policy, changedSet) { + continue + } + for _, gID := range policy.RuleGroups() { + groupSet[gID] = struct{}{} + } + collectPolicyDirectPeers(ctx, policy, peerSet) + } +} + +func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool { + for _, rule := range policy.Rules { + if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { + if _, ok := changedSet[rule.SourceResource.ID]; ok { + return true + } + } + if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { + if _, ok := changedSet[rule.DestinationResource.ID]; ok { + return true + } + } + } + return false +} + +func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get routes for direct peer ref resolution: %v", err) + return + } + + for _, r := range routes { + if r.Peer == "" { + continue + } + if _, ok := changedSet[r.Peer]; !ok { + continue + } + for _, gID := range r.Groups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.PeerGroups { + groupSet[gID] = struct{}{} + } + for _, gID := range r.AccessControlGroups { + groupSet[gID] = struct{}{} + } + peerSet[r.Peer] = struct{}{} + } +} + +func collectRouterDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get network routers for direct peer ref resolution: %v", err) + return + } + + for _, router := range routers { + if router.Peer == "" { + continue + } + if _, ok := changedSet[router.Peer]; !ok { + continue + } + for _, gID := range router.PeerGroups { + groupSet[gID] = struct{}{} + } + peerSet[router.Peer] = struct{}{} + } +} + func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { for _, rule := range policy.Rules { for _, gID := range rule.Sources { diff --git a/management/server/peer.go b/management/server/peer.go index 5ea0197bc..8d25754a1 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1344,6 +1344,13 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context. log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs) allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs) + + // Also collect groups/peers from entities that reference the changed peers directly by ID + // (e.g. Route.Peer, PolicyRule.SourceResource/DestinationResource, NetworkRouter.Peer) + directRefGroups, directRefPeers := collectDirectPeerRefAffectedGroups(ctx, s, accountID, changedPeerIDs) + allGroupIDs = append(allGroupIDs, directRefGroups...) + directPeerIDs = append(directPeerIDs, directRefPeers...) + result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs) log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))