From 13d26106f8b1275433ab66431c2d3bded0e62498 Mon Sep 17 00:00:00 2001 From: pascal Date: Fri, 8 May 2026 20:44:17 +0200 Subject: [PATCH] improve db calls --- management/server/affected_groups.go | 87 +++++++-------------- management/server/group_linkage.go | 108 ++++++++++++++++++++++----- 2 files changed, 114 insertions(+), 81 deletions(-) diff --git a/management/server/affected_groups.go b/management/server/affected_groups.go index 7554fd763..d0857ddec 100644 --- a/management/server/affected_groups.go +++ b/management/server/affected_groups.go @@ -66,18 +66,16 @@ func collectPolicyAffectedGroups(ctx context.Context, transaction store.Store, a for _, gID := range ruleGroups { groupSet[gID] = struct{}{} } - collectPolicyDirectPeers(ctx, policy, peerSet) + collectPolicyDirectPeers(policy, peerSet) } } -func collectPolicyDirectPeers(ctx context.Context, policy *types.Policy, peerSet map[string]struct{}) { +func collectPolicyDirectPeers(policy *types.Policy, peerSet map[string]struct{}) { for _, rule := range policy.Rules { if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { - log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID) peerSet[rule.SourceResource.ID] = struct{}{} } if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { - log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID) peerSet[rule.DestinationResource.ID] = struct{}{} } } @@ -95,17 +93,8 @@ func collectRouteAffectedGroups(ctx context.Context, transaction store.Store, ac continue } log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description) - for _, gID := range r.Groups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.PeerGroups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.AccessControlGroups { - groupSet[gID] = struct{}{} - } + addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) if r.Peer != "" { - log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer) peerSet[r.Peer] = struct{}{} } } @@ -221,26 +210,27 @@ func collectPolicyDirectPeerRefGroups(ctx context.Context, transaction store.Sto for _, gID := range policy.RuleGroups() { groupSet[gID] = struct{}{} } - collectPolicyDirectPeers(ctx, policy, peerSet) + collectPolicyDirectPeers(policy, peerSet) } } func policyReferencesDirectPeers(policy *types.Policy, changedSet map[string]struct{}) bool { for _, rule := range policy.Rules { - if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { - if _, ok := changedSet[rule.SourceResource.ID]; ok { - return true - } - } - if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { - if _, ok := changedSet[rule.DestinationResource.ID]; ok { - return true - } + if isDirectPeerInSet(rule.SourceResource, changedSet) || isDirectPeerInSet(rule.DestinationResource, changedSet) { + return true } } return false } +func isDirectPeerInSet(res types.Resource, set map[string]struct{}) bool { + if res.Type != types.ResourceTypePeer || res.ID == "" { + return false + } + _, ok := set[res.ID] + return ok +} + func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Store, accountID string, changedSet, groupSet, peerSet map[string]struct{}) { routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) if err != nil { @@ -255,15 +245,7 @@ func collectRouteDirectPeerRefGroups(ctx context.Context, transaction store.Stor if _, ok := changedSet[r.Peer]; !ok { continue } - for _, gID := range r.Groups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.PeerGroups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.AccessControlGroups { - groupSet[gID] = struct{}{} - } + addAllToSet(groupSet, r.Groups, r.PeerGroups, r.AccessControlGroups) peerSet[r.Peer] = struct{}{} } } @@ -291,44 +273,25 @@ func collectRouterDirectPeerRefGroups(ctx context.Context, transaction store.Sto func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { for _, rule := range policy.Rules { - for _, gID := range rule.Sources { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range rule.Destinations { - if _, ok := groupSet[gID]; ok { - return true - } + if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) { + return true } } return false } func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { - for _, gID := range r.Groups { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range r.PeerGroups { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range r.AccessControlGroups { - if _, ok := groupSet[gID]; ok { - return true - } - } - return false + return anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) } func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool { - for _, gID := range router.PeerGroups { - if _, ok := groupSet[gID]; ok { - return true + return anyInSet(router.PeerGroups, groupSet) +} + +func addAllToSet(set map[string]struct{}, slices ...[]string) { + for _, s := range slices { + for _, id := range s { + set[id] = struct{}{} } } - return false } diff --git a/management/server/group_linkage.go b/management/server/group_linkage.go index a69e7d7a7..56491f9df 100644 --- a/management/server/group_linkage.go +++ b/management/server/group_linkage.go @@ -194,33 +194,103 @@ func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, } // areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. +// It fetches each collection once and checks all groupIDs against them in memory. func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { if len(groupIDs) == 0 { return false, nil } - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return false, err + groupSet := make(map[string]struct{}, len(groupIDs)) + for _, id := range groupIDs { + groupSet[id] = struct{}{} } - for _, groupID := range groupIDs { - if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { - return true, nil - } - if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked { - return true, nil - } + if affected, err := dnsSettingsReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil { + return affected, err + } + if affected, err := nameServersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil { + return affected, err + } + if affected, err := policiesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil { + return affected, err + } + if affected, err := routesReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil { + return affected, err + } + if affected, err := networkRoutersReferenceGroups(ctx, transaction, accountID, groupSet); affected || err != nil { + return affected, err } return false, nil } + +func anyInSet(ids []string, set map[string]struct{}) bool { + for _, id := range ids { + if _, ok := set[id]; ok { + return true + } + } + return false +} + +func dnsSettingsReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { + dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + return anyInSet(dnsSettings.DisabledManagementGroups, groupSet), nil +} + +func nameServersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { + nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + for _, ns := range nameServerGroups { + if anyInSet(ns.Groups, groupSet) { + return true, nil + } + } + return false, nil +} + +func policiesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { + policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + for _, policy := range policies { + for _, rule := range policy.Rules { + if anyInSet(rule.Sources, groupSet) || anyInSet(rule.Destinations, groupSet) { + return true, nil + } + } + } + return false, nil +} + +func routesReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { + routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + for _, r := range routes { + if anyInSet(r.Groups, groupSet) || anyInSet(r.PeerGroups, groupSet) || anyInSet(r.AccessControlGroups, groupSet) { + return true, nil + } + } + return false, nil +} + +func networkRoutersReferenceGroups(ctx context.Context, transaction store.Store, accountID string, groupSet map[string]struct{}) (bool, error) { + routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) + if err != nil { + return false, err + } + for _, router := range routers { + if anyInSet(router.PeerGroups, groupSet) { + return true, nil + } + } + return false, nil +}