diff --git a/management/server/policy.go b/management/server/policy.go index eb44a0436..6dcb96316 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -521,12 +521,12 @@ func validatePolicy(ctx context.Context, transaction Store, accountID string, po policy.AccountID = accountID } - groups, err := transaction.GetAccountGroups(ctx, LockingStrengthShare, accountID) + groups, err := transaction.GetGroupsByIDs(ctx, LockingStrengthShare, accountID, policy.ruleGroups()) if err != nil { return err } - postureChecks, err := transaction.GetAccountPostureChecks(ctx, LockingStrengthShare, accountID) + postureChecks, err := transaction.GetPostureChecksByIDs(ctx, LockingStrengthShare, accountID, policy.SourcePostureChecks) if err != nil { return err } @@ -629,15 +629,10 @@ func (a *Account) getPostureChecks(postureChecksID string) *posture.Checks { } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. -func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds []string) []string { - validPostureCheckIDs := make(map[string]struct{}) - for _, check := range postureChecks { - validPostureCheckIDs[check.ID] = struct{}{} - } - +func getValidPostureCheckIDs(postureChecks map[string]*posture.Checks, postureChecksIds []string) []string { validIDs := make([]string, 0, len(postureChecksIds)) for _, id := range postureChecksIds { - if _, exists := validPostureCheckIDs[id]; exists { + if _, exists := postureChecks[id]; exists { validIDs = append(validIDs, id) } } @@ -646,15 +641,10 @@ func getValidPostureCheckIDs(postureChecks []*posture.Checks, postureChecksIds [ } // getValidGroupIDs filters and returns only the valid group IDs from the provided list. -func getValidGroupIDs(groups []*nbgroup.Group, groupIDs []string) []string { - validGroupIDs := make(map[string]struct{}) - for _, group := range groups { - validGroupIDs[group.ID] = struct{}{} - } - +func getValidGroupIDs(groups map[string]*nbgroup.Group, groupIDs []string) []string { validIDs := make([]string, 0, len(groupIDs)) for _, id := range groupIDs { - if _, exists := validGroupIDs[id]; exists { + if _, exists := groups[id]; exists { validIDs = append(validIDs, id) } } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index e7a2e50d8..a4191de9f 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -1234,8 +1234,8 @@ func (s *SqlStore) GetGroupsByIDs(ctx context.Context, lockStrength LockingStren var groups []*nbgroup.Group result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&groups, accountAndIDsQueryCondition, accountID, groupIDs) if result.Error != nil { - log.WithContext(ctx).Errorf("failed to get groups by ID's from the store: %s", result.Error) - return nil, status.Errorf(status.Internal, "failed to get groups by ID's from the store") + log.WithContext(ctx).Errorf("failed to get groups by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get groups by ID's from store") } groupsMap := make(map[string]*nbgroup.Group) @@ -1377,6 +1377,23 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin return postureCheck, nil } +// GetPostureChecksByIDs retrieves posture checks by their IDs and account ID. +func (s *SqlStore) GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&postureChecks, accountAndIDsQueryCondition, accountID, postureChecksIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get posture checks by ID's from store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get posture checks by ID's from store") + } + + postureChecksMap := make(map[string]*posture.Checks) + for _, postureCheck := range postureChecks { + postureChecksMap[postureCheck.ID] = postureCheck + } + + return postureChecksMap, nil +} + // SavePostureChecks saves a posture checks to the database. func (s *SqlStore) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error { result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Save(postureCheck) diff --git a/management/server/store.go b/management/server/store.go index 108b262b1..ba61d552d 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -88,6 +88,7 @@ type Store interface { GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, accountID, postureCheckID string) (*posture.Checks, error) + GetPostureChecksByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, postureChecksIDs []string) (map[string]*posture.Checks, error) SavePostureChecks(ctx context.Context, lockStrength LockingStrength, postureCheck *posture.Checks) error DeletePostureChecks(ctx context.Context, lockStrength LockingStrength, accountID, postureChecksID string) error