further improve db calls

This commit is contained in:
pascal
2026-05-08 20:51:46 +02:00
parent 13d26106f8
commit c948d7398f
3 changed files with 137 additions and 226 deletions

View File

@@ -5,68 +5,137 @@ import (
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/route"
)
// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
// collectPeerChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the
// changed groups and/or changed peers. Each collection is fetched from the store exactly once.
func collectPeerChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs, changedPeerIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 && len(changedPeerIDs) == 0 {
return nil, nil
}
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
changedGroupSet := toSet(changedGroupIDs)
changedPeerSet := toSet(changedPeerIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
collectPolicyAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectRouteAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectNameServerAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
collectDNSSettingsAffectedGroups(ctx, transaction, accountID, changedSet, groupSet)
collectNetworkRouterAffectedGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectAffectedFromPolicies(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromRoutes(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
collectAffectedFromNameServers(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromDNSSettings(ctx, transaction, accountID, changedGroupSet, groupSet)
collectAffectedFromNetworkRouters(ctx, transaction, accountID, changedGroupSet, changedPeerSet, groupSet, peerSet)
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
allGroupIDs = setToSlice(groupSet)
directPeerIDs = setToSlice(peerSet)
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
log.WithContext(ctx).Tracef("affected groups resolution: changedGroups=%v changedPeers=%v -> affectedGroups=%v, directPeers=%v",
changedGroupIDs, changedPeerIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
// collectGroupChangeAffectedGroups is a convenience wrapper used by callers that only have changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) ([]string, []string) {
return collectPeerChangeAffectedGroups(ctx, transaction, accountID, changedGroupIDs, nil)
}
func collectAffectedFromPolicies(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
log.WithContext(ctx).Errorf("failed to get policies for affected group resolution: %v", err)
return
}
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
matchedByGroup := policyReferencesGroups(policy, changedGroupSet)
matchedByPeer := len(changedPeerSet) > 0 && policyReferencesDirectPeers(policy, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
ruleGroups := policy.RuleGroups()
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups)
for _, gID := range ruleGroups {
addAllToSet(groupSet, policy.RuleGroups())
collectPolicyDirectPeers(policy, peerSet)
}
}
func collectAffectedFromRoutes(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for affected group resolution: %v", err)
return
}
for _, r := range routes {
matchedByGroup := routeReferencesGroups(r, changedGroupSet)
matchedByPeer := r.Peer != "" && len(changedPeerSet) > 0 && isInSet(r.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
func collectAffectedFromNameServers(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for affected group resolution: %v", err)
return
}
for _, ns := range nsGroups {
if anyInSet(ns.Groups, changedGroupSet) {
addAllToSet(groupSet, ns.Groups)
}
}
}
func collectAffectedFromDNSSettings(ctx context.Context, transaction store.Store, accountID string, changedGroupSet map[string]struct{}, groupSet map[string]struct{}) {
if len(changedGroupSet) == 0 {
return
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for affected group resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedGroupSet[gID]; ok {
groupSet[gID] = struct{}{}
}
collectPolicyDirectPeers(policy, peerSet)
}
}
func collectAffectedFromNetworkRouters(ctx context.Context, transaction store.Store, accountID string, changedGroupSet, changedPeerSet map[string]struct{}, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for affected group resolution: %v", err)
return
}
for _, router := range routers {
matchedByGroup := routerReferencesGroups(router, changedGroupSet)
matchedByPeer := router.Peer != "" && len(changedPeerSet) > 0 && isInSet(router.Peer, changedPeerSet)
if !matchedByGroup && !matchedByPeer {
continue
}
addAllToSet(groupSet, router.PeerGroups)
if router.Peer != "" {
peerSet[router.Peer] = struct{}{}
}
}
}
@@ -81,139 +150,15 @@ func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{})
}
}
func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
return
}
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
if r.Peer != "" {
peerSet[r.Peer] = struct{}{}
}
}
}
func collectNameServerAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
return
}
for _, ns := range nsGroups {
if !nsReferencesGroups(ns, changedSet) {
continue
}
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
}
}
func nsReferencesGroups(ns *nbdns.NameServerGroup, changedSet map[string]struct{}) bool {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
log.Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func collectDNSSettingsAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet map[string]struct{}) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
return
}
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
groupSet[gID] = struct{}{}
}
}
}
func collectNetworkRouterAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
return
}
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
peerSet[router.Peer] = struct{}{}
}
}
}
// collectDirectPeerRefAffectedGroups finds entities (policies, routes, network routers) that reference
// the changed peers directly by peer ID (not via group membership) and collects the affected groups and peers.
func collectDirectPeerRefAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedPeerIDs []string) (groupIDs []string, directPeerIDs []string) {
if len(changedPeerIDs) == 0 {
return nil, nil
}
changedSet := make(map[string]struct{}, len(changedPeerIDs))
for _, id := range changedPeerIDs {
changedSet[id] = struct{}{}
}
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
collectPolicyDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectRouteDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
collectRouterDirectPeerRefGroups(ctx, transaction, accountID, changedSet, groupSet, peerSet)
groupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
return groupIDs, directPeerIDs
}
func collectPolicyDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for direct peer ref resolution: %v", err)
return
}
for _, policy := range policies {
if !policyReferencesDirectPeers(policy, changedSet) {
continue
}
for _, gID := range policy.RuleGroups() {
groupSet[gID] = struct{}{}
}
collectPolicyDirectPeers(policy, peerSet)
}
}
func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) {
@@ -231,55 +176,6 @@ func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool {
return ok
}
func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for direct peer ref resolution: %v", err)
return
}
for _, r := range routes {
if r.Peer == "" {
continue
}
if _, ok := changedSet[r.Peer]; !ok {
continue
}
addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups)
peerSet[r.Peer] = struct{}{}
}
}
func collectRouterDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for direct peer ref resolution: %v", err)
return
}
for _, router := range routers {
if router.Peer == "" {
continue
}
if _, ok := changedSet[router.Peer]; !ok {
continue
}
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
peerSet[router.Peer] = struct{}{}
}
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) {
return true
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet)
}
@@ -288,6 +184,20 @@ func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[stri
return anyInSet(router.PeerGroups, groupSet)
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func isInSet(id string, set map[string]struct{}) bool {
_, ok := set[id]
return ok
}
func addAllToSet(set map[string]struct{}, slices ...[]string) {
for _, s := range slices {
for _, id := range s {
@@ -295,3 +205,19 @@ func addAllToSet(set map[string]struct{}, slices ...[]string) {
}
}
}
func toSet(ids []string) map[string]struct{} {
set := make(map[string]struct{}, len(ids))
for _, id := range ids {
set[id] = struct{}{}
}
return set
}
func setToSlice(set map[string]struct{}) []string {
s := make([]string, 0, len(set))
for id := range set {
s = append(s, id)
}
return s
}

View File

@@ -224,15 +224,6 @@ func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, ac
return false, nil
}
func anyInSet(ids []string, set map[string]struct{}) bool {
for _, id := range ids {
if _, ok := set[id]; ok {
return true
}
}
return false
}
func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {

View File

@@ -1343,14 +1343,8 @@ func (am *DefaultAccountManager) resolveAffectedPeersForPeerChanges(ctx context.
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> groups=%v", changedPeerIDs, groupIDs)
allGroupIDs, directPeerIDs := collectGroupChangeAffectedGroups(ctx, s, accountID, groupIDs)
// Also collect groups/peers from entities that reference the changed peers directly by ID
// (e.g. Route.Peer, PolicyRule.SourceResource/DestinationResource, NetworkRouter.Peer)
directRefGroups, directRefPeers := collectDirectPeerRefAffectedGroups(ctx, s, accountID, changedPeerIDs)
allGroupIDs = append(allGroupIDs, directRefGroups...)
directPeerIDs = append(directPeerIDs, directRefPeers...)
// Single pass: find entities referencing the changed groups OR the changed peers directly
allGroupIDs, directPeerIDs := collectPeerChangeAffectedGroups(ctx, s, accountID, groupIDs, changedPeerIDs)
result := am.resolvePeerIDs(ctx, s, accountID, allGroupIDs, directPeerIDs)
log.WithContext(ctx).Tracef("resolveAffectedPeersForPeerChanges: changedPeers=%v -> %d affected peers", changedPeerIDs, len(result))