diff --git a/management/server/file_store.go b/management/server/file_store.go index be4c6ec16..316feb867 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -982,3 +982,7 @@ func (s *FileStore) AccountExists(_ context.Context, id string) (bool, error) { func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Account) error { return nil } + +func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { + return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") +} diff --git a/management/server/group.go b/management/server/group.go index 3f69c52ae..9343f2dd2 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -62,32 +62,7 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID str // GetGroupByName filters all groups in an account by name and returns the one with the most peers func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*nbgroup.Group, error) { - groups, err := am.Store.GetAccountGroups(ctx, accountID) - if err != nil { - return nil, err - } - - matchingGroups := make([]*nbgroup.Group, 0) - for _, group := range groups { - if group.Name == groupName { - matchingGroups = append(matchingGroups, group) - } - } - - if len(matchingGroups) == 0 { - return nil, status.Errorf(status.NotFound, "group with name %s not found", groupName) - } - - maxPeers := -1 - var groupWithMostPeers *nbgroup.Group - for i, group := range matchingGroups { - if len(group.Peers) > maxPeers { - maxPeers = len(group.Peers) - groupWithMostPeers = matchingGroups[i] - } - } - - return groupWithMostPeers, nil + return am.Store.GetGroupByName(ctx, LockingStrengthShare, groupName, accountID) } // SaveGroup object of the peers diff --git a/management/server/sql_store.go b/management/server/sql_store.go index 58b258404..b76846c9f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1092,3 +1092,17 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength return account.Domain, account.DomainCategory, nil } + +// GetGroupByName retrieves a group by name and account ID. +func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { + var group nbgroup.Group + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}). + Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "group not found") + } + return nil, status.Errorf(status.Internal, "failed to retrieve group fields") + } + return &group, nil +} diff --git a/management/server/store.go b/management/server/store.go index 8f00f62d6..10a52db98 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -64,6 +64,7 @@ type Store interface { DeleteTokenID2UserIDIndex(tokenID string) error GetAccountGroups(ctx context.Context, accountID string) ([]*nbgroup.Group, error) + GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error)