diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index ec3a5261e..fff7a5f6d 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -252,20 +252,29 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne func (h *handler) generateNetworkResponse(networks []*types.Network, routers map[string][]*routerTypes.NetworkRouter, resourceIDs map[string][]string, groups map[string]*nbtypes.Group) []*api.Network { var networkResponse []*api.Network for _, network := range networks { - routerIDs := []string{} - peerCounter := 0 - for _, router := range routers[network.ID] { - routerIDs = append(routerIDs, router.ID) - if router.Peer != "" { - peerCounter++ - } - if len(router.PeerGroups) > 0 { - for _, groupID := range router.PeerGroups { - peerCounter += len(groups[groupID].Peers) - } - } - } + routerIDs, peerCounter := getRouterIDs(network, routers, groups) networkResponse = append(networkResponse, network.ToAPIResponse(routerIDs, resourceIDs[network.ID], peerCounter)) } return networkResponse } + +func getRouterIDs(network *types.Network, routers map[string][]*routerTypes.NetworkRouter, groups map[string]*nbtypes.Group) ([]string, int) { + routerIDs := []string{} + peerCounter := 0 + for _, router := range routers[network.ID] { + routerIDs = append(routerIDs, router.ID) + if router.Peer != "" { + peerCounter++ + } + if len(router.PeerGroups) > 0 { + for _, groupID := range router.PeerGroups { + group, ok := groups[groupID] + if !ok { + continue + } + peerCounter += len(group.Peers) + } + } + } + return routerIDs, peerCounter +} diff --git a/management/server/networks/manager.go b/management/server/networks/manager.go index 61dd59cb8..ddc88b05f 100644 --- a/management/server/networks/manager.go +++ b/management/server/networks/manager.go @@ -2,6 +2,7 @@ package networks import ( "context" + "fmt" "github.com/rs/xid" @@ -98,7 +99,36 @@ func (m *managerImpl) DeleteNetwork(ctx context.Context, accountID, userID, netw return status.NewPermissionDeniedError() } - return m.store.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + resources, err := transaction.GetNetworkResourcesByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get resources in network: %w", err) + } + + for _, resource := range resources { + err = m.resourcesManager.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resource.ID) + if err != nil { + return fmt.Errorf("failed to delete resource: %w", err) + } + } + + routers, err := transaction.GetNetworkRoutersByNetID(ctx, store.LockingStrengthUpdate, accountID, networkID) + if err != nil { + return fmt.Errorf("failed to get routers in network: %w", err) + } + + for _, router := range routers { + err = transaction.DeleteNetworkRouter(ctx, store.LockingStrengthUpdate, accountID, router.ID) + if err != nil { + return fmt.Errorf("failed to delete router: %w", err) + } + } + + return transaction.DeleteNetwork(ctx, store.LockingStrengthUpdate, accountID, networkID) + }) } func (m *managerImpl) GetResourceManager() resources.Manager { diff --git a/management/server/networks/resources/manager.go b/management/server/networks/resources/manager.go index 8d13659bd..907d926cd 100644 --- a/management/server/networks/resources/manager.go +++ b/management/server/networks/resources/manager.go @@ -19,6 +19,7 @@ type Manager interface { GetResource(ctx context.Context, accountID, userID, networkID, resourceID string) (*types.NetworkResource, error) UpdateResource(ctx context.Context, userID string, resource *types.NetworkResource) (*types.NetworkResource, error) DeleteResource(ctx context.Context, accountID, userID, networkID, resourceID string) error + DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error } type managerImpl struct { @@ -161,5 +162,39 @@ func (m *managerImpl) DeleteResource(ctx context.Context, accountID, userID, net return status.NewPermissionDeniedError() } - return m.store.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) + unlock := m.store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + return m.store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + return m.DeleteResourceInTransaction(ctx, transaction, accountID, networkID, resourceID) + }) +} + +func (m *managerImpl) DeleteResourceInTransaction(ctx context.Context, transaction store.Store, accountID, networkID, resourceID string) error { + resource, err := transaction.GetNetworkResourceByID(ctx, store.LockingStrengthUpdate, accountID, resourceID) + if err != nil { + return fmt.Errorf("failed to get network resource: %w", err) + } + + if resource.NetworkID != networkID { + return errors.New("resource not part of network") + } + + account, err := transaction.GetAccount(ctx, accountID) + if err != nil { + return fmt.Errorf("failed to get account: %w", err) + } + account.DeleteResource(resource.ID) + + err = transaction.SaveAccount(ctx, account) + if err != nil { + return fmt.Errorf("failed to save account: %w", err) + } + + err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return fmt.Errorf("failed to increment network serial: %w", err) + } + + return transaction.DeleteNetworkResource(ctx, store.LockingStrengthUpdate, accountID, resourceID) } diff --git a/management/server/types/account.go b/management/server/types/account.go index 281c8ea63..2b8c37178 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -569,10 +569,29 @@ func (a *Account) DeletePeer(peerID string) { } } + for i, r := range a.NetworkRouters { + if r.Peer == peerID { + a.NetworkRouters = append(a.NetworkRouters[:i], a.NetworkRouters[i+1:]...) + break + } + } + delete(a.Peers, peerID) a.Network.IncSerial() } +func (a *Account) DeleteResource(resourceID string) { + // delete resource from groups + for _, g := range a.Groups { + for i, pk := range g.Resources { + if pk.ID == resourceID { + g.Resources = append(g.Resources[:i], g.Resources[i+1:]...) + break + } + } + } +} + // FindPeerByPubKey looks for a Peer by provided WireGuard public key in the Account or returns error if it wasn't found. // It will return an object copy of the peer. func (a *Account) FindPeerByPubKey(peerPubKey string) (*nbpeer.Peer, error) {