From 85851bc4779193bcf50d54c6a7414a49b677844a Mon Sep 17 00:00:00 2001 From: pascal Date: Fri, 8 May 2026 16:43:27 +0200 Subject: [PATCH] extract submethods --- .../network_map/controller/controller.go | 46 ++- management/server/group.go | 390 ------------------ management/server/networks/manager.go | 90 ++-- .../server/networks/resources/manager.go | 198 +++++---- management/server/networks/routers/manager.go | 176 ++++---- 5 files changed, 288 insertions(+), 612 deletions(-) diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index d5be1fd65..455db9a74 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -248,19 +248,7 @@ func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string, func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error { log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName()) - affected := make(map[string]struct{}, len(peerIDs)) - for _, id := range peerIDs { - affected[id] = struct{}{} - } - - hasConnected := false - for _, id := range peerIDs { - if c.peersUpdateManager.HasChannel(id) { - hasConnected = true - break - } - } - if !hasConnected { + if !c.hasConnectedPeers(peerIDs) { log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs) return nil } @@ -272,13 +260,7 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s globalStart := time.Now() - var peersToUpdate []*nbpeer.Peer - for _, peer := range account.Peers { - if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) { - peersToUpdate = append(peersToUpdate, peer) - } - } - + peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs) if len(peersToUpdate) == 0 { log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)") return nil @@ -368,6 +350,30 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s return nil } +func (c *Controller) hasConnectedPeers(peerIDs []string) bool { + for _, id := range peerIDs { + if c.peersUpdateManager.HasChannel(id) { + return true + } + } + return false +} + +func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer { + affected := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + affected[id] = struct{}{} + } + + var result []*nbpeer.Peer + for _, peer := range account.Peers { + if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) { + result = append(result, peer) + } + } + return result +} + func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error { if !c.peersUpdateManager.HasChannel(peerId) { return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) diff --git a/management/server/group.go b/management/server/group.go index cce7bea90..125c11374 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -9,15 +9,12 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" - nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/activity" - routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" - "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/status" ) @@ -656,390 +653,3 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st return nil } - -func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error { - // disable a deleting integration group if the initiator is not an admin service user - if group.Issued == types.GroupIssuedIntegration { - executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID) - if err != nil { - return status.Errorf(status.Internal, "failed to get user") - } - if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser { - return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group") - } - } - - if group.IsGroupAll() { - return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed") - } - - if len(group.Resources) > 0 { - return &GroupLinkError{"network resource", group.Resources[0].ID} - } - - if slices.Contains(flowGroups, group.ID) { - return &GroupLinkError{"settings", "traffic event logging"} - } - - if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"route", string(linkedRoute.NetID)} - } - - if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"name server groups", linkedDns.Name} - } - - if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"policy", linkedPolicy.Name} - } - - if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"setup key", linkedSetupKey.Name} - } - - if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"user", linkedUser.Id} - } - - if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked { - return &GroupLinkError{"network router", linkedRouter.ID} - } - - return checkGroupLinkedToSettings(ctx, transaction, group) -} - -// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account. -func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error { - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID) - if err != nil { - return status.Errorf(status.Internal, "failed to get DNS settings") - } - - if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) { - return &GroupLinkError{"disabled DNS management groups", group.Name} - } - - settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID) - if err != nil { - return status.Errorf(status.Internal, "failed to get account settings") - } - - if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) { - return &GroupLinkError{"integrated validator", group.Name} - } - - return nil -} - -// isGroupLinkedToRoute checks if a group is linked to any route in the account. -func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) { - routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err) - return false, nil - } - - for _, r := range routes { - isLinked := slices.Contains(r.Groups, groupID) || - slices.Contains(r.PeerGroups, groupID) || - slices.Contains(r.AccessControlGroups, groupID) - if isLinked { - return true, r - } - } - - return false, nil -} - -// isGroupLinkedToPolicy checks if a group is linked to any policy in the account. -func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) { - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err) - return false, nil - } - - for _, policy := range policies { - for _, rule := range policy.Rules { - if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) { - return true, policy - } - } - } - return false, nil -} - -// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account. -func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) { - nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err) - return false, nil - } - - for _, dns := range nameServerGroups { - for _, g := range dns.Groups { - if g == groupID { - return true, dns - } - } - } - - return false, nil -} - -// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account. -func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) { - setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err) - return false, nil - } - - for _, setupKey := range setupKeys { - if slices.Contains(setupKey.AutoGroups, groupID) { - return true, setupKey - } - } - return false, nil -} - -// isGroupLinkedToUser checks if a group is linked to any user in the account. -func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) { - users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err) - return false, nil - } - - for _, user := range users { - if slices.Contains(user.AutoGroups, groupID) { - return true, user - } - } - return false, nil -} - -// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account. -func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) { - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err) - return false, nil - } - - for _, router := range routers { - if slices.Contains(router.PeerGroups, groupID) { - return true, router - } - } - return false, nil -} - -// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers. -func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) { - if len(groupIDs) == 0 { - return false, nil - } - - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) - if err != nil { - return false, err - } - - for _, groupID := range groupIDs { - if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) { - return true, nil - } - if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked { - return true, nil - } - if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked { - return true, nil - } - } - - return false, nil -} - -// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings, -// and network routers to collect all group IDs and direct peer IDs affected by the changed groups. -func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) { - if len(changedGroupIDs) == 0 { - return nil, nil - } - - changedSet := make(map[string]struct{}, len(changedGroupIDs)) - for _, id := range changedGroupIDs { - changedSet[id] = struct{}{} - } - - log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs) - - groupSet := make(map[string]struct{}) - - peerSet := make(map[string]struct{}) - - policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err) - } else { - for _, policy := range policies { - if !policyReferencesGroups(policy, changedSet) { - continue - } - ruleGroups := policy.RuleGroups() - log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups) - for _, gID := range ruleGroups { - groupSet[gID] = struct{}{} - } - for _, rule := range policy.Rules { - if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" { - log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID) - peerSet[rule.SourceResource.ID] = struct{}{} - } - if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" { - log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID) - peerSet[rule.DestinationResource.ID] = struct{}{} - } - } - } - } - - routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err) - } else { - for _, r := range routes { - if !routeReferencesGroups(r, changedSet) { - continue - } - log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description) - for _, gID := range r.Groups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.PeerGroups { - groupSet[gID] = struct{}{} - } - for _, gID := range r.AccessControlGroups { - groupSet[gID] = struct{}{} - } - if r.Peer != "" { - log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer) - peerSet[r.Peer] = struct{}{} - } - } - } - - nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err) - } else { - for _, ns := range nsGroups { - for _, gID := range ns.Groups { - if _, ok := changedSet[gID]; ok { - log.WithContext(ctx).Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID) - for _, g := range ns.Groups { - groupSet[g] = struct{}{} - } - break - } - } - } - } - - dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err) - } else { - for _, gID := range dnsSettings.DisabledManagementGroups { - if _, ok := changedSet[gID]; ok { - log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID) - groupSet[gID] = struct{}{} - } - } - } - - routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err) - } else { - for _, router := range routers { - if !routerReferencesGroups(router, changedSet) { - continue - } - log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID) - for _, gID := range router.PeerGroups { - groupSet[gID] = struct{}{} - } - if router.Peer != "" { - log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer) - peerSet[router.Peer] = struct{}{} - } - } - } - - allGroupIDs = make([]string, 0, len(groupSet)) - for gID := range groupSet { - allGroupIDs = append(allGroupIDs, gID) - } - - directPeerIDs = make([]string, 0, len(peerSet)) - for pID := range peerSet { - directPeerIDs = append(directPeerIDs, pID) - } - - log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs) - - return allGroupIDs, directPeerIDs -} - -func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool { - for _, rule := range policy.Rules { - for _, gID := range rule.Sources { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range rule.Destinations { - if _, ok := groupSet[gID]; ok { - return true - } - } - } - return false -} - -func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool { - for _, gID := range r.Groups { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range r.PeerGroups { - if _, ok := groupSet[gID]; ok { - return true - } - } - for _, gID := range r.AccessControlGroups { - if _, ok := groupSet[gID]; ok { - return true - } - } - return false -} - -func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool { - for _, gID := range router.PeerGroups { - if _, ok := groupSet[gID]; ok { - return true - } - } - return false -} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 0f21ea9ba..7191694ac 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -245,62 +245,84 @@ func resolveNetworkAffectedPeers(ctx context.Context, s store.Store, accountID s } if len(data.resourceGroupIDs) > 0 { - destSet := make(map[string]struct{}, len(data.resourceGroupIDs)) for _, gID := range data.resourceGroupIDs { - destSet[gID] = struct{}{} groupSet[gID] = struct{}{} } - - for _, policy := range data.policies { - if policy == nil || !policy.Enabled { - continue - } - for _, rule := range policy.Rules { - if rule == nil || !rule.Enabled { - continue - } - for _, gID := range rule.Destinations { - if _, ok := destSet[gID]; ok { - for _, srcGID := range rule.Sources { - groupSet[srcGID] = struct{}{} - } - break - } - } - } - } + collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet) } if len(groupSet) == 0 && len(data.directPeerIDs) == 0 { return nil } + peerIDs := resolveGroupsAndDirectPeers(ctx, s, accountID, groupSet, data.directPeerIDs) + + log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + return peerIDs +} + +// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs +// and adds their source groups to the groupSet. +func collectPolicySourceGroups(policies []*nbTypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) { + destSet := make(map[string]struct{}, len(destGroupIDs)) + for _, gID := range destGroupIDs { + destSet[gID] = struct{}{} + } + + for _, policy := range policies { + if policy == nil || !policy.Enabled { + continue + } + for _, rule := range policy.Rules { + if rule == nil || !rule.Enabled { + continue + } + if ruleMatchesDestinations(rule, destSet) { + for _, gID := range rule.Sources { + groupSet[gID] = struct{}{} + } + } + } + } +} + +// ruleMatchesDestinations checks if a policy rule references any of the destination groups. +func ruleMatchesDestinations(rule *nbTypes.PolicyRule, destSet map[string]struct{}) bool { + for _, gID := range rule.Destinations { + if _, ok := destSet[gID]; ok { + return true + } + } + return false +} + +// resolveGroupsAndDirectPeers resolves group IDs and direct peer IDs into a deduplicated peer ID list. +func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string { groupIDs := make([]string, 0, len(groupSet)) for gID := range groupSet { groupIDs = append(groupIDs, gID) } - log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: resolved groupIDs=%v", groupIDs) peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs) if err != nil { log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err) return nil } - if len(data.directPeerIDs) > 0 { - seen := make(map[string]struct{}, len(peerIDs)) - for _, id := range peerIDs { - seen[id] = struct{}{} - } - for _, id := range data.directPeerIDs { - if _, exists := seen[id]; !exists { - peerIDs = append(peerIDs, id) - seen[id] = struct{}{} - } - } + if len(directPeerIDs) == 0 { + return peerIDs } - log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + seen := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + seen[id] = struct{}{} + } + for _, id := range directPeerIDs { + if _, exists := seen[id]; !exists { + peerIDs = append(peerIDs, id) + seen[id] = struct{}{} + } + } return peerIDs } diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 03dbe542b..552daf37f 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -116,49 +116,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc var eventsToStore []func() var affectedData *resourceAffectedPeersData err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - _, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) - if err == nil { - return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) - } - - network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) - if err != nil { - return fmt.Errorf("failed to get network: %w", err) - } - - err = transaction.SaveNetworkResource(ctx, resource) - if err != nil { - return fmt.Errorf("failed to save network resource: %w", err) - } - - event := func() { - m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network)) - } - eventsToStore = append(eventsToStore, event) - - res := nbtypes.Resource{ - ID: resource.ID, - Type: nbtypes.ResourceType(resource.Type.String()), - } - for _, groupID := range resource.GroupIDs { - event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) - if err != nil { - return fmt.Errorf("failed to add resource to group: %w", err) - } - eventsToStore = append(eventsToStore, event) - } - - err = transaction.IncrementNetworkSerial(ctx, resource.AccountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs) - if err != nil { - log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err) - } - - return nil + var txErr error + eventsToStore, affectedData, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource) + return txErr }) if err != nil { return nil, fmt.Errorf("failed to create network resource: %w", err) @@ -178,6 +138,50 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc return resource, nil } +func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource) ([]func(), *resourceAffectedPeersData, error) { + _, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name) + if err == nil { + return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name) + } + + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get network: %w", err) + } + + if err = transaction.SaveNetworkResource(ctx, resource); err != nil { + return nil, nil, fmt.Errorf("failed to save network resource: %w", err) + } + + var eventsToStore []func() + eventsToStore = append(eventsToStore, func() { + m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network)) + }) + + res := nbtypes.Resource{ + ID: resource.ID, + Type: nbtypes.ResourceType(resource.Type.String()), + } + for _, groupID := range resource.GroupIDs { + event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res) + if err != nil { + return nil, nil, fmt.Errorf("failed to add resource to group: %w", err) + } + eventsToStore = append(eventsToStore, event) + } + + if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil { + return nil, nil, fmt.Errorf("failed to increment network serial: %w", err) + } + + affectedData, err := loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs) + if err != nil { + log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err) + } + + return eventsToStore, affectedData, nil +} + func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read) if err != nil { @@ -502,40 +506,9 @@ func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountI log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: resourceGroupIDs=%v, routerPeerGroups=%v, routerDirectPeers=%v, policies=%d", data.resourceGroupIDs, data.routerPeerGroups, data.routerDirectPeers, len(data.policies)) + groupSet := make(map[string]struct{}) - var directPeerIDs []string - - destSet := make(map[string]struct{}, len(data.resourceGroupIDs)) - for _, gID := range data.resourceGroupIDs { - destSet[gID] = struct{}{} - } - - for _, policy := range data.policies { - if policy == nil || !policy.Enabled { - continue - } - for _, rule := range policy.Rules { - if rule == nil || !rule.Enabled { - continue - } - referencesResource := false - for _, gID := range rule.Destinations { - if _, ok := destSet[gID]; ok { - referencesResource = true - break - } - } - if !referencesResource { - continue - } - for _, gID := range rule.Sources { - groupSet[gID] = struct{}{} - } - if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" { - directPeerIDs = append(directPeerIDs, rule.SourceResource.ID) - } - } - } + directPeerIDs := collectResourcePolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet) for _, gID := range data.routerPeerGroups { groupSet[gID] = struct{}{} @@ -546,31 +519,78 @@ func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountI return nil } + peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, directPeerIDs) + + log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + return peerIDs +} + +// collectResourcePolicySourceGroups finds policies whose rules reference the resource destination groups, +// adds their source groups to groupSet, and returns any direct peer IDs from source resources. +func collectResourcePolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) []string { + destSet := make(map[string]struct{}, len(destGroupIDs)) + for _, gID := range destGroupIDs { + destSet[gID] = struct{}{} + } + + var directPeerIDs []string + for _, policy := range policies { + if policy == nil || !policy.Enabled { + continue + } + for _, rule := range policy.Rules { + if rule == nil || !rule.Enabled { + continue + } + if !ruleMatchesDestinations(rule, destSet) { + continue + } + for _, gID := range rule.Sources { + groupSet[gID] = struct{}{} + } + if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" { + directPeerIDs = append(directPeerIDs, rule.SourceResource.ID) + } + } + } + return directPeerIDs +} + +func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool { + for _, gID := range rule.Destinations { + if _, ok := destSet[gID]; ok { + return true + } + } + return false +} + +func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string { groupIDs := make([]string, 0, len(groupSet)) for gID := range groupSet { groupIDs = append(groupIDs, gID) } - peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs) + peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs) if err != nil { log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err) return nil } - if len(directPeerIDs) > 0 { - seen := make(map[string]struct{}, len(peerIDs)) - for _, id := range peerIDs { - seen[id] = struct{}{} - } - for _, id := range directPeerIDs { - if _, exists := seen[id]; !exists { - peerIDs = append(peerIDs, id) - seen[id] = struct{}{} - } - } + if len(directPeerIDs) == 0 { + return peerIDs } - log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + seen := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + seen[id] = struct{}{} + } + for _, id := range directPeerIDs { + if _, exists := seen[id]; !exists { + peerIDs = append(peerIDs, id) + seen[id] = struct{}{} + } + } return peerIDs } diff --git a/management/server/networks/routers/manager.go b/management/server/networks/routers/manager.go index 0da184ca3..67c87fbdb 100644 --- a/management/server/networks/routers/manager.go +++ b/management/server/networks/routers/manager.go @@ -170,44 +170,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t var network *networkTypes.Network var affectedData *routerAffectedPeersData err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) - if err != nil { - return fmt.Errorf("failed to get network: %w", err) - } - - if network.ID != router.NetworkID { - return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) - } - - allPeerGroups := router.PeerGroups - var directPeers []string - if router.Peer != "" { - directPeers = append(directPeers, router.Peer) - } - oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID) - if err == nil { - allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...) - if oldRouter.Peer != "" { - directPeers = append(directPeers, oldRouter.Peer) - } - } - - err = transaction.SaveNetworkRouter(ctx, router) - if err != nil { - return fmt.Errorf("failed to update network router: %w", err) - } - - err = transaction.IncrementNetworkSerial(ctx, router.AccountID) - if err != nil { - return fmt.Errorf("failed to increment network serial: %w", err) - } - - affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...) - if err != nil { - log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err) - } - - return nil + var txErr error + network, affectedData, txErr = m.updateRouterInTransaction(ctx, transaction, router) + return txErr }) if err != nil { return nil, err @@ -225,6 +190,45 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t return router, nil } +func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, *routerAffectedPeersData, error) { + network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID) + if err != nil { + return nil, nil, fmt.Errorf("failed to get network: %w", err) + } + + if network.ID != router.NetworkID { + return nil, nil, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID) + } + + allPeerGroups := router.PeerGroups + var directPeers []string + if router.Peer != "" { + directPeers = append(directPeers, router.Peer) + } + oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID) + if err == nil { + allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...) + if oldRouter.Peer != "" { + directPeers = append(directPeers, oldRouter.Peer) + } + } + + if err = transaction.SaveNetworkRouter(ctx, router); err != nil { + return nil, nil, fmt.Errorf("failed to update network router: %w", err) + } + + if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil { + return nil, nil, fmt.Errorf("failed to increment network serial: %w", err) + } + + affectedData, err := loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...) + if err != nil { + log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err) + } + + return network, affectedData, nil +} + func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete) if err != nil { @@ -374,65 +378,79 @@ func (m *managerImpl) resolveRouterAffectedPeers(ctx context.Context, accountID } if len(data.resourceGroupIDs) > 0 { - destSet := make(map[string]struct{}, len(data.resourceGroupIDs)) - for _, gID := range data.resourceGroupIDs { - destSet[gID] = struct{}{} - } - - for _, policy := range data.policies { - if policy == nil || !policy.Enabled { - continue - } - for _, rule := range policy.Rules { - if rule == nil || !rule.Enabled { - continue - } - referencesResource := false - for _, gID := range rule.Destinations { - if _, ok := destSet[gID]; ok { - referencesResource = true - break - } - } - if !referencesResource { - continue - } - for _, gID := range rule.Sources { - groupSet[gID] = struct{}{} - } - } - } + collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet) } if len(groupSet) == 0 && len(data.directPeerIDs) == 0 { return nil } + peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, data.directPeerIDs) + + log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + return peerIDs +} + +// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs +// and adds their source groups to the groupSet. +func collectPolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) { + destSet := make(map[string]struct{}, len(destGroupIDs)) + for _, gID := range destGroupIDs { + destSet[gID] = struct{}{} + } + + for _, policy := range policies { + if policy == nil || !policy.Enabled { + continue + } + for _, rule := range policy.Rules { + if rule == nil || !rule.Enabled { + continue + } + if ruleMatchesDestinations(rule, destSet) { + for _, gID := range rule.Sources { + groupSet[gID] = struct{}{} + } + } + } + } +} + +func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool { + for _, gID := range rule.Destinations { + if _, ok := destSet[gID]; ok { + return true + } + } + return false +} + +func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string { groupIDs := make([]string, 0, len(groupSet)) for gID := range groupSet { groupIDs = append(groupIDs, gID) } - peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs) + peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs) if err != nil { log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err) return nil } - if len(data.directPeerIDs) > 0 { - seen := make(map[string]struct{}, len(peerIDs)) - for _, id := range peerIDs { - seen[id] = struct{}{} - } - for _, id := range data.directPeerIDs { - if _, exists := seen[id]; !exists { - peerIDs = append(peerIDs, id) - seen[id] = struct{}{} - } - } + if len(directPeerIDs) == 0 { + return peerIDs } - log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs) + seen := make(map[string]struct{}, len(peerIDs)) + for _, id := range peerIDs { + seen[id] = struct{}{} + } + for _, id := range directPeerIDs { + if _, exists := seen[id]; !exists { + peerIDs = append(peerIDs, id) + seen[id] = struct{}{} + } + } return peerIDs }