mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-25 03:36:41 +00:00
handle empty onboard to avoid breaking clients and dashboard
This commit is contained in:
@@ -275,41 +275,6 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
|
||||
return am.idpManager
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
|
||||
defer unlock()
|
||||
|
||||
if newOnboarding == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "new onboarding data cannot be nil")
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||
}
|
||||
if oldOnboarding.IsEqual(*newOnboarding) {
|
||||
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
|
||||
return oldOnboarding, nil
|
||||
}
|
||||
|
||||
newOnboarding.AccountID = accountID
|
||||
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
|
||||
}
|
||||
|
||||
return newOnboarding, nil
|
||||
}
|
||||
|
||||
// UpdateAccountSettings updates Account settings.
|
||||
// Only users with role UserRoleAdmin can update the account.
|
||||
// User that performs the update has to belong to the account.
|
||||
@@ -1234,7 +1199,58 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
return am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if onboarding == nil {
|
||||
onboarding = &types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
}
|
||||
}
|
||||
|
||||
return onboarding, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to validate user permissions: %w", err)
|
||||
}
|
||||
|
||||
if !allowed {
|
||||
return nil, status.NewPermissionDeniedError()
|
||||
}
|
||||
|
||||
oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID)
|
||||
if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() {
|
||||
return nil, fmt.Errorf("failed to get account onboarding: %w", err)
|
||||
}
|
||||
|
||||
if oldOnboarding == nil {
|
||||
oldOnboarding = &types.AccountOnboarding{
|
||||
AccountID: accountID,
|
||||
}
|
||||
}
|
||||
|
||||
if newOnboarding == nil {
|
||||
return oldOnboarding, nil
|
||||
}
|
||||
|
||||
if oldOnboarding.IsEqual(*newOnboarding) {
|
||||
log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID)
|
||||
return oldOnboarding, nil
|
||||
}
|
||||
|
||||
newOnboarding.AccountID = accountID
|
||||
err = am.Store.SaveAccountOnboarding(ctx, newOnboarding)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to update account onboarding: %w", err)
|
||||
}
|
||||
|
||||
return newOnboarding, nil
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) {
|
||||
|
||||
@@ -3448,15 +3448,29 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) {
|
||||
account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "")
|
||||
require.NoError(t, err)
|
||||
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||
if onboarding.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was not retrieved from the store")
|
||||
}
|
||||
t.Run("should return account onboarding when onboarding exist", func(t *testing.T) {
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
assert.Equal(t, account.Id, onboarding.AccountID)
|
||||
assert.Equal(t, true, onboarding.OnboardingFlowPending)
|
||||
assert.Equal(t, true, onboarding.SignupFormPending)
|
||||
if onboarding.UpdatedAt.IsZero() {
|
||||
t.Errorf("Onboarding was not retrieved from the store")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) {
|
||||
account.Id = "with-zero-onboarding"
|
||||
account.Onboarding = types.AccountOnboarding{}
|
||||
err = manager.Store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, onboarding)
|
||||
_, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id)
|
||||
require.Error(t, err, "should return error when onboarding is not set")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
@@ -3492,8 +3506,8 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) {
|
||||
assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending)
|
||||
})
|
||||
|
||||
t.Run("update onboarding with no change", func(t *testing.T) {
|
||||
t.Run("update onboarding with no onboarding", func(t *testing.T) {
|
||||
_, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil)
|
||||
require.Error(t, err)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -61,9 +61,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
|
||||
}, nil
|
||||
},
|
||||
UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) {
|
||||
if onboarding == nil {
|
||||
return nil, status.Errorf(status.InvalidArgument, "onboarding cannot be nil")
|
||||
}
|
||||
return &types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
SignupFormPending: true,
|
||||
@@ -194,13 +191,26 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "PutAccount fail without onboarding",
|
||||
name: "PutAccount OK without onboarding",
|
||||
expectedBody: true,
|
||||
requestType: http.MethodPut,
|
||||
requestPath: "/api/accounts/" + accountID,
|
||||
requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"),
|
||||
expectedStatus: http.StatusUnprocessableEntity,
|
||||
expectedArray: false,
|
||||
expectedStatus: http.StatusOK,
|
||||
expectedSettings: api.AccountSettings{
|
||||
PeerLoginExpiration: 15552000,
|
||||
PeerLoginExpirationEnabled: false,
|
||||
GroupsPropagationEnabled: br(false),
|
||||
JwtGroupsClaimName: sr("roles"),
|
||||
JwtGroupsEnabled: br(true),
|
||||
JwtAllowGroups: &[]string{"test"},
|
||||
RegularUsersViewBlocked: true,
|
||||
RoutingPeerDnsResolutionEnabled: br(false),
|
||||
LazyConnectionEnabled: br(false),
|
||||
DnsDomain: sr(""),
|
||||
},
|
||||
expectedArray: false,
|
||||
expectedID: accountID,
|
||||
},
|
||||
{
|
||||
name: "Update account failure with high peer_login_expiration more than 180 days",
|
||||
|
||||
@@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding
|
||||
func NewAccountOnboardingNotFoundError(accountKey string) error {
|
||||
return Errorf(NotFound, "account onboarding not found: %s", accountKey)
|
||||
}
|
||||
|
||||
// NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account
|
||||
func NewPeerNotPartOfAccountError() error {
|
||||
return Errorf(PermissionDenied, "peer is not part of this account")
|
||||
|
||||
@@ -730,12 +730,10 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (
|
||||
var accountOnboarding types.AccountOnboarding
|
||||
result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
|
||||
if errors.Is(result.Error, gorm.ErrRecordNotFound) {
|
||||
//accountOnboarding.AccountID = accountID
|
||||
//return &accountOnboarding, nil
|
||||
return nil, status.NewAccountNotFoundError(accountID)
|
||||
return nil, status.NewAccountOnboardingNotFoundError(accountID)
|
||||
}
|
||||
log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error)
|
||||
return nil, status.NewGetAccountFromStoreError(result.Error)
|
||||
}
|
||||
|
||||
@@ -744,7 +742,7 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (
|
||||
|
||||
// SaveAccountOnboarding updates the onboarding information for a specific account.
|
||||
func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error {
|
||||
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Save(onboarding)
|
||||
result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||
return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error)
|
||||
|
||||
Reference in New Issue
Block a user