package server import ( "context" log "github.com/sirupsen/logrus" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/route" ) // collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings, // and network routers to collect all group IDs and direct peer IDs affected by the // changed groups and/or changed peers. Each collection is fetched from the store exactly once. func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) { if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 { return nil, nil } changedGroupSet := toSet(changedGroupIDs) changedPeerSet := toSet(changedPeerIDs) groupSet := make(map[string]struct{}) peerSet := make(map[string]struct{}) collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet) collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet) collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet) allGroupIDs = setToSlice(groupSet) directPeerIDs = setToSlice(peerSet) log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs) return allGroupIDs, directPeerIDs } // collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups. func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) { return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil) } func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err) return } for _, policy := range policies { matchedByGroup := policyReferencesGroups(policy, changedGroupSet) matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet) if !matchedByGroup && !matchedByPeer { continue } addAllToSet(groupSet, policy.RuleGroups()) collectPolicyDirectPeers(policy, peerSet) } } func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err) return } for _, r := range routes { matchedByGroup := routeReferencesGroups(r, changedGroupSet) matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet) if !matchedByGroup && !matchedByPeer { continue } addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) if r.Peer != "" { peerSet[r.Peer] = struct{}{} } } } func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) { if len(changedGroupSet) == 0 { return } nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err) return } for _, ns := range nsGroups { if anyInSet(ns.Groups, changedGroupSet) { addAllToSet(groupSet, ns.Groups) } } } func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) { if len(changedGroupSet) == 0 { return } dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err) return } for _, gID := range dnsSettings.DisabledManagementGroups { if _, ok := changedGroupSet[gID]; ok { groupSet[gID] = struct{}{} } } } func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) { routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err) return } for _, router := range routers { matchedByGroup := routerReferencesGroups(router, changedGroupSet) matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet) if !matchedByGroup && !matchedByPeer { continue } addAllToSet(groupSet, router.PeerGroups) if router.Peer != "" { peerSet[router.Peer] = struct{}{} } } } func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) { for _, rule := range policy.Rules { if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { peerSet[rule.SourceResource.ID] = struct{}{} } if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { peerSet[rule.DestinationResource.ID] = struct{}{} } } } func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { for _, rule := range policy.Rules { if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) { return true } } return false } func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool { for _, rule := range policy.Rules { if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) { return true } } return false } func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool { if res.Type != types.ResourceTypePeer || res.ID == "" { return false } _, ok := set[res.ID] return ok } func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) } func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool { return anyInSet(router.PeerGroups, groupSet) } func anyInSet(ids []string, set map[string]struct{}) bool { for _, id := range ids { if _, ok := set[id]; ok { return true } } return false } func isInSet(id string, set map[string]struct{}) bool { _, ok := set[id] return ok } func addAllToSet(set map[string]struct{}, slices ...[]string) { for _, s := range slices { for _, id := range s { set[id] = struct{}{} } } } func toSet(ids []string) map[string]struct{} { set := make(map[string]struct{}, len(ids)) for _, id := range ids { set[id] = struct{}{} } return set } func setToSlice(set map[string]struct{}) []string { s := make([]string, 0, len(set)) for id := range set { s = append(s, id) } return s }