mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-03 15:46:38 +00:00
add to networks modules
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbTypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -111,6 +113,14 @@ func (m *managerImpl) UpdateNetwork(ctx context.Context, userID string, network
|
||||
return network, m.store.SaveNetwork(ctx, network)
|
||||
}
|
||||
|
||||
// networkAffectedPeersData holds data loaded inside the transaction for affected peer resolution.
|
||||
type networkAffectedPeersData struct {
|
||||
resourceGroupIDs []string
|
||||
routerPeerGroups []string
|
||||
directPeerIDs []string
|
||||
policies []*nbTypes.Policy
|
||||
}
|
||||
|
||||
func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, networkID string) error {
|
||||
ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Networks, operations.Delete)
|
||||
if err != nil {
|
||||
@@ -126,13 +136,22 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedData *networkAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resources in network: %w", err)
|
||||
}
|
||||
|
||||
var resourceGroupIDs []string
|
||||
for _, resource := range resources {
|
||||
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
|
||||
if err == nil {
|
||||
for _, g := range groups {
|
||||
resourceGroupIDs = append(resourceGroupIDs, g.ID)
|
||||
}
|
||||
}
|
||||
|
||||
event, err := m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
@@ -140,12 +159,19 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
eventsToStore = append(eventsToStore, event...)
|
||||
}
|
||||
|
||||
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
netRouters, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get routers in network: %w", err)
|
||||
}
|
||||
|
||||
for _, router := range routers {
|
||||
var routerPeerGroups []string
|
||||
var directPeerIDs []string
|
||||
for _, router := range netRouters {
|
||||
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
|
||||
if router.Peer != "" {
|
||||
directPeerIDs = append(directPeerIDs, router.Peer)
|
||||
}
|
||||
|
||||
event, err := m.routersManager.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, router.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete router: %w", err)
|
||||
@@ -153,6 +179,24 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
eventsToStore = append(eventsToStore, event)
|
||||
}
|
||||
|
||||
// load policies before deleting so group memberships are still present
|
||||
var policies []*nbTypes.Policy
|
||||
if len(resourceGroupIDs) > 0 {
|
||||
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to get policies for affected peers: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(resourceGroupIDs) > 0 || len(routerPeerGroups) > 0 || len(directPeerIDs) > 0 {
|
||||
affectedData = &networkAffectedPeersData{
|
||||
resourceGroupIDs: resourceGroupIDs,
|
||||
routerPeerGroups: routerPeerGroups,
|
||||
directPeerIDs: directPeerIDs,
|
||||
policies: policies,
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.DeleteNetwork(ctx, accountID, networkID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete network: %w", err)
|
||||
@@ -177,11 +221,82 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
if affectedData != nil {
|
||||
affectedPeerIDs := resolveNetworkAffectedPeers(ctx, m.store, accountID, affectedData)
|
||||
if len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveNetworkAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
|
||||
func resolveNetworkAffectedPeers(ctx context.Context, s store.Store, accountID string, data *networkAffectedPeersData) []string {
|
||||
groupSet := make(map[string]struct{})
|
||||
|
||||
for _, gID := range data.routerPeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
|
||||
if len(data.resourceGroupIDs) > 0 {
|
||||
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
|
||||
for _, gID := range data.resourceGroupIDs {
|
||||
destSet[gID] = struct{}{}
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, policy := range data.policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if rule == nil || !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, gID := range rule.Destinations {
|
||||
if _, ok := destSet[gID]; ok {
|
||||
for _, srcGID := range rule.Sources {
|
||||
groupSet[srcGID] = struct{}{}
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
groupIDs = append(groupIDs, gID)
|
||||
}
|
||||
|
||||
peerIDs, err := s.GetPeerIDsByGroups(ctx, accountID, groupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(data.directPeerIDs) > 0 {
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range data.directPeerIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
@@ -114,6 +114,7 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedData *resourceAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
_, err = transaction.GetNetworkResourceByName(ctx, store.LockingStrengthNone, resource.AccountID, resource.Name)
|
||||
if err == nil {
|
||||
@@ -152,6 +153,11 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, resource.GroupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -162,7 +168,9 @@ func (m *managerImpl) CreateResource(ctx context.Context, userID string, resourc
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -207,6 +215,7 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
resource.Prefix = prefix
|
||||
|
||||
var eventsToStore []func()
|
||||
var affectedData *resourceAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err := transaction.GetNetworkByID(ctx, store.LockingStrengthUpdate, resource.AccountID, resource.NetworkID)
|
||||
if err != nil {
|
||||
@@ -232,6 +241,15 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
return fmt.Errorf("failed to get network resource: %w", err)
|
||||
}
|
||||
|
||||
oldGroups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, oldResource.AccountID, oldResource.ID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get old resource groups: %w", err)
|
||||
}
|
||||
var oldGroupIDs []string
|
||||
for _, g := range oldGroups {
|
||||
oldGroupIDs = append(oldGroupIDs, g.ID)
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkResource(ctx, resource)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save network resource: %w", err)
|
||||
@@ -247,6 +265,11 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
m.accountManager.StoreEvent(ctx, userID, resource.ID, resource.AccountID, activity.NetworkResourceUpdated, resource.EventMeta(network))
|
||||
})
|
||||
|
||||
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, resource.AccountID, resource.NetworkID, append(resource.GroupIDs, oldGroupIDs...))
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
err = transaction.IncrementNetworkSerial(ctx, resource.AccountID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
@@ -270,7 +293,9 @@ func (m *managerImpl) UpdateResource(ctx context.Context, userID string, resourc
|
||||
}
|
||||
}()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, resource.AccountID)
|
||||
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, resource.AccountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, resource.AccountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return resource, nil
|
||||
}
|
||||
@@ -331,7 +356,22 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
}
|
||||
|
||||
var events []func()
|
||||
var affectedData *resourceAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
groups, err := m.groupsManager.GetResourceGroupsInTransaction(ctx, transaction, store.LockingStrengthNone, accountID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get resource groups: %w", err)
|
||||
}
|
||||
var resourceGroupIDs []string
|
||||
for _, g := range groups {
|
||||
resourceGroupIDs = append(resourceGroupIDs, g.ID)
|
||||
}
|
||||
|
||||
affectedData, err = loadResourceAffectedPeersData(ctx, transaction, accountID, networkID, resourceGroupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
events, err = m.DeleteResourceInTransaction(ctx, transaction, accountID, userID, networkID, resourceID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete resource: %w", err)
|
||||
@@ -352,7 +392,9 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net
|
||||
event()
|
||||
}
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
if affectedPeerIDs := m.resolveResourceAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -399,6 +441,127 @@ func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transacti
|
||||
return eventsToStore, nil
|
||||
}
|
||||
|
||||
// resourceAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
|
||||
type resourceAffectedPeersData struct {
|
||||
resourceGroupIDs []string
|
||||
policies []*nbtypes.Policy
|
||||
routerPeerGroups []string
|
||||
routerDirectPeers []string
|
||||
}
|
||||
|
||||
// loadResourceAffectedPeersData loads the data needed to determine affected peers within a transaction.
|
||||
func loadResourceAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, resourceGroupIDs []string) (*resourceAffectedPeersData, error) {
|
||||
if len(resourceGroupIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get policies: %w", err)
|
||||
}
|
||||
|
||||
routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get routers: %w", err)
|
||||
}
|
||||
|
||||
var routerPeerGroups []string
|
||||
var routerDirectPeers []string
|
||||
for _, router := range routers {
|
||||
if !router.Enabled {
|
||||
continue
|
||||
}
|
||||
routerPeerGroups = append(routerPeerGroups, router.PeerGroups...)
|
||||
if router.Peer != "" {
|
||||
routerDirectPeers = append(routerDirectPeers, router.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
return &resourceAffectedPeersData{
|
||||
resourceGroupIDs: resourceGroupIDs,
|
||||
policies: policies,
|
||||
routerPeerGroups: routerPeerGroups,
|
||||
routerDirectPeers: routerDirectPeers,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveResourceAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
|
||||
func (m *managerImpl) resolveResourceAffectedPeers(ctx context.Context, accountID string, data *resourceAffectedPeersData) []string {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
var directPeerIDs []string
|
||||
|
||||
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
|
||||
for _, gID := range data.resourceGroupIDs {
|
||||
destSet[gID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, policy := range data.policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if rule == nil || !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
referencesResource := false
|
||||
for _, gID := range rule.Destinations {
|
||||
if _, ok := destSet[gID]; ok {
|
||||
referencesResource = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !referencesResource {
|
||||
continue
|
||||
}
|
||||
for _, gID := range rule.Sources {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
if rule.SourceResource.Type == nbtypes.ResourceTypePeer && rule.SourceResource.ID != "" {
|
||||
directPeerIDs = append(directPeerIDs, rule.SourceResource.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, gID := range data.routerPeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
directPeerIDs = append(directPeerIDs, data.routerDirectPeers...)
|
||||
|
||||
if len(groupSet) == 0 && len(directPeerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
groupIDs = append(groupIDs, gID)
|
||||
}
|
||||
|
||||
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(directPeerIDs) > 0 {
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range directPeerIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerIDs, nil
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
nbtypes "github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
|
||||
@@ -89,6 +91,7 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var affectedData *routerAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
@@ -111,6 +114,11 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, router.PeerGroups, router.Peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -119,7 +127,9 @@ func (m *managerImpl) CreateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterCreated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
@@ -155,6 +165,7 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
}
|
||||
|
||||
var network *networkTypes.Network
|
||||
var affectedData *routerAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
network, err = transaction.GetNetworkByID(ctx, store.LockingStrengthNone, router.AccountID, router.NetworkID)
|
||||
if err != nil {
|
||||
@@ -165,6 +176,16 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return status.NewRouterNotPartOfNetworkError(router.ID, router.NetworkID)
|
||||
}
|
||||
|
||||
allPeerGroups := router.PeerGroups
|
||||
directPeers := []string{router.Peer}
|
||||
oldRouter, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, router.AccountID, router.ID)
|
||||
if err == nil {
|
||||
allPeerGroups = append(allPeerGroups, oldRouter.PeerGroups...)
|
||||
if oldRouter.Peer != "" {
|
||||
directPeers = append(directPeers, oldRouter.Peer)
|
||||
}
|
||||
}
|
||||
|
||||
err = transaction.SaveNetworkRouter(ctx, router)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update network router: %w", err)
|
||||
@@ -175,6 +196,11 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
return fmt.Errorf("failed to increment network serial: %w", err)
|
||||
}
|
||||
|
||||
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, router.AccountID, router.NetworkID, allPeerGroups, directPeers...)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@@ -183,7 +209,9 @@ func (m *managerImpl) UpdateRouter(ctx context.Context, userID string, router *t
|
||||
|
||||
m.accountManager.StoreEvent(ctx, userID, router.ID, router.AccountID, activity.NetworkRouterUpdated, router.EventMeta(network))
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, router.AccountID)
|
||||
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, router.AccountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, router.AccountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return router, nil
|
||||
}
|
||||
@@ -198,7 +226,19 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
}
|
||||
|
||||
var event func()
|
||||
var affectedData *routerAffectedPeersData
|
||||
err = m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
|
||||
router, err := transaction.GetNetworkRouterByID(ctx, store.LockingStrengthNone, accountID, routerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get router: %w", err)
|
||||
}
|
||||
|
||||
// load before delete so group memberships are still present
|
||||
affectedData, err = loadRouterAffectedPeersData(ctx, transaction, accountID, networkID, router.PeerGroups, router.Peer)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to load affected peers data: %v", err)
|
||||
}
|
||||
|
||||
event, err = m.DeleteRouterInTransaction(ctx, transaction, accountID, userID, networkID, routerID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete network router: %w", err)
|
||||
@@ -217,7 +257,9 @@ func (m *managerImpl) DeleteRouter(ctx context.Context, accountID, userID, netwo
|
||||
|
||||
event()
|
||||
|
||||
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||
if affectedPeerIDs := m.resolveRouterAffectedPeers(ctx, accountID, affectedData); len(affectedPeerIDs) > 0 {
|
||||
go m.accountManager.UpdateAffectedPeers(ctx, accountID, affectedPeerIDs)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -249,6 +291,136 @@ func (m *managerImpl) DeleteRouterInTransaction(ctx context.Context, transaction
|
||||
return event, nil
|
||||
}
|
||||
|
||||
// routerAffectedPeersData holds data loaded inside a transaction for affected peer resolution.
|
||||
type routerAffectedPeersData struct {
|
||||
routerPeerGroups []string
|
||||
directPeerIDs []string
|
||||
resourceGroupIDs []string
|
||||
policies []*nbtypes.Policy
|
||||
}
|
||||
|
||||
// loadRouterAffectedPeersData loads the data needed to determine affected peers within a transaction.
|
||||
func loadRouterAffectedPeersData(ctx context.Context, transaction store.Store, accountID, networkID string, routerPeerGroups []string, directPeers ...string) (*routerAffectedPeersData, error) {
|
||||
var directPeerIDs []string
|
||||
for _, p := range directPeers {
|
||||
if p != "" {
|
||||
directPeerIDs = append(directPeerIDs, p)
|
||||
}
|
||||
}
|
||||
|
||||
if len(routerPeerGroups) == 0 && len(directPeerIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthNone, accountID, networkID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get network resources: %w", err)
|
||||
}
|
||||
|
||||
var resourceGroupIDs []string
|
||||
for _, resource := range resources {
|
||||
if !resource.Enabled {
|
||||
continue
|
||||
}
|
||||
groups, err := transaction.GetResourceGroups(ctx, store.LockingStrengthNone, accountID, resource.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get groups for resource %s: %w", resource.ID, err)
|
||||
}
|
||||
for _, g := range groups {
|
||||
resourceGroupIDs = append(resourceGroupIDs, g.ID)
|
||||
}
|
||||
}
|
||||
|
||||
var policies []*nbtypes.Policy
|
||||
if len(resourceGroupIDs) > 0 {
|
||||
policies, err = transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get policies: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return &routerAffectedPeersData{
|
||||
routerPeerGroups: routerPeerGroups,
|
||||
directPeerIDs: directPeerIDs,
|
||||
resourceGroupIDs: resourceGroupIDs,
|
||||
policies: policies,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// resolveRouterAffectedPeers computes affected peer IDs from preloaded data outside the transaction.
|
||||
func (m *managerImpl) resolveRouterAffectedPeers(ctx context.Context, accountID string, data *routerAffectedPeersData) []string {
|
||||
if data == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupSet := make(map[string]struct{})
|
||||
|
||||
for _, gID := range data.routerPeerGroups {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
|
||||
if len(data.resourceGroupIDs) > 0 {
|
||||
destSet := make(map[string]struct{}, len(data.resourceGroupIDs))
|
||||
for _, gID := range data.resourceGroupIDs {
|
||||
destSet[gID] = struct{}{}
|
||||
}
|
||||
|
||||
for _, policy := range data.policies {
|
||||
if policy == nil || !policy.Enabled {
|
||||
continue
|
||||
}
|
||||
for _, rule := range policy.Rules {
|
||||
if rule == nil || !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
referencesResource := false
|
||||
for _, gID := range rule.Destinations {
|
||||
if _, ok := destSet[gID]; ok {
|
||||
referencesResource = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !referencesResource {
|
||||
continue
|
||||
}
|
||||
for _, gID := range rule.Sources {
|
||||
groupSet[gID] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(groupSet) == 0 && len(data.directPeerIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
groupIDs := make([]string, 0, len(groupSet))
|
||||
for gID := range groupSet {
|
||||
groupIDs = append(groupIDs, gID)
|
||||
}
|
||||
|
||||
peerIDs, err := m.store.GetPeerIDsByGroups(ctx, accountID, groupIDs)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve peer IDs: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(data.directPeerIDs) > 0 {
|
||||
seen := make(map[string]struct{}, len(peerIDs))
|
||||
for _, id := range peerIDs {
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
for _, id := range data.directPeerIDs {
|
||||
if _, exists := seen[id]; !exists {
|
||||
peerIDs = append(peerIDs, id)
|
||||
seen[id] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return peerIDs
|
||||
}
|
||||
|
||||
func NewManagerMock() Manager {
|
||||
return &mockManager{}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user