add store lock

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-09-26 18:46:23 +03:00
parent 871595d15f
commit 4575ae2841
9 changed files with 34 additions and 33 deletions

View File

@@ -1054,10 +1054,11 @@ func (s *SqlStore) GetAccountDNSSettings(ctx context.Context, lockStrength Locki
}
// AccountExists checks whether an account exists by the given ID.
func (s *SqlStore) AccountExists(ctx context.Context, id string) (bool, error) {
func (s *SqlStore) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) {
var count int64
result := s.db.WithContext(ctx).Model(&Account{}).Where(idQueryCondition, id).Count(&count)
result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).
Where(idQueryCondition, id).Count(&count)
if result.Error != nil {
return false, result.Error
}
@@ -1101,8 +1102,8 @@ func (s *SqlStore) GetGroupByName(ctx context.Context, lockStrength LockingStren
}
// GetAccountPolicies retrieves policies for an account.
func (s *SqlStore) GetAccountPolicies(ctx context.Context, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), accountID)
func (s *SqlStore) GetAccountPolicies(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*Policy, error) {
return getRecords[*Policy](s.db.WithContext(ctx).Preload(clause.Associations), lockStrength, accountID)
}
// GetPolicyByID retrieves a policy by its ID and account ID.
@@ -1111,8 +1112,8 @@ func (s *SqlStore) GetPolicyByID(ctx context.Context, lockStrength LockingStreng
}
// GetAccountPostureChecks retrieves posture checks for an account.
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountPostureChecks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*posture.Checks, error) {
return getRecords[*posture.Checks](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetPostureChecksByID retrieves posture checks by their ID and account ID.
@@ -1121,8 +1122,8 @@ func (s *SqlStore) GetPostureChecksByID(ctx context.Context, lockStrength Lockin
}
// GetAccountRoutes retrieves network routes for an account.
func (s *SqlStore) GetAccountRoutes(ctx context.Context, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountRoutes(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*route.Route, error) {
return getRecords[*route.Route](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetRouteByID retrieves a route by its ID and account ID.
@@ -1131,8 +1132,8 @@ func (s *SqlStore) GetRouteByID(ctx context.Context, lockStrength LockingStrengt
}
// GetAccountSetupKeys retrieves setup keys for an account.
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountSetupKeys(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*SetupKey, error) {
return getRecords[*SetupKey](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetSetupKeyByID retrieves a setup key by its ID and account ID.
@@ -1141,8 +1142,8 @@ func (s *SqlStore) GetSetupKeyByID(ctx context.Context, lockStrength LockingStre
}
// GetAccountNameServerGroups retrieves name server groups for an account.
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), accountID)
func (s *SqlStore) GetAccountNameServerGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*nbdns.NameServerGroup, error) {
return getRecords[*nbdns.NameServerGroup](s.db.WithContext(ctx), lockStrength, accountID)
}
// GetNameServerGroupByID retrieves a name server group by its ID and account ID.
@@ -1151,10 +1152,10 @@ func (s *SqlStore) GetNameServerGroupByID(ctx context.Context, lockStrength Lock
}
// getRecords retrieves records from the database based on the account ID.
func getRecords[T any](db *gorm.DB, accountID string) ([]T, error) {
func getRecords[T any](db *gorm.DB, lockStrength LockingStrength, accountID string) ([]T, error) {
var record []T
result := db.Find(&record, accountIDCondition, accountID)
result := db.Clauses(clause.Locking{Strength: string(lockStrength)}).Find(&record, accountIDCondition, accountID)
if err := result.Error; err != nil {
parts := strings.Split(fmt.Sprintf("%T", record), ".")
recordType := parts[len(parts)-1]