diff --git a/management/server/account.go b/management/server/account.go index b3228f83e..b2f7b4ce7 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1052,39 +1052,21 @@ func BuildManager( metrics: metrics, requestBuffer: NewAccountRequestBuffer(ctx, store), } - allAccounts := store.GetAllAccounts(ctx) + allAccountIDs, err := store.GetAllAccountIDs(ctx, LockingStrengthShare) + if err != nil { + return nil, err + } + // enable single account mode only if configured by user and number of existing accounts is not grater than 1 - am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 + am.singleAccountMode = singleAccountModeDomain != "" && len(allAccountIDs) <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccountIDs)) } else { - log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) - } - - // if account doesn't have a default group - // we create 'all' group and add all peers into it - // also we create default rule with source as destination - for _, account := range allAccounts { - shouldSave := false - - _, err := account.GetGroupAll() - if err != nil { - if err := addAllGroup(account); err != nil { - return nil, err - } - shouldSave = true - } - - if shouldSave { - err = store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccountIDs)) } goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) @@ -1290,19 +1272,18 @@ func (am *DefaultAccountManager) newAccount(ctx context.Context, userID, domain for i := 0; i < 2; i++ { accountId := xid.New().String() - _, err := am.Store.GetAccount(ctx, accountId) - statusErr, _ := status.FromError(err) - switch { - case err == nil: - log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") - continue - case statusErr.Type() == status.NotFound: + exists, err := am.Store.AccountExists(ctx, LockingStrengthShare, accountId) + if err != nil { + log.WithContext(ctx).Errorf("error while checking account existence: %v", err) + return nil, err + } + + if !exists { newAccount := newAccountWithId(ctx, accountId, userID, domain) am.StoreEvent(ctx, userID, newAccount.Id, accountId, activity.AccountCreated, nil) return newAccount, nil - default: - return nil, err } + log.WithContext(ctx).Warnf("an account with ID already exists, retrying...") } return nil, status.Errorf(status.Internal, "error while creating new account") @@ -1321,16 +1302,16 @@ func (am *DefaultAccountManager) warmupIDPCache(ctx context.Context) error { // update their AppMetadata with the AccountID. if unsetData, ok := userData[idp.UnsetAccountID]; ok { for _, user := range unsetData { - accountID, err := am.Store.GetAccountByUser(ctx, user.ID) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, user.ID) if err == nil { - data := userData[accountID.Id] + data := userData[userAccountID] if data == nil { data = make([]*idp.UserData, 0, 1) } - user.AppMetadata.WTAccountID = accountID.Id + user.AppMetadata.WTAccountID = userAccountID - userData[accountID.Id] = append(data, user) + userData[userAccountID] = append(data, user) } } } @@ -1416,7 +1397,7 @@ func (am *DefaultAccountManager) GetAccountIDByUserID(ctx context.Context, userI return "", status.Errorf(status.NotFound, "no valid userID provided") } - accountID, err := am.Store.GetAccountIDByUserID(userID) + accountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) if err != nil { if s, ok := status.FromError(err); ok && s.Type() == status.NotFound { account, err := am.GetOrCreateAccountByUser(ctx, userID, domain) @@ -1696,9 +1677,6 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx return nil } - unlockAccount := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlockAccount() - accountDomain, domainCategory, err := am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, accountID) if err != nil { log.WithContext(ctx).Errorf("error getting account domain and category: %v", err) @@ -1716,7 +1694,7 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx } newDomain := accountDomain - newCategoty := domainCategory + newCategory := domainCategory lowerDomain := strings.ToLower(claims.Domain) if accountDomain != lowerDomain && user.HasAdminPower() { @@ -1724,10 +1702,10 @@ func (am *DefaultAccountManager) updateAccountDomainAttributesIfNotUpToDate(ctx } if accountDomain == lowerDomain { - newCategoty = claims.DomainCategory + newCategory = claims.DomainCategory } - return am.Store.UpdateAccountDomainAttributes(ctx, accountID, newDomain, newCategoty, primaryDomain) + return am.Store.UpdateAccountDomainAttributes(ctx, LockingStrengthUpdate, accountID, newDomain, newCategory, primaryDomain) } // handleExistingUserAccount handles existing User accounts and update its domain attributes. @@ -2163,7 +2141,7 @@ func (am *DefaultAccountManager) getAccountIDWithAuthorizationClaims(ctx context return "", err } - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if handleNotFound(err) != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err @@ -2209,7 +2187,7 @@ func (am *DefaultAccountManager) getPrivateDomainWithGlobalLock(ctx context.Cont } func (am *DefaultAccountManager) handlePrivateAccountWithIDFromClaim(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, error) { - userAccountID, err := am.Store.GetAccountIDByUserID(claims.UserId) + userAccountID, err := am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, claims.UserId) if err != nil { log.WithContext(ctx).Errorf("error getting account ID by user ID: %v", err) return "", err diff --git a/management/server/peer.go b/management/server/peer.go index 0455cf719..9f3fc8b0c 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -411,7 +411,7 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s addedByUser := false if len(userID) > 0 { addedByUser = true - accountID, err = am.Store.GetAccountIDByUserID(userID) + accountID, err = am.Store.GetAccountIDByUserID(ctx, LockingStrengthShare, userID) } else { accountID, err = am.Store.GetAccountIDBySetupKey(ctx, encodedHashedKey) } diff --git a/management/server/sql_store.go b/management/server/sql_store.go index f10e2f8ff..649638d81 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -324,7 +324,7 @@ func (s *SqlStore) SavePeer(ctx context.Context, lockStrength LockingStrength, a return nil } -func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error { +func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error { accountCopy := Account{ Domain: domain, DomainCategory: category, @@ -332,7 +332,7 @@ func (s *SqlStore) UpdateAccountDomainAttributes(ctx context.Context, accountID } fieldsToUpdate := []string{"domain", "domain_category", "is_domain_primary_account"} - result := s.db.WithContext(ctx).Model(&Account{}). + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}). Select(fieldsToUpdate). Where(idQueryCondition, accountID). Updates(&accountCopy) @@ -563,6 +563,18 @@ func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*Account) { return all } +func (s *SqlStore) GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error) { + var accountIDs []string + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}). + Model(&Account{}).Pluck("id", &accountIDs) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get account IDs from the store: %s", result.Error) + return nil, status.Errorf(status.Internal, "failed to get account IDs from store") + } + + return accountIDs, nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*Account, error) { start := time.Now() defer func() { @@ -704,14 +716,15 @@ func (s *SqlStore) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) return accountID, nil } -func (s *SqlStore) GetAccountIDByUserID(userID string) (string, error) { +func (s *SqlStore) GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) { var accountID string - result := s.db.Model(&User{}).Select("account_id").Where(idQueryCondition, userID).First(&accountID) + result := s.db.WithContext(ctx).Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&User{}). + Select("account_id").Where(idQueryCondition, userID).First(&accountID) if result.Error != nil { if errors.Is(result.Error, gorm.ErrRecordNotFound) { return "", status.NewAccountNotFoundError() } - + log.WithContext(ctx).Errorf("failed to get accountID from the store: %s", result.Error) return "", status.NewGetAccountFromStoreError(result.Error) } diff --git a/management/server/store.go b/management/server/store.go index f7d5e9348..141e98abf 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -47,8 +47,9 @@ type Store interface { GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*Account, error) GetAccountByPeerPubKey(ctx context.Context, peerKey string) (*Account, error) + GetAllAccountIDs(ctx context.Context, lockStrength LockingStrength) ([]string, error) GetAccountIDByPeerPubKey(ctx context.Context, peerKey string) (string, error) - GetAccountIDByUserID(userID string) (string, error) + GetAccountIDByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (string, error) GetAccountIDBySetupKey(ctx context.Context, peerKey string) (string, error) GetAccountIDByPeerID(ctx context.Context, lockStrength LockingStrength, peerID string) (string, error) GetAccountByPeerID(ctx context.Context, peerID string) (*Account, error) @@ -62,7 +63,7 @@ type Store interface { SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *Settings) error SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error - UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error + UpdateAccountDomainAttributes(ctx context.Context, lockStrength LockingStrength, accountID string, domain string, category string, isPrimaryDomain bool) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*User, error)