Refactor retrieval of policy and posture checks

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-09-24 21:57:33 +03:00
parent 7561706627
commit eab85644cd
6 changed files with 86 additions and 68 deletions

View File

@@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User,
func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) {
var user User
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).
First(&user, idQueryCondition, userID)
Preload(clause.Associations).First(&user, idQueryCondition, userID)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.NewUserNotFoundError(userID)
@@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength
func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations).
Where(accountAndIDQueryCondition, accountID, groupID).First(&group)
if result.Error != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string)
func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) {
var group nbgroup.Group
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}).
Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "group not found")
@@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
return &group, nil
}
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
var policies []*Policy
result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID).
Preload(clause.Associations).Find(&policies)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return policies, nil
}
func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) {
var policy *Policy
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}).
Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return policy, nil
}
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
var postureChecks []*posture.Checks
result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks)
if result.Error != nil {
return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error)
}
return postureChecks, nil
}
func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) {
var postureCheck *posture.Checks
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}).
Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck)
if err := result.Error; err != nil {
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
return nil, status.Errorf(status.NotFound, "posture checks not found")
}
return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error)
}
return postureCheck, nil
}