diff --git a/management/server/account.go b/management/server/account.go index 8c91afe53..ef58f6545 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1048,39 +1048,21 @@ func BuildManager( metrics: metrics, requestBuffer: NewAccountRequestBuffer(ctx, store), } - allAccounts := store.GetAllAccounts(ctx) + totalAccounts, err := store.GetTotalAccounts(ctx, LockingStrengthShare) + if err != nil { + return nil, err + } + // enable single account mode only if configured by user and number of existing accounts is not grater than 1 - am.singleAccountMode = singleAccountModeDomain != "" && len(allAccounts) <= 1 + am.singleAccountMode = singleAccountModeDomain != "" && totalAccounts <= 1 if am.singleAccountMode { if !isDomainValid(singleAccountModeDomain) { return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for a single account mode. Please review your input for --single-account-mode-domain", singleAccountModeDomain) } am.singleAccountModeDomain = singleAccountModeDomain - log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", len(allAccounts)) + log.WithContext(ctx).Infof("single account mode enabled, accounts number %d", totalAccounts) } else { - log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", len(allAccounts)) - } - - // if account doesn't have a default group - // we create 'all' group and add all peers into it - // also we create default rule with source as destination - for _, account := range allAccounts { - shouldSave := false - - _, err := account.GetGroupAll() - if err != nil { - if err := addAllGroup(account); err != nil { - return nil, err - } - shouldSave = true - } - - if shouldSave { - err = store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - } + log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", totalAccounts) } goCacheClient := gocache.New(CacheExpirationMax, 30*time.Minute) diff --git a/management/server/sql_store.go b/management/server/sql_store.go index e312d4c40..48fc69421 100644 --- a/management/server/sql_store.go +++ b/management/server/sql_store.go @@ -877,6 +877,17 @@ func (s *SqlStore) GetAccountCreatedBy(ctx context.Context, lockStrength Locking return createdBy, nil } +func (s *SqlStore) GetTotalAccounts(ctx context.Context, lockStrength LockingStrength) (int64, error) { + var count int64 + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&Account{}).Count(&count) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to get total accounts from store: %s", result.Error) + return 0, status.Errorf(status.Internal, "failed to get total accounts from store") + } + + return count, nil +} + // SaveUserLastLogin stores the last login time for a user in DB. func (s *SqlStore) SaveUserLastLogin(ctx context.Context, accountID, userID string, lastLogin time.Time) error { var user User diff --git a/management/server/sql_store_test.go b/management/server/sql_store_test.go index 8b717b80b..36c6eac32 100644 --- a/management/server/sql_store_test.go +++ b/management/server/sql_store_test.go @@ -2662,3 +2662,13 @@ func TestSqlStore_SaveAccountSettings(t *testing.T) { require.NoError(t, err) require.Equal(t, settings, saveSettings) } + +func TestSqlStore_GetTotalAccounts(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + totalAccounts, err := store.GetTotalAccounts(context.Background(), LockingStrengthShare) + require.NoError(t, err) + require.Equal(t, int64(1), totalAccounts) +} diff --git a/management/server/store.go b/management/server/store.go index 2b265b2fd..c64561ead 100644 --- a/management/server/store.go +++ b/management/server/store.go @@ -58,6 +58,7 @@ type Store interface { GetAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*Settings, error) GetAccountDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string) (*DNSSettings, error) GetAccountCreatedBy(ctx context.Context, lockStrength LockingStrength, accountID string) (string, error) + GetTotalAccounts(ctx context.Context, lockStrength LockingStrength) (int64, error) SaveAccount(ctx context.Context, account *Account) error DeleteAccount(ctx context.Context, account *Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error