diff --git a/management/server/groups/manager.go b/management/server/groups/manager.go index f5abb212e..02b669e41 100644 --- a/management/server/groups/manager.go +++ b/management/server/groups/manager.go @@ -13,7 +13,8 @@ import ( ) type Manager interface { - GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) + GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) GetResourceGroupsInTransaction(ctx context.Context, transaction store.Store, lockingStrength store.LockingStrength, accountID, resourceID string) ([]*types.Group, error) AddResourceToGroup(ctx context.Context, accountID, userID, groupID string, resourceID *types.Resource) error AddResourceToGroupInTransaction(ctx context.Context, transaction store.Store, accountID, userID, groupID string, resourceID *types.Resource) (func(), error) @@ -37,7 +38,7 @@ func NewManager(store store.Store, permissionsManager permissions.Manager, accou } } -func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { ok, err := m.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, permissions.Groups, permissions.Read) if err != nil { return nil, err @@ -51,6 +52,15 @@ func (m *managerImpl) GetAllGroups(ctx context.Context, accountID, userID string return nil, fmt.Errorf("error getting account groups: %w", err) } + return groups, nil +} + +func (m *managerImpl) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { + groups, err := m.GetAllGroups(ctx, accountID, userID) + if err != nil { + return nil, err + } + groupsMap := make(map[string]*types.Group) for _, group := range groups { groupsMap[group.ID] = group @@ -130,7 +140,7 @@ func (m *managerImpl) GetResourceGroupsInTransaction(ctx context.Context, transa return transaction.GetResourceGroups(ctx, lockingStrength, accountID, resourceID) } -func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum { +func ToGroupsInfo(groups []*types.Group, id string) []api.GroupMinimum { groupsInfo := []api.GroupMinimum{} groupsChecked := make(map[string]struct{}) for _, group := range groups { @@ -167,7 +177,11 @@ func ToGroupsInfo(groups map[string]*types.Group, id string) []api.GroupMinimum return groupsInfo } -func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { +func (m *mockManager) GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error) { + return []*types.Group{}, nil +} + +func (m *mockManager) GetAllGroupsMap(ctx context.Context, accountID, userID string) (map[string]*types.Group, error) { return map[string]*types.Group{}, nil } diff --git a/management/server/http/handlers/networks/handler.go b/management/server/http/handlers/networks/handler.go index 6b36a8fce..316b93611 100644 --- a/management/server/http/handlers/networks/handler.go +++ b/management/server/http/handlers/networks/handler.go @@ -82,7 +82,7 @@ func (h *handler) getAllNetworks(w http.ResponseWriter, r *http.Request) { return } - groups, err := h.groupsManager.GetAllGroups(r.Context(), accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(r.Context(), accountID, userID) if err != nil { util.WriteError(r.Context(), err, w) return @@ -267,7 +267,7 @@ func (h *handler) collectIDsInNetwork(ctx context.Context, accountID, userID, ne return nil, nil, 0, fmt.Errorf("failed to get routers in network: %w", err) } - groups, err := h.groupsManager.GetAllGroups(ctx, accountID, userID) + groups, err := h.groupsManager.GetAllGroupsMap(ctx, accountID, userID) if err != nil { return nil, nil, 0, fmt.Errorf("failed to get groups: %w", err) } diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 09d63ea6f..76a0149c6 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -72,13 +72,8 @@ func (h *Handler) getPeer(ctx context.Context, accountID, peerID, userID string, } dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} - grps, _ := h.accountManager.GetAllGroups(ctx, accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } - - groupsInfo := groups.ToGroupsInfo(groupsMap, peerID) + grps, _ := h.accountManager.GetPeerGroups(ctx, accountID, peerID) + groupsInfo := groups.ToGroupsInfo(grps, peerID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -128,12 +123,7 @@ func (h *Handler) updatePeer(ctx context.Context, accountID, userID, peerID stri return } - groupsMap := map[string]*types.Group{} - for _, group := range peerGroups { - groupsMap[group.ID] = group - } - - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(peerGroups, peer.ID) validPeers, err := h.accountManager.GetValidatedPeers(ctx, accountID) if err != nil { @@ -204,11 +194,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { dnsDomain := h.accountManager.GetDNSDomain() - groupsMap := map[string]*types.Group{} grps, _ := h.accountManager.GetAllGroups(r.Context(), accountID, userID) - for _, group := range grps { - groupsMap[group.ID] = group - } respBody := make([]*api.PeerBatch, 0, len(peers)) for _, peer := range peers { @@ -217,7 +203,7 @@ func (h *Handler) GetAllPeers(w http.ResponseWriter, r *http.Request) { util.WriteError(r.Context(), err, w) return } - groupMinimumInfo := groups.ToGroupsInfo(groupsMap, peer.ID) + groupMinimumInfo := groups.ToGroupsInfo(grps, peer.ID) respBody = append(respBody, toPeerListItemResponse(peerToReturn, groupMinimumInfo, dnsDomain, 0)) }