diff --git a/management/server/affected_groups.go b/management/server/affected_groups.go index d0857ddec..4b765ec41 100644 --- a/management/server/affected_groups.go +++ b/management/server/affected_groups.go @@ -5,68 +5,137 @@ import ( log "github.com/sirupsen/logrus" - nbdns "github.com/netbirdio/netbird/dns" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) -// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings, -// and network routers to collect all group IDs and direct peer IDs affected by the changed groups. -func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) { - if len(changedGroupIDs) == 0 { +// collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings, +// and network routers to collect all group IDs and direct peer IDs affected by the +// changed groups and/or changed peers. Each collection is fetched from the store exactly once. +func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) { + if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 { return nil, nil } - changedSet := make(map[string]struct{}, len(changedGroupIDs)) - for _, id := range changedGroupIDs { - changedSet[id] = struct{}{} - } - - log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs) + changedGroupSet := toSet(changedGroupIDs) + changedPeerSet := toSet(changedPeerIDs) groupSet := make(map[string]struct{}) peerSet := make(map[string]struct{}) - collectPolicyAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) - collectRouteAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) - collectNameServerAffectedGroups(ctx, transaction, accountID, changedSet, groupSet) - collectDNSSettingsAffectedGroups(ctx, transaction, accountID, changedSet, groupSet) - collectNetworkRouterAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) + collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) + collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) + collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet) + collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet) + collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) - allGroupIDs = make([]string, 0, len(groupSet)) - for gID := range groupSet { - allGroupIDs = append(allGroupIDs, gID) - } + allGroupIDs = setToSlice(groupSet) + directPeerIDs = setToSlice(peerSet) - directPeerIDs = make([]string, 0, len(peerSet)) - for pID := range peerSet { - directPeerIDs = append(directPeerIDs, pID) - } - - log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs) + log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v", + changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs) return allGroupIDs, directPeerIDs } -func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { +// collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups. +func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) { + return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil) +} + +func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { - log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err) + log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err) return } for _, policy := range policies { - if !policyReferencesGroups(policy, changedSet) { + matchedByGroup := policyReferencesGroups(policy, changedGroupSet) + matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet) + if !matchedByGroup && !matchedByPeer { continue } - ruleGroups := policy.RuleGroups() - log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups) - for _, gID := range ruleGroups { + addAllToSet(groupSet, policy.RuleGroups()) + collectPolicyDirectPeers(policy, peerSet) + } +} + +func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err) + return + } + + for _, r := range routes { + matchedByGroup := routeReferencesGroups(r, changedGroupSet) + matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet) + if !matchedByGroup && !matchedByPeer { + continue + } + addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) + if r.Peer != "" { + peerSet[r.Peer] = struct{}{} + } + } +} + +func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) { + if len(changedGroupSet) == 0 { + return + } + + nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err) + return + } + + for _, ns := range nsGroups { + if anyInSet(ns.Groups, changedGroupSet) { + addAllToSet(groupSet, ns.Groups) + } + } +} + +func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) { + if len(changedGroupSet) == 0 { + return + } + + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err) + return + } + + for _, gID := range dnsSettings.DisabledManagementGroups { + if _, ok := changedGroupSet[gID]; ok { groupSet[gID] = struct{}{} } - collectPolicyDirectPeers(policy, peerSet) + } +} + +func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err) + return + } + + for _, router := range routers { + matchedByGroup := routerReferencesGroups(router, changedGroupSet) + matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet) + if !matchedByGroup && !matchedByPeer { + continue + } + addAllToSet(groupSet, router.PeerGroups) + if router.Peer != "" { + peerSet[router.Peer] = struct{}{} + } } } @@ -81,139 +150,15 @@ func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) } } -func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { - routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err) - return - } - - for _, r := range routes { - if !routeReferencesGroups(r, changedSet) { - continue - } - log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description) - addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) - if r.Peer != "" { - peerSet[r.Peer] = struct{}{} - } - } -} - -func collectNameServerAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) { - nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err) - return - } - - for _, ns := range nsGroups { - if !nsReferencesGroups(ns, changedSet) { - continue - } - for _, g := range ns.Groups { - groupSet[g] = struct{}{} - } - } -} - -func nsReferencesGroups(ns *nbdns.NameServerGroup, changedSet map[string]struct{}) bool { - for _, gID := range ns.Groups { - if _, ok := changedSet[gID]; ok { - log.Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID) +func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { + for _, rule := range policy.Rules { + if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) { return true } } return false } -func collectDNSSettingsAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err) - return - } - - for _, gID := range dnsSettings.DisabledManagementGroups { - if _, ok := changedSet[gID]; ok { - log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID) - groupSet[gID] = struct{}{} - } - } -} - -func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err) - return - } - - for _, router := range routers { - if !routerReferencesGroups(router, changedSet) { - continue - } - log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID) - for _, gID := range router.PeerGroups { - groupSet[gID] = struct{}{} - } - if router.Peer != "" { - log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer) - peerSet[router.Peer] = struct{}{} - } - } -} - -// collectDirectPeerRefAffectedGroups finds entities (policies, routes, network routers) that reference -// the changed peers directly by peer ID (not via group membership) and collects the affected groups and peers. -func collectDirectPeerRefAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedPeerIDs []string) (groupIDs []string, directPeerIDs []string) { - if len(changedPeerIDs) == 0 { - return nil, nil - } - - changedSet := make(map[string]struct{}, len(changedPeerIDs)) - for _, id := range changedPeerIDs { - changedSet[id] = struct{}{} - } - - groupSet := make(map[string]struct{}) - peerSet := make(map[string]struct{}) - - collectPolicyDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) - collectRouteDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) - collectRouterDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet) - - groupIDs = make([]string, 0, len(groupSet)) - for gID := range groupSet { - groupIDs = append(groupIDs, gID) - } - - directPeerIDs = make([]string, 0, len(peerSet)) - for pID := range peerSet { - directPeerIDs = append(directPeerIDs, pID) - } - - return groupIDs, directPeerIDs -} - -func collectPolicyDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get policies for direct peer ref resolution: %v", err) - return - } - - for _, policy := range policies { - if !policyReferencesDirectPeers(policy, changedSet) { - continue - } - for _, gID := range policy.RuleGroups() { - groupSet[gID] = struct{}{} - } - collectPolicyDirectPeers(policy, peerSet) - } -} - func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool { for _, rule := range policy.Rules { if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) { @@ -231,55 +176,6 @@ func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool { return ok } -func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { - routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get routes for direct peer ref resolution: %v", err) - return - } - - for _, r := range routes { - if r.Peer == "" { - continue - } - if _, ok := changedSet[r.Peer]; !ok { - continue - } - addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) - peerSet[r.Peer] = struct{}{} - } -} - -func collectRouterDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get network routers for direct peer ref resolution: %v", err) - return - } - - for _, router := range routers { - if router.Peer == "" { - continue - } - if _, ok := changedSet[router.Peer]; !ok { - continue - } - for _, gID := range router.PeerGroups { - groupSet[gID] = struct{}{} - } - peerSet[router.Peer] = struct{}{} - } -} - -func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { - for _, rule := range policy.Rules { - if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) { - return true - } - } - return false -} - func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) } @@ -288,6 +184,20 @@ func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[stri return anyInSet(router.PeerGroups, groupSet) } +func anyInSet(ids []string, set map[string]struct{}) bool { + for _, id := range ids { + if _, ok := set[id]; ok { + return true + } + } + return false +} + +func isInSet(id string, set map[string]struct{}) bool { + _, ok := set[id] + return ok +} + func addAllToSet(set map[string]struct{}, slices ...[]string) { for _, s := range slices { for _, id := range s { @@ -295,3 +205,19 @@ func addAllToSet(set map[string]struct{}, slices ...[]string) { } } } + +func toSet(ids []string) map[string]struct{} { + set := make(map[string]struct{}, len(ids)) + for _, id := range ids { + set[id] = struct{}{} + } + return set +} + +func setToSlice(set map[string]struct{}) []string { + s := make([]string, 0, len(set)) + for id := range set { + s = append(s, id) + } + return s +} diff --git a/management/server/group_linkage.go b/management/server/group_linkage.go index 56491f9df..626ef7956 100644 --- a/management/server/group_linkage.go +++ b/management/server/group_linkage.go @@ -224,15 +224,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac return false, nil } -func anyInSet(ids []string, set map[string]struct{}) bool { - for _, id := range ids { - if _, ok := set[id]; ok { - return true - } - } - return false -} - func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { diff --git a/management/server/peer.go b/management/server/peer.go index 8d25754a1..94c015e05 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1343,14 +1343,8 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context. log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs) - allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs) - - // Also collect groups/peers from entities that reference the changed peers directly by ID - // (e.g. Route.Peer, PolicyRule.SourceResource/DestinationResource, NetworkRouter.Peer) - directRefGroups, directRefPeers := collectDirectPeerRefAffectedGroups(ctx, s, accountID, changedPeerIDs) - allGroupIDs = append(allGroupIDs, directRefGroups...) - directPeerIDs = append(directPeerIDs, directRefPeers...) - + // Single pass: find entities referencing the changed groups OR the changed peers directly + allGroupIDs, directPeerIDs := collectPeerChangeAffectedGroups(ctx, s, accountID, groupIDs, changedPeerIDs) result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs) log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))