mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 23:26:41 +00:00
fix rollback on account id returning empty
This commit is contained in:
8
management/server/account/pat.go
Normal file
8
management/server/account/pat.go
Normal file
@@ -0,0 +1,8 @@
|
||||
package account
|
||||
|
||||
const (
|
||||
// PATMinExpireDays is the minimum allowed Personal Access Token expiration in days.
|
||||
PATMinExpireDays = 1
|
||||
// PATMaxExpireDays is the maximum allowed Personal Access Token expiration in days.
|
||||
PATMaxExpireDays = 365
|
||||
)
|
||||
@@ -86,6 +86,7 @@ func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
|
||||
resp.PersonalAccessToken = &result.PATPlainToken
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "no-store")
|
||||
util.WriteJSONObject(ctx, w, resp)
|
||||
}
|
||||
|
||||
|
||||
@@ -179,6 +179,7 @@ func TestSetup_Success(t *testing.T) {
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
|
||||
|
||||
var response api.SetupResponse
|
||||
err := json.NewDecoder(rec.Body).Decode(&response)
|
||||
@@ -320,6 +321,7 @@ func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@@ -338,6 +340,7 @@ func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@@ -377,6 +380,7 @@ func TestSetup_PAT_MissingExpireIn_DefaultsToOneDay(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@@ -397,6 +401,7 @@ func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@@ -438,11 +443,13 @@ func TestSetup_PAT_Success(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, rec.Code)
|
||||
assert.Equal(t, "no-store", rec.Header().Get("Cache-Control"))
|
||||
var response api.SetupResponse
|
||||
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
|
||||
assert.Equal(t, "owner-id", response.UserId)
|
||||
@@ -454,6 +461,10 @@ func TestSetup_PAT_Success(t *testing.T) {
|
||||
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
|
||||
t.Setenv(nbinstance.SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("", status.Errorf(status.NotFound, "account not found"))
|
||||
|
||||
rolledBackFor := ""
|
||||
manager := &mockInstanceManager{
|
||||
isSetupRequired: true,
|
||||
@@ -469,12 +480,16 @@ func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
|
||||
return "", errors.New("db down")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
}
|
||||
|
||||
router := setupTestRouterWithPAT(manager, accountMgr)
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
@@ -519,6 +534,7 @@ func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
|
||||
|
||||
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
router.ServeHTTP(rec, req)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/store"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
)
|
||||
@@ -20,13 +21,6 @@ const (
|
||||
SetupPATEnabledEnvKey = "NB_SETUP_PAT_ENABLED"
|
||||
|
||||
setupPATDefaultExpireDays = 1
|
||||
|
||||
// patMinExpireDays and patMaxExpireDays mirror the bounds enforced by
|
||||
// DefaultAccountManager.CreatePAT in management/server/user.go. They are
|
||||
// duplicated here so /api/setup can reject invalid input before it creates
|
||||
// the embedded-IdP user.
|
||||
patMinExpireDays = 1
|
||||
patMaxExpireDays = 365
|
||||
)
|
||||
|
||||
// SetupOptions controls optional work performed during initial instance setup.
|
||||
@@ -79,8 +73,8 @@ func normalizeSetupOptions(opts SetupOptions, setupPATEnabled bool) (SetupOption
|
||||
opts.PATExpireInDays = &defaultExpireInDays
|
||||
}
|
||||
|
||||
if *opts.PATExpireInDays < patMinExpireDays || *opts.PATExpireInDays > patMaxExpireDays {
|
||||
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", patMinExpireDays, patMaxExpireDays)
|
||||
if *opts.PATExpireInDays < account.PATMinExpireDays || *opts.PATExpireInDays > account.PATMaxExpireDays {
|
||||
return opts, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
|
||||
}
|
||||
|
||||
return opts, nil
|
||||
@@ -96,6 +90,10 @@ func (m *SetupService) SetupOwner(ctx context.Context, email, password, name str
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.CreatePAT && m.accountManager == nil {
|
||||
return nil, fmt.Errorf("account manager is required to create setup PAT")
|
||||
}
|
||||
|
||||
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -106,12 +104,6 @@ func (m *SetupService) SetupOwner(ctx context.Context, email, password, name str
|
||||
return result, nil
|
||||
}
|
||||
|
||||
if m.accountManager == nil {
|
||||
err := fmt.Errorf("account manager is required to create setup PAT")
|
||||
m.rollbackSetup(ctx, userData.ID, "setup PAT requested without account manager", err, "")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userAuth := auth.UserAuth{
|
||||
UserId: userData.ID,
|
||||
Email: userData.Email,
|
||||
@@ -137,6 +129,14 @@ func (m *SetupService) SetupOwner(ctx context.Context, email, password, name str
|
||||
}
|
||||
|
||||
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) {
|
||||
if accountID == "" {
|
||||
resolvedAccountID, err := m.lookupSetupAccountIDForRollback(ctx, userID)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to resolve setup account for user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, err)
|
||||
}
|
||||
accountID = resolvedAccountID
|
||||
}
|
||||
|
||||
if accountID != "" {
|
||||
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, err)
|
||||
@@ -152,6 +152,27 @@ func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string,
|
||||
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
|
||||
}
|
||||
|
||||
func (m *SetupService) lookupSetupAccountIDForRollback(ctx context.Context, userID string) (string, error) {
|
||||
if m.accountManager == nil {
|
||||
return "", fmt.Errorf("account manager is required to resolve setup account")
|
||||
}
|
||||
|
||||
accountStore := m.accountManager.GetStore()
|
||||
if accountStore == nil {
|
||||
return "", fmt.Errorf("account store is unavailable")
|
||||
}
|
||||
|
||||
accountID, err := accountStore.GetAccountIDByUserID(ctx, store.LockingStrengthNone, userID)
|
||||
if err != nil {
|
||||
if isNotFoundError(err) {
|
||||
return "", nil
|
||||
}
|
||||
return "", fmt.Errorf("get setup account ID for rollback: %w", err)
|
||||
}
|
||||
|
||||
return accountID, nil
|
||||
}
|
||||
|
||||
// rollbackSetupAccount removes only the setup-created account data from the
|
||||
// store. It intentionally avoids accountManager.DeleteAccount because the normal
|
||||
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
|
||||
|
||||
@@ -112,6 +112,76 @@ func TestSetupOwner_PATFeatureEnabled_MissingExpireDefaultsToOneDay(t *testing.T
|
||||
assert.Equal(t, "nbp_plain", result.PATPlainToken)
|
||||
}
|
||||
|
||||
func TestSetupOwner_PATFeatureEnabled_MissingAccountManagerFailsBeforeCreateUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
createCalled := false
|
||||
rollbackCalled := false
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
createOwnerUserFn: func(_ context.Context, email, _, name string) (*idp.UserData, error) {
|
||||
createCalled = true
|
||||
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
|
||||
},
|
||||
rollbackSetupFn: func(_ context.Context, _ string) error {
|
||||
rollbackCalled = true
|
||||
return nil
|
||||
},
|
||||
},
|
||||
nil,
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "account manager is required")
|
||||
assert.False(t, createCalled)
|
||||
assert.False(t, rollbackCalled)
|
||||
}
|
||||
|
||||
func TestSetupOwner_AccountProvisioningFails_RollsBackSideEffectAccountAndUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
accountStore := nbstore.NewMockStore(ctrl)
|
||||
account := &types.Account{Id: "acc-1"}
|
||||
accountStore.EXPECT().GetAccountIDByUserID(gomock.Any(), nbstore.LockingStrengthNone, "owner-id").Return("acc-1", nil)
|
||||
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
|
||||
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
|
||||
|
||||
rolledBackFor := ""
|
||||
setupManager := NewSetupService(
|
||||
&setupInstanceManagerMock{
|
||||
rollbackSetupFn: func(_ context.Context, userID string) error {
|
||||
rolledBackFor = userID
|
||||
return nil
|
||||
},
|
||||
},
|
||||
&mock_server.MockAccountManager{
|
||||
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
|
||||
assert.Equal(t, "owner-id", userAuth.UserId)
|
||||
return "", errors.New("metadata update failed")
|
||||
},
|
||||
GetStoreFunc: func() nbstore.Store {
|
||||
return accountStore
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
|
||||
CreatePAT: true,
|
||||
PATExpireInDays: intPtr(30),
|
||||
})
|
||||
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, result)
|
||||
assert.Contains(t, err.Error(), "create account for setup user")
|
||||
assert.Equal(t, "owner-id", rolledBackFor)
|
||||
}
|
||||
|
||||
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
|
||||
t.Setenv(SetupPATEnabledEnvKey, "true")
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/idp/dex"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
"github.com/netbirdio/netbird/management/server/activity"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
nbpeer "github.com/netbirdio/netbird/management/server/peer"
|
||||
@@ -395,8 +396,8 @@ func (am *DefaultAccountManager) CreatePAT(ctx context.Context, accountID string
|
||||
return nil, status.Errorf(status.InvalidArgument, "token name can't be empty")
|
||||
}
|
||||
|
||||
if expiresIn < 1 || expiresIn > 365 {
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between 1 and 365")
|
||||
if expiresIn < account.PATMinExpireDays || expiresIn > account.PATMaxExpireDays {
|
||||
return nil, status.Errorf(status.InvalidArgument, "expiration has to be between %d and %d", account.PATMinExpireDays, account.PATMaxExpireDays)
|
||||
}
|
||||
|
||||
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, initiatorUserID, modules.Pats, operations.Create)
|
||||
|
||||
Reference in New Issue
Block a user