diff --git a/management/server/file_store.go b/management/server/file_store.go index a18e0e539..7b766a2e3 100644 --- a/management/server/file_store.go +++ b/management/server/file_store.go @@ -986,6 +986,24 @@ func (s *FileStore) UpdateAccount(_ context.Context, _ LockingStrength, _ *Accou func (s *FileStore) GetGroupByID(_ context.Context, _, _ string) (*nbgroup.Group, error) { return nil, status.Errorf(status.Internal, "GetGroupByID is not implemented") } + func (s *FileStore) GetGroupByName(_ context.Context, _ LockingStrength, _, _ string) (*nbgroup.Group, error) { return nil, status.Errorf(status.Internal, "GetGroupByName is not implemented") } + +func (s *FileStore) GetAccountPolicies(_ context.Context, _ string) ([]*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") +} + +func (s *FileStore) GetPolicyByID(_ context.Context, _ LockingStrength, _ string, _ string) (*Policy, error) { + return nil, status.Errorf(status.Internal, "GetPolicyByID is not implemented") + +} + +func (s *FileStore) GetAccountPostureChecks(_ context.Context, _ string) ([]*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetAccountPostureChecks is not implemented") +} + +func (s *FileStore) GetPostureChecksByID(_ context.Context, _ LockingStrength, _ string, _ string) (*posture.Checks, error) { + return nil, status.Errorf(status.Internal, "GetPostureChecksByID is not implemented") +} diff --git a/management/server/http/policies_handler.go b/management/server/http/policies_handler.go index 1b0992cdd..5b6e1121c 100644 --- a/management/server/http/policies_handler.go +++ b/management/server/http/policies_handler.go @@ -3,7 +3,6 @@ package http import ( "encoding/json" "net/http" - "slices" "strconv" "github.com/gorilla/mux" @@ -84,18 +83,12 @@ func (h *Policies) UpdatePolicy(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, userID) + _, err = h.accountManager.GetPolicy(r.Context(), accountID, policyID, userID) if err != nil { util.WriteError(r.Context(), err, w) return } - policyIdx := slices.IndexFunc(account.Policies, func(policy *server.Policy) bool { return policy.ID == policyID }) - if policyIdx < 0 { - util.WriteError(r.Context(), status.Errorf(status.NotFound, "couldn't find policy id %s", policyID), w) - return - } - h.savePolicy(w, r, accountID, userID, policyID) } diff --git a/management/server/policy.go b/management/server/policy.go index 833f97d39..204d719c1 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -315,30 +315,16 @@ func (a *Account) connResourcesGenerator(ctx context.Context) (func(*PolicyRule, // GetPolicy from the store func (am *DefaultAccountManager) GetPolicy(ctx context.Context, accountID, policyID, userID string) (*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !(user.HasAdminPower() || user.IsServiceUser) { + if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - for _, policy := range account.Policies { - if policy.ID == policyID { - return policy, nil - } - } - - return nil, status.Errorf(status.NotFound, "policy with ID %s not found", policyID) + return am.Store.GetPolicyByID(ctx, LockingStrengthShare, policyID, accountID) } // SavePolicy in the store @@ -400,24 +386,16 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po // ListPolicies from the store func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, userID string) ([]*Policy, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err + if (!user.HasAdminPower() && !user.IsServiceUser) || user.AccountID != accountID { + return nil, status.Errorf(status.PermissionDenied, "only users with admin power are allowed to view policies") } - if !(user.HasAdminPower() || user.IsServiceUser) { - return nil, status.Errorf(status.PermissionDenied, "only users with admin power can view policies") - } - - return account.Policies, nil + return am.Store.GetAccountPolicies(ctx, accountID) } func (am *DefaultAccountManager) deletePolicy(account *Account, policyID string) (*Policy, error) { diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index 4180550e6..7a03effb1 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -15,30 +15,16 @@ const ( ) func (am *DefaultAccountManager) GetPostureChecks(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - for _, postureChecks := range account.PostureChecks { - if postureChecks.ID == postureChecksID { - return postureChecks, nil - } - } - - return nil, status.Errorf(status.NotFound, "posture checks with ID %s not found", postureChecksID) + return am.Store.GetPostureChecksByID(ctx, LockingStrengthShare, postureChecksID, accountID) } func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountID, userID string, postureChecks *posture.Checks) error { @@ -121,24 +107,16 @@ func (am *DefaultAccountManager) DeletePostureChecks(ctx context.Context, accoun } func (am *DefaultAccountManager) ListPostureChecks(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - account, err := am.Store.GetAccount(ctx, accountID) + user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID) if err != nil { return nil, err } - user, err := account.FindUser(userID) - if err != nil { - return nil, err - } - - if !user.HasAdminPower() { + if !user.HasAdminPower() || user.AccountID != accountID { return nil, status.Errorf(status.PermissionDenied, errMsgPostureAdminOnly) } - return account.PostureChecks, nil + return am.Store.GetAccountPostureChecks(ctx, accountID) } func (am *DefaultAccountManager) savePostureChecks(account *Account, postureChecks *posture.Checks) (exists, uniqName bool) { diff --git a/management/server/sql_store.go b/management/server/sql_store.go index d843e6f1d..5094c589b 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -489,7 +489,7 @@ func (s *SqlStore) GetUserByTokenID(ctx context.Context, tokenID string) (*User, func (s *SqlStore) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error) { var user User result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). - First(&user, idQueryCondition, userID) + Preload(clause.Associations).First(&user, idQueryCondition, userID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.NewUserNotFoundError(userID) @@ -1095,7 +1095,8 @@ func (s *SqlStore) GetAccountDomainAndCategory(ctx context.Context, lockStrength func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group - result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Where(accountAndIDQueryCondition, accountID, groupID).First(&group) + result := s.db.WithContext(ctx).Model(&nbgroup.Group{}).Preload(clause.Associations). + Where(accountAndIDQueryCondition, accountID, groupID).First(&group) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") @@ -1109,7 +1110,7 @@ func (s *SqlStore) GetGroupByID(ctx context.Context, groupID, accountID string) func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) { var group nbgroup.Group result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&nbgroup.Group{}). - Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group) + Preload(clause.Associations).Where("name = ? and account_id = ?", groupName, accountID).Order("json_array_length(peers) DESC").First(&group) if err := result.Error; err != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return nil, status.Errorf(status.NotFound, "group not found") @@ -1118,3 +1119,48 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren } return &group, nil } + +func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) { + var policies []*Policy + result := s.db.WithContext(ctx).Model(&Policy{}).Where(accountIDCondition, accountID). + Preload(clause.Associations).Find(&policies) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error) + } + return policies, nil +} + +func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) { + var policy *Policy + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Policy{}). + Preload(clause.Associations).Where(accountAndIDQueryCondition, accountID, policyID).First(&policy) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "posture checks not found") + } + return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error) + } + return policy, nil +} + +func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) { + var postureChecks []*posture.Checks + result := s.db.WithContext(ctx).Model(&posture.Checks{}).Where(accountIDCondition, accountID).Find(&postureChecks) + if result.Error != nil { + return nil, status.Errorf(status.Internal, "failed to get account posture checks: %v", result.Error) + } + return postureChecks, nil +} + +func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) { + var postureCheck *posture.Checks + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&posture.Checks{}). + Where(accountAndIDQueryCondition, accountID, postureCheckID).First(&postureCheck) + if err := result.Error; err != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.Errorf(status.NotFound, "posture checks not found") + } + return nil, status.Errorf(status.Internal, "failed to get posture checks from store: %s", result.Error) + } + return postureCheck, nil +} diff --git a/management/server/store.go b/management/server/store.go index 73e68531c..601e173e2 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -68,7 +68,12 @@ type Store interface { GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*nbgroup.Group, error) SaveGroups(accountID string, groups map[string]*nbgroup.Group) error + GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) + GetPolicyByID(ctx context.Context, lockStrength LockingStrength, policyID string, accountID string) (*Policy, error) + GetPostureCheckByChecksDefinition(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) + GetPostureChecksByID(ctx context.Context, lockStrength LockingStrength, postureCheckID string, accountID string) (*posture.Checks, error) GetPeerLabelsInAccount(ctx context.Context, lockStrength LockingStrength, accountId string) ([]string, error) AddPeerToAllGroup(ctx context.Context, accountID string, peerID string) error