extract submethods

This commit is contained in:
pascal
2026-05-08 16:43:27 +02:00
parent fed4f1b024
commit 85851bc477
5 changed files with 288 additions and 612 deletions

View File

@@ -248,19 +248,7 @@ func (c *Controller) UpdateAffectedPeers(ctx context.Context, accountID string,
func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: account %s, %d affected peers: %v (caller: %s)", accountID, len(peerIDs), peerIDs, util.GetCallerName())
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
hasConnected := false
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
hasConnected = true
break
}
}
if !hasConnected {
if !c.hasConnectedPeers(peerIDs) {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no connected peers among %v, skipping", peerIDs)
return nil
}
@@ -272,13 +260,7 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
globalStart := time.Now()
var peersToUpdate []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
peersToUpdate = append(peersToUpdate, peer)
}
}
peersToUpdate := c.filterConnectedAffectedPeers(account, peerIDs)
if len(peersToUpdate) == 0 {
log.WithContext(ctx).Tracef("sendUpdateForAffectedPeers: no peers to update (affected peers not found in account or no channels)")
return nil
@@ -368,6 +350,30 @@ func (c *Controller) sendUpdateForAffectedPeers(ctx context.Context, accountID s
return nil
}
func (c *Controller) hasConnectedPeers(peerIDs []string) bool {
for _, id := range peerIDs {
if c.peersUpdateManager.HasChannel(id) {
return true
}
}
return false
}
func (c *Controller) filterConnectedAffectedPeers(account *types.Account, peerIDs []string) []*nbpeer.Peer {
affected := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
affected[id] = struct{}{}
}
var result []*nbpeer.Peer
for _, peer := range account.Peers {
if _, ok := affected[peer.ID]; ok && c.peersUpdateManager.HasChannel(peer.ID) {
result = append(result, peer)
}
}
return result
}
func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error {
if !c.peersUpdateManager.HasChannel(peerId) {
return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId)

View File

@@ -9,15 +9,12 @@ import (
"github.com/rs/xid"
log "github.com/sirupsen/logrus"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/activity"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
"github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -656,390 +653,3 @@ func validateNewGroup(ctx context.Context, transaction store.Store, accountID st
return nil
}
func validateDeleteGroup(ctx context.Context, transaction store.Store, group *types.Group, userID string, flowGroups []string) error {
// disable a deleting integration group if the initiator is not an admin service user
if group.Issued == types.GroupIssuedIntegration {
executingUser, err := transaction.GetUserByUserID(ctx, store.LockingStrengthNone, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to get user")
}
if executingUser.Role != types.UserRoleAdmin || !executingUser.IsServiceUser {
return status.Errorf(status.PermissionDenied, "only service users with admin power can delete integration group")
}
}
if group.IsGroupAll() {
return status.Errorf(status.InvalidArgument, "deleting group ALL is not allowed")
}
if len(group.Resources) > 0 {
return &GroupLinkError{"network resource", group.Resources[0].ID}
}
if slices.Contains(flowGroups, group.ID) {
return &GroupLinkError{"settings", "traffic event logging"}
}
if isLinked, linkedRoute := isGroupLinkedToRoute(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"route", string(linkedRoute.NetID)}
}
if isLinked, linkedDns := isGroupLinkedToDns(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"name server groups", linkedDns.Name}
}
if isLinked, linkedPolicy := isGroupLinkedToPolicy(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"policy", linkedPolicy.Name}
}
if isLinked, linkedSetupKey := isGroupLinkedToSetupKey(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"setup key", linkedSetupKey.Name}
}
if isLinked, linkedUser := isGroupLinkedToUser(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"user", linkedUser.Id}
}
if isLinked, linkedRouter := isGroupLinkedToNetworkRouter(ctx, transaction, group.AccountID, group.ID); isLinked {
return &GroupLinkError{"network router", linkedRouter.ID}
}
return checkGroupLinkedToSettings(ctx, transaction, group)
}
// checkGroupLinkedToSettings verifies if a group is linked to any settings in the account.
func checkGroupLinkedToSettings(ctx context.Context, transaction store.Store, group *types.Group) error {
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get DNS settings")
}
if slices.Contains(dnsSettings.DisabledManagementGroups, group.ID) {
return &GroupLinkError{"disabled DNS management groups", group.Name}
}
settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthNone, group.AccountID)
if err != nil {
return status.Errorf(status.Internal, "failed to get account settings")
}
if settings.Extra != nil && slices.Contains(settings.Extra.IntegratedValidatorGroups, group.ID) {
return &GroupLinkError{"integrated validator", group.Name}
}
return nil
}
// isGroupLinkedToRoute checks if a group is linked to any route in the account.
func isGroupLinkedToRoute(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *route.Route) {
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving routes while checking group linkage: %v", err)
return false, nil
}
for _, r := range routes {
isLinked := slices.Contains(r.Groups, groupID) ||
slices.Contains(r.PeerGroups, groupID) ||
slices.Contains(r.AccessControlGroups, groupID)
if isLinked {
return true, r
}
}
return false, nil
}
// isGroupLinkedToPolicy checks if a group is linked to any policy in the account.
func isGroupLinkedToPolicy(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.Policy) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving policies while checking group linkage: %v", err)
return false, nil
}
for _, policy := range policies {
for _, rule := range policy.Rules {
if slices.Contains(rule.Sources, groupID) || slices.Contains(rule.Destinations, groupID) {
return true, policy
}
}
}
return false, nil
}
// isGroupLinkedToDns checks if a group is linked to any nameserver group in the account.
func isGroupLinkedToDns(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *nbdns.NameServerGroup) {
nameServerGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving name server groups while checking group linkage: %v", err)
return false, nil
}
for _, dns := range nameServerGroups {
for _, g := range dns.Groups {
if g == groupID {
return true, dns
}
}
}
return false, nil
}
// isGroupLinkedToSetupKey checks if a group is linked to any setup key in the account.
func isGroupLinkedToSetupKey(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.SetupKey) {
setupKeys, err := transaction.GetAccountSetupKeys(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving setup keys while checking group linkage: %v", err)
return false, nil
}
for _, setupKey := range setupKeys {
if slices.Contains(setupKey.AutoGroups, groupID) {
return true, setupKey
}
}
return false, nil
}
// isGroupLinkedToUser checks if a group is linked to any user in the account.
func isGroupLinkedToUser(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *types.User) {
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving users while checking group linkage: %v", err)
return false, nil
}
for _, user := range users {
if slices.Contains(user.AutoGroups, groupID) {
return true, user
}
}
return false, nil
}
// isGroupLinkedToNetworkRouter checks if a group is linked to any network router in the account.
func isGroupLinkedToNetworkRouter(ctx context.Context, transaction store.Store, accountID string, groupID string) (bool, *routerTypes.NetworkRouter) {
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("error retrieving network routers while checking group linkage: %v", err)
return false, nil
}
for _, router := range routers {
if slices.Contains(router.PeerGroups, groupID) {
return true, router
}
}
return false, nil
}
// areGroupChangesAffectPeers checks if any changes to the specified groups will affect peers.
func areGroupChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, groupIDs []string) (bool, error) {
if len(groupIDs) == 0 {
return false, nil
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
for _, groupID := range groupIDs {
if slices.Contains(dnsSettings.DisabledManagementGroups, groupID) {
return true, nil
}
if linked, _ := isGroupLinkedToDns(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToPolicy(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToRoute(ctx, transaction, accountID, groupID); linked {
return true, nil
}
if linked, _ := isGroupLinkedToNetworkRouter(ctx, transaction, accountID, groupID); linked {
return true, nil
}
}
return false, nil
}
// collectGroupChangeAffectedGroups walks policies, routes, nameservers, DNS settings,
// and network routers to collect all group IDs and direct peer IDs affected by the changed groups.
func collectGroupChangeAffectedGroups(ctx context.Context, transaction store.Store, accountID string, changedGroupIDs []string) (allGroupIDs []string, directPeerIDs []string) {
if len(changedGroupIDs) == 0 {
return nil, nil
}
changedSet := make(map[string]struct{}, len(changedGroupIDs))
for _, id := range changedGroupIDs {
changedSet[id] = struct{}{}
}
log.WithContext(ctx).Tracef("collecting affected groups for changed groups %v", changedGroupIDs)
groupSet := make(map[string]struct{})
peerSet := make(map[string]struct{})
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get policies for group change resolution: %v", err)
} else {
for _, policy := range policies {
if !policyReferencesGroups(policy, changedSet) {
continue
}
ruleGroups := policy.RuleGroups()
log.WithContext(ctx).Tracef("policy %s (%s) references changed groups, adding rule groups %v", policy.ID, policy.Name, ruleGroups)
for _, gID := range ruleGroups {
groupSet[gID] = struct{}{}
}
for _, rule := range policy.Rules {
if rule.SourceResource.Type == types.ResourceTypePeer && rule.SourceResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct source peer %s", policy.ID, rule.ID, rule.SourceResource.ID)
peerSet[rule.SourceResource.ID] = struct{}{}
}
if rule.DestinationResource.Type == types.ResourceTypePeer && rule.DestinationResource.ID != "" {
log.WithContext(ctx).Tracef("policy %s rule %s has direct destination peer %s", policy.ID, rule.ID, rule.DestinationResource.ID)
peerSet[rule.DestinationResource.ID] = struct{}{}
}
}
}
}
routes, err := transaction.GetAccountRoutes(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get routes for group change resolution: %v", err)
} else {
for _, r := range routes {
if !routeReferencesGroups(r, changedSet) {
continue
}
log.WithContext(ctx).Tracef("route %s (%s) references changed groups", r.ID, r.Description)
for _, gID := range r.Groups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.PeerGroups {
groupSet[gID] = struct{}{}
}
for _, gID := range r.AccessControlGroups {
groupSet[gID] = struct{}{}
}
if r.Peer != "" {
log.WithContext(ctx).Tracef("route %s has direct peer %s", r.ID, r.Peer)
peerSet[r.Peer] = struct{}{}
}
}
}
nsGroups, err := transaction.GetAccountNameServerGroups(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get nameserver groups for group change resolution: %v", err)
} else {
for _, ns := range nsGroups {
for _, gID := range ns.Groups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("nameserver group %s (%s) references changed group %s", ns.ID, ns.Name, gID)
for _, g := range ns.Groups {
groupSet[g] = struct{}{}
}
break
}
}
}
}
dnsSettings, err := transaction.GetAccountDNSSettings(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get DNS settings for group change resolution: %v", err)
} else {
for _, gID := range dnsSettings.DisabledManagementGroups {
if _, ok := changedSet[gID]; ok {
log.WithContext(ctx).Tracef("DNS disabled management group %s matches changed group", gID)
groupSet[gID] = struct{}{}
}
}
}
routers, err := transaction.GetNetworkRoutersByAccountID(ctx, store.LockingStrengthNone, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get network routers for group change resolution: %v", err)
} else {
for _, router := range routers {
if !routerReferencesGroups(router, changedSet) {
continue
}
log.WithContext(ctx).Tracef("network router %s references changed groups", router.ID)
for _, gID := range router.PeerGroups {
groupSet[gID] = struct{}{}
}
if router.Peer != "" {
log.WithContext(ctx).Tracef("network router %s has direct peer %s", router.ID, router.Peer)
peerSet[router.Peer] = struct{}{}
}
}
}
allGroupIDs = make([]string, 0, len(groupSet))
for gID := range groupSet {
allGroupIDs = append(allGroupIDs, gID)
}
directPeerIDs = make([]string, 0, len(peerSet))
for pID := range peerSet {
directPeerIDs = append(directPeerIDs, pID)
}
log.WithContext(ctx).Tracef("affected groups resolution: changed=%v -> affectedGroups=%v, directPeers=%v", changedGroupIDs, allGroupIDs, directPeerIDs)
return allGroupIDs, directPeerIDs
}
func policyReferencesGroups(policy *types.Policy, groupSet map[string]struct{}) bool {
for _, rule := range policy.Rules {
for _, gID := range rule.Sources {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range rule.Destinations {
if _, ok := groupSet[gID]; ok {
return true
}
}
}
return false
}
func routeReferencesGroups(r *route.Route, groupSet map[string]struct{}) bool {
for _, gID := range r.Groups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
for _, gID := range r.AccessControlGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}
func routerReferencesGroups(router *routerTypes.NetworkRouter, groupSet map[string]struct{}) bool {
for _, gID := range router.PeerGroups {
if _, ok := groupSet[gID]; ok {
return true
}
}
return false
}

View File

@@ -245,62 +245,84 @@ func resolveNetworkAffectedPeers(ctx context.Context, s store.Store, accountID s
}
if len(data.resourceGroupIDs) > 0 {
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
groupSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
for _, srcGID := range rule.Sources {
groupSet[srcGID] = struct{}{}
}
break
}
}
}
}
collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, s, accountID, groupSet, data.directPeerIDs)
log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs
// and adds their source groups to the groupSet.
func collectPolicySourceGroups(policies []*nbTypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
if ruleMatchesDestinations(rule, destSet) {
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
}
}
// ruleMatchesDestinations checks if a policy rule references any of the destination groups.
func ruleMatchesDestinations(rule *nbTypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
// resolveGroupsAndDirectPeers resolves group IDs and direct peer IDs into a deduplicated peer ID list.
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: resolved groupIDs=%v", groupIDs)
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(data.directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range data.directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
if len(directPeerIDs) == 0 {
return peerIDs
}
log.WithContext(ctx).Tracef("resolveNetworkAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}

View File

@@ -116,49 +116,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
var eventsToStore []func()
var affectedData *resourceAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
err = transaction.SaveNetworkResource(ctx, resource)
if err != nil {
return fmt.Errorf("failed to save network resource: %w", err)
}
event := func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
}
eventsToStore = append(eventsToStore, event)
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
var txErr error
eventsToStore, affectedData, txErr = m.createResourceInTransaction(ctx, transaction, userID, resource)
return txErr
})
if err != nil {
return nil, fmt.Errorf("failed to create network resource: %w", err)
@@ -178,6 +138,50 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
return resource, nil
}
func (m *managerImpl) createResourceInTransaction(ctx context.Context, transaction store.Store, userID string, resource *types.NetworkResource) ([]func(), *resourceAffectedPeersData, error) {
_, err := transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
if err == nil {
return nil, nil, status.Errorf(status.InvalidArgument, "resource with name %s already exists", resource.Name)
}
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if err = transaction.SaveNetworkResource(ctx, resource); err != nil {
return nil, nil, fmt.Errorf("failed to save network resource: %w", err)
}
var eventsToStore []func()
eventsToStore = append(eventsToStore, func() {
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceCreated, resource.EventMeta(network))
})
res := nbtypes.Resource{
ID: resource.ID,
Type: nbtypes.ResourceType(resource.Type.String()),
}
for _, groupID := range resource.GroupIDs {
event, err := m.groupsManager.AddResourceToGroupInTransaction(ctx, transaction, resource.AccountID, userID, groupID, &res)
if err != nil {
return nil, nil, fmt.Errorf("failed to add resource to group: %w", err)
}
eventsToStore = append(eventsToStore, event)
}
if err = transaction.IncrementNetworkSerial(ctx, resource.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err := loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return eventsToStore, affectedData, nil
}
func (m *managerImpl) GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Read)
if err != nil {
@@ -502,40 +506,9 @@ func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountI
log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: resourceGroupIDs=%v, routerPeerGroups=%v, routerDirectPeers=%v, policies=%d",
data.resourceGroupIDs, data.routerPeerGroups, data.routerDirectPeers, len(data.policies))
groupSet := make(map[string]struct{})
var directPeerIDs []string
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
referencesResource := false
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
referencesResource = true
break
}
}
if !referencesResource {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
}
}
directPeerIDs := collectResourcePolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
for _, gID := range data.routerPeerGroups {
groupSet[gID] = struct{}{}
@@ -546,31 +519,78 @@ func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountI
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, directPeerIDs)
log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectResourcePolicySourceGroups finds policies whose rules reference the resource destination groups,
// adds their source groups to groupSet, and returns any direct peer IDs from source resources.
func collectResourcePolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) []string {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
var directPeerIDs []string
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
if !ruleMatchesDestinations(rule, destSet) {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" {
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
}
}
}
return directPeerIDs
}
func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
if len(directPeerIDs) == 0 {
return peerIDs
}
log.WithContext(ctx).Tracef("resolveResourceAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}

View File

@@ -170,44 +170,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
var network *networkTypes.Network
var affectedData *routerAffectedPeersData
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return fmt.Errorf("failed to get network: %w", err)
}
if network.ID != router.NetworkID {
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
allPeerGroups := router.PeerGroups
var directPeers []string
if router.Peer != "" {
directPeers = append(directPeers, router.Peer)
}
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
if err == nil {
allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...)
if oldRouter.Peer != "" {
directPeers = append(directPeers, oldRouter.Peer)
}
}
err = transaction.SaveNetworkRouter(ctx, router)
if err != nil {
return fmt.Errorf("failed to update network router: %w", err)
}
err = transaction.IncrementNetworkSerial(ctx, router.AccountID)
if err != nil {
return fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return nil
var txErr error
network, affectedData, txErr = m.updateRouterInTransaction(ctx, transaction, router)
return txErr
})
if err != nil {
return nil, err
@@ -225,6 +190,45 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
return router, nil
}
func (m *managerImpl) updateRouterInTransaction(ctx context.Context, transaction store.Store, router *types.NetworkRouter) (*networkTypes.Network, *routerAffectedPeersData, error) {
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
if err != nil {
return nil, nil, fmt.Errorf("failed to get network: %w", err)
}
if network.ID != router.NetworkID {
return nil, nil, status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
}
allPeerGroups := router.PeerGroups
var directPeers []string
if router.Peer != "" {
directPeers = append(directPeers, router.Peer)
}
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
if err == nil {
allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...)
if oldRouter.Peer != "" {
directPeers = append(directPeers, oldRouter.Peer)
}
}
if err = transaction.SaveNetworkRouter(ctx, router); err != nil {
return nil, nil, fmt.Errorf("failed to update network router: %w", err)
}
if err = transaction.IncrementNetworkSerial(ctx, router.AccountID); err != nil {
return nil, nil, fmt.Errorf("failed to increment network serial: %w", err)
}
affectedData, err := loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...)
if err != nil {
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
}
return network, affectedData, nil
}
func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, networkID, routerID string) error {
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
if err != nil {
@@ -374,65 +378,79 @@ func (m *managerImpl) resolveRouterAffectedPeers(ctx context.Context, accountID
}
if len(data.resourceGroupIDs) > 0 {
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
for _, gID := range data.resourceGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range data.policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
referencesResource := false
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
referencesResource = true
break
}
}
if !referencesResource {
continue
}
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
collectPolicySourceGroups(data.policies, data.resourceGroupIDs, groupSet)
}
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
return nil
}
peerIDs := resolveGroupsAndDirectPeers(ctx, m.store, accountID, groupSet, data.directPeerIDs)
log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
return peerIDs
}
// collectPolicySourceGroups finds policies whose rules reference any of the destination group IDs
// and adds their source groups to the groupSet.
func collectPolicySourceGroups(policies []*nbtypes.Policy, destGroupIDs []string, groupSet map[string]struct{}) {
destSet := make(map[string]struct{}, len(destGroupIDs))
for _, gID := range destGroupIDs {
destSet[gID] = struct{}{}
}
for _, policy := range policies {
if policy == nil || !policy.Enabled {
continue
}
for _, rule := range policy.Rules {
if rule == nil || !rule.Enabled {
continue
}
if ruleMatchesDestinations(rule, destSet) {
for _, gID := range rule.Sources {
groupSet[gID] = struct{}{}
}
}
}
}
}
func ruleMatchesDestinations(rule *nbtypes.PolicyRule, destSet map[string]struct{}) bool {
for _, gID := range rule.Destinations {
if _, ok := destSet[gID]; ok {
return true
}
}
return false
}
func resolveGroupsAndDirectPeers(ctx context.Context, s store.Store, accountID string, groupSet map[string]struct{}, directPeerIDs []string) []string {
groupIDs := make([]string, 0, len(groupSet))
for gID := range groupSet {
groupIDs = append(groupIDs, gID)
}
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
if err != nil {
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
return nil
}
if len(data.directPeerIDs) > 0 {
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range data.directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
if len(directPeerIDs) == 0 {
return peerIDs
}
log.WithContext(ctx).Tracef("resolveRouterAffectedPeers: result %d peers: %v", len(peerIDs), peerIDs)
seen := make(map[string]struct{}, len(peerIDs))
for _, id := range peerIDs {
seen[id] = struct{}{}
}
for _, id := range directPeerIDs {
if _, exists := seen[id]; !exists {
peerIDs = append(peerIDs, id)
seen[id] = struct{}{}
}
}
return peerIDs
}