Files
netbird/management/server/affected_groups.go
2026-05-08 20:51:46 +02:00

224 lines
7.7 KiB
Go

package server
import (
"context"
log "github.com/sirupsen/logrus"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the
// changed groups and/or changed peers. Each collection is fetched from the store exactly once.
func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 {
return nil, nil
}
changedGroupSet := toSet(changedGroupIDs)
changedPeerSet := toSet(changedPeerIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
allGroupIDs = setToSlice(groupSet)
directPeerIDs = setToSlice(peerSet)
log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v",
changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
// collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) {
return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil)
}
func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err)
return
}
for _, policy := range policies {
matchedByGroup := policyReferencesGroups(policy, changedGroupSet)
matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, peerSet)
}
}
func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err)
return
}
for _, r := range routes {
matchedByGroup := routeReferencesGroups(r, changedGroupSet)
matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err)
return
}
for _, ns := range nsGroups {
if anyInSet(ns.Groups, changedGroupSet) {
addAllToSet(groupSet, ns.Groups)
}
}
}
func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedGroupSet[gID]; ok {
groupSet[gID] = struct{}{}
}
}
}
func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err)
return
}
for _, router := range routers {
matchedByGroup := routerReferencesGroups(router, changedGroupSet)
matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, router.PeerGroups)
if router.Peer != "" {
peerSet[router.Peer] = struct{}{}
}
}
}
func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
return true
}
}
return false
}
func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
if res.Type != types.ResourceTypePeer || res.ID == "" {
return false
}
_, ok := set[res.ID]
return ok
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet)
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
return anyInSet(router.PeerGroups, groupSet)
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAllToSet(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
set[id] = struct{}{}
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}