diff --git a/management/server/account.go b/management/server/account.go index 4f3220de0..e57d7092b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -275,41 +275,6 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { return am.idpManager } -func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - if newOnboarding == nil { - return nil, status.Errorf(status.InvalidArgument, "new onboarding data cannot be nil") - } - - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("failed to get account onboarding: %w", err) - } - if oldOnboarding.IsEqual(*newOnboarding) { - log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) - return oldOnboarding, nil - } - - newOnboarding.AccountID = accountID - err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) - if err != nil { - return nil, fmt.Errorf("failed to update account onboarding: %w", err) - } - - return newOnboarding, nil -} - // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. @@ -1234,7 +1199,58 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountOnboarding(ctx, accountID) + onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + return nil, err + } + + if onboarding == nil { + onboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + return onboarding, nil +} + +func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + return nil, fmt.Errorf("failed to get account onboarding: %w", err) + } + + if oldOnboarding == nil { + oldOnboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + if newOnboarding == nil { + return oldOnboarding, nil + } + + if oldOnboarding.IsEqual(*newOnboarding) { + log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) + return oldOnboarding, nil + } + + newOnboarding.AccountID = accountID + err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) + if err != nil { + return nil, fmt.Errorf("failed to update account onboarding: %w", err) + } + + return newOnboarding, nil } func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index dc21cf71f..137179fa0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3448,15 +3448,29 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") require.NoError(t, err) - onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) - require.NoError(t, err) - require.NotNil(t, onboarding) - assert.Equal(t, account.Id, onboarding.AccountID) - assert.Equal(t, true, onboarding.OnboardingFlowPending) - assert.Equal(t, true, onboarding.SignupFormPending) - if onboarding.UpdatedAt.IsZero() { - t.Errorf("Onboarding was not retrieved from the store") - } + t.Run("should return account onboarding when onboarding exist", func(t *testing.T) { + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + assert.Equal(t, account.Id, onboarding.AccountID) + assert.Equal(t, true, onboarding.OnboardingFlowPending) + assert.Equal(t, true, onboarding.SignupFormPending) + if onboarding.UpdatedAt.IsZero() { + t.Errorf("Onboarding was not retrieved from the store") + } + }) + + t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) { + account.Id = "with-zero-onboarding" + account.Onboarding = types.AccountOnboarding{} + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + _, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "should return error when onboarding is not set") + }) } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { @@ -3492,8 +3506,8 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) }) - t.Run("update onboarding with no change", func(t *testing.T) { + t.Run("update onboarding with no onboarding", func(t *testing.T) { _, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil) - require.Error(t, err) + require.NoError(t, err) }) } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index fa901677b..dbf0c22bc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -61,9 +61,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { }, nil }, UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { - if onboarding == nil { - return nil, status.Errorf(status.InvalidArgument, "onboarding cannot be nil") - } return &types.AccountOnboarding{ OnboardingFlowPending: true, SignupFormPending: true, @@ -194,13 +191,26 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedID: accountID, }, { - name: "PutAccount fail without onboarding", + name: "PutAccount OK without onboarding", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), - expectedStatus: http.StatusUnprocessableEntity, - expectedArray: false, + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + }, + expectedArray: false, + expectedID: accountID, }, { name: "Update account failure with high peer_login_expiration more than 180 days", diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a6f6d1a7..47c236e93 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding +func NewAccountOnboardingNotFoundError(accountKey string) error { + return Errorf(NotFound, "account onboarding not found: %s", accountKey) +} + // NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account func NewPeerNotPartOfAccountError() error { return Errorf(PermissionDenied, "peer is not part of this account") diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 9e33a6972..c24ebf6ec 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -730,12 +730,10 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) ( var accountOnboarding types.AccountOnboarding result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) if result.Error != nil { - log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - //accountOnboarding.AccountID = accountID - //return &accountOnboarding, nil - return nil, status.NewAccountNotFoundError(accountID) + return nil, status.NewAccountOnboardingNotFoundError(accountID) } + log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -744,7 +742,7 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) ( // SaveAccountOnboarding updates the onboarding information for a specific account. func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error { - result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Save(onboarding) + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding) if result.Error != nil { log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)