mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-13 12:19:54 +00:00
further improve db calls
This commit is contained in:
@@ -5,68 +5,137 @@ import (
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
|
||||
// and network routers to collect all group IDs and direct peer IDs affected by the changed groups.
|
||||
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
|
||||
if len(changedGroupIDs) == 0 {
|
||||
// collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
|
||||
// and network routers to collect all group IDs and direct peer IDs affected by the
|
||||
// changed groups and/or changed peers. Each collection is fetched from the store exactly once.
|
||||
func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) {
|
||||
if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
changedSet := make(map[string]struct{}, len(changedGroupIDs))
|
||||
for _, id := range changedGroupIDs {
|
||||
changedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
|
||||
changedGroupSet := toSet(changedGroupIDs)
|
||||
changedPeerSet := toSet(changedPeerIDs)
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
peerSet := make(map[string]struct{})
|
||||
|
||||
collectPolicyAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
collectRouteAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
collectNameServerAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
|
||||
collectDNSSettingsAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
|
||||
collectNetworkRouterAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
|
||||
collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
|
||||
collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet)
|
||||
collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet)
|
||||
collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
|
||||
|
||||
allGroupIDs = make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
allGroupIDs = append(allGroupIDs, gID)
|
||||
}
|
||||
allGroupIDs = setToSlice(groupSet)
|
||||
directPeerIDs = setToSlice(peerSet)
|
||||
|
||||
directPeerIDs = make([]string, 0, len(peerSet))
|
||||
for pID := range peerSet {
|
||||
directPeerIDs = append(directPeerIDs, pID)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
|
||||
log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v",
|
||||
changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs)
|
||||
|
||||
return allGroupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
// collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups.
|
||||
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) {
|
||||
return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil)
|
||||
}
|
||||
|
||||
func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
|
||||
log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if !policyReferencesGroups(policy, changedSet) {
|
||||
matchedByGroup := policyReferencesGroups(policy, changedGroupSet)
|
||||
matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
ruleGroups := policy.RuleGroups()
|
||||
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups)
|
||||
for _, gID := range ruleGroups {
|
||||
addAllToSet(groupSet, policy.RuleGroups())
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
matchedByGroup := routeReferencesGroups(r, changedGroupSet)
|
||||
matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
|
||||
if r.Peer != "" {
|
||||
peerSet[r.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
|
||||
if len(changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ns := range nsGroups {
|
||||
if anyInSet(ns.Groups, changedGroupSet) {
|
||||
addAllToSet(groupSet, ns.Groups)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
|
||||
if len(changedGroupSet) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, gID := range dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := changedGroupSet[gID]; ok {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
matchedByGroup := routerReferencesGroups(router, changedGroupSet)
|
||||
matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet)
|
||||
if !matchedByGroup && !matchedByPeer {
|
||||
continue
|
||||
}
|
||||
addAllToSet(groupSet, router.PeerGroups)
|
||||
if router.Peer != "" {
|
||||
peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -81,139 +150,15 @@ func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{})
|
||||
}
|
||||
}
|
||||
|
||||
func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if !routeReferencesGroups(r, changedSet) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
|
||||
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
|
||||
if r.Peer != "" {
|
||||
peerSet[r.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectNameServerAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
|
||||
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, ns := range nsGroups {
|
||||
if !nsReferencesGroups(ns, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, g := range ns.Groups {
|
||||
groupSet[g] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func nsReferencesGroups(ns *nbdns.NameServerGroup, changedSet map[string]struct{}) bool {
|
||||
for _, gID := range ns.Groups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
log.Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func collectDNSSettingsAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, gID := range dnsSettings.DisabledManagementGroups {
|
||||
if _, ok := changedSet[gID]; ok {
|
||||
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
if !routerReferencesGroups(router, changedSet) {
|
||||
continue
|
||||
}
|
||||
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
|
||||
for _, gID := range router.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if router.Peer != "" {
|
||||
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
|
||||
peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// collectDirectPeerRefAffectedGroups finds entities (policies, routes, network routers) that reference
|
||||
// the changed peers directly by peer ID (not via group membership) and collects the affected groups and peers.
|
||||
func collectDirectPeerRefAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedPeerIDs []string) (groupIDs []string, directPeerIDs []string) {
|
||||
if len(changedPeerIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
changedSet := make(map[string]struct{}, len(changedPeerIDs))
|
||||
for _, id := range changedPeerIDs {
|
||||
changedSet[id] = struct{}{}
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
peerSet := make(map[string]struct{})
|
||||
|
||||
collectPolicyDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
collectRouteDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
collectRouterDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
|
||||
|
||||
groupIDs = make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
groupIDs = append(groupIDs, gID)
|
||||
}
|
||||
|
||||
directPeerIDs = make([]string, 0, len(peerSet))
|
||||
for pID := range peerSet {
|
||||
directPeerIDs = append(directPeerIDs, pID)
|
||||
}
|
||||
|
||||
return groupIDs, directPeerIDs
|
||||
}
|
||||
|
||||
func collectPolicyDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies for direct peer ref resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, policy := range policies {
|
||||
if !policyReferencesDirectPeers(policy, changedSet) {
|
||||
continue
|
||||
}
|
||||
for _, gID := range policy.RuleGroups() {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
collectPolicyDirectPeers(policy, peerSet)
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
|
||||
@@ -231,55 +176,6 @@ func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
|
||||
return ok
|
||||
}
|
||||
|
||||
func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get routes for direct peer ref resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, r := range routes {
|
||||
if r.Peer == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedSet[r.Peer]; !ok {
|
||||
continue
|
||||
}
|
||||
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
|
||||
peerSet[r.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func collectRouterDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
|
||||
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get network routers for direct peer ref resolution: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
if router.Peer == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := changedSet[router.Peer]; !ok {
|
||||
continue
|
||||
}
|
||||
for _, gID := range router.PeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
peerSet[router.Peer] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
|
||||
for _, rule := range policy.Rules {
|
||||
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
|
||||
return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet)
|
||||
}
|
||||
@@ -288,6 +184,20 @@ func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[stri
|
||||
return anyInSet(router.PeerGroups, groupSet)
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func isInSet(id string, set map[string]struct{}) bool {
|
||||
_, ok := set[id]
|
||||
return ok
|
||||
}
|
||||
|
||||
func addAllToSet(set map[string]struct{}, slices ...[]string) {
|
||||
for _, s := range slices {
|
||||
for _, id := range s {
|
||||
@@ -295,3 +205,19 @@ func addAllToSet(set map[string]struct{}, slices ...[]string) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toSet(ids []string) map[string]struct{} {
|
||||
set := make(map[string]struct{}, len(ids))
|
||||
for _, id := range ids {
|
||||
set[id] = struct{}{}
|
||||
}
|
||||
return set
|
||||
}
|
||||
|
||||
func setToSlice(set map[string]struct{}) []string {
|
||||
s := make([]string, 0, len(set))
|
||||
for id := range set {
|
||||
s = append(s, id)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -224,15 +224,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func anyInSet(ids []string, set map[string]struct{}) bool {
|
||||
for _, id := range ids {
|
||||
if _, ok := set[id]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
|
||||
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
|
||||
@@ -1343,14 +1343,8 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.
|
||||
|
||||
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs)
|
||||
|
||||
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs)
|
||||
|
||||
// Also collect groups/peers from entities that reference the changed peers directly by ID
|
||||
// (e.g. Route.Peer, PolicyRule.SourceResource/DestinationResource, NetworkRouter.Peer)
|
||||
directRefGroups, directRefPeers := collectDirectPeerRefAffectedGroups(ctx, s, accountID, changedPeerIDs)
|
||||
allGroupIDs = append(allGroupIDs, directRefGroups...)
|
||||
directPeerIDs = append(directPeerIDs, directRefPeers...)
|
||||
|
||||
// Single pass: find entities referencing the changed groups OR the changed peers directly
|
||||
allGroupIDs, directPeerIDs := collectPeerChangeAffectedGroups(ctx, s, accountID, groupIDs, changedPeerIDs)
|
||||
result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs)
|
||||
|
||||
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))
|
||||
|
||||
Reference in New Issue
Block a user