Files
netbird/management/server/affected_groups.go
2026-05-08 16:48:09 +02:00

221 lines
7.3 KiB
Go

package server
import (
"context"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
return nil, nil
}
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
collectPolicyAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectRouteAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectNameServerAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
collectDNSSettingsAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
collectNetworkRouterAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
return
}
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
continue
}
ruleGroups := policy.RuleGroups()
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups)
for _, gID := range ruleGroups {
groupSet[gID] = struct{}{}
}
collectPolicyDirectPeers(ctx, policy, peerSet)
}
}
func collectPolicyDirectPeers(ctx context.Context, policy *types.Policy, peerSet map[string]struct{}) {
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID)
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID)
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
return
}
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.PeerGroups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.AccessControlGroups {
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer)
peerSet[r.Peer] = struct{}{}
}
}
}
func collectNameServerAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
return
}
for _, ns := range nsGroups {
if !nsReferencesGroups(ns, changedSet) {
continue
}
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
}
}
func nsReferencesGroups(ns *nbdns.NameServerGroup, changedSet map[string]struct{}) bool {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
log.Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
return true
}
}
return false
}
func collectDNSSettingsAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
groupSet[gID] = struct{}{}
}
}
}
func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
return
}
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
peerSet[router.Peer] = struct{}{}
}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, gID := range rule.Sources {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range rule.Destinations {
if _, ok := groupSet[gID]; ok {
return true
}
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
for _, gID := range r.Groups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.AccessControlGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
for _, gID := range router.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}