enable pat creation on setup

This commit is contained in:
jnfrati
2026-04-27 16:54:47 +02:00
parent 53b04e512a
commit be9f1b46e6
9 changed files with 738 additions and 20 deletions

View File

@@ -66,10 +66,11 @@ import (
)
const (
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
apiPrefix = "/api"
rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED"
rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST"
rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM"
setupCreatePATEnabledKey = "NB_SETUP_PAT_ENABLED"
)
// NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints.
@@ -171,7 +172,8 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks
zonesManager.RegisterEndpoints(router, zManager)
recordsManager.RegisterEndpoints(router, rManager)
idp.AddEndpoints(accountManager, router)
instance.AddEndpoints(instanceManager, router)
setupCreatePATEnabled := os.Getenv(setupCreatePATEnabledKey) == "true"
instance.AddEndpoints(instanceManager, accountManager, setupCreatePATEnabled, router)
instance.AddVersionEndpoint(instanceManager, router)
if serviceManager != nil && reverseProxyDomainManager != nil {
reverseproxymanager.RegisterEndpoints(serviceManager, *reverseProxyDomainManager, reverseProxyAccessLogsManager, permissionsManager, router)

View File

@@ -7,21 +7,41 @@ import (
"github.com/gorilla/mux"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/http/util"
"github.com/netbirdio/netbird/shared/management/status"
)
// patMinExpireDays and patMaxExpireDays mirror the bounds enforced by
// DefaultAccountManager.CreatePAT in management/server/user.go. They are
// duplicated here so /api/setup can reject invalid input before it creates
// the embedded-IdP user.
const (
patMinExpireDays = 1
patMaxExpireDays = 365
)
// handler handles the instance setup HTTP endpoints
type handler struct {
instanceManager nbinstance.Manager
instanceManager nbinstance.Manager
setupManager *nbinstance.SetupService
createPATEnabled bool
}
// AddEndpoints registers the instance setup endpoints.
// These endpoints bypass authentication for initial setup.
func AddEndpoints(instanceManager nbinstance.Manager, router *mux.Router) {
//
// createPATEnabled toggles the setup-time Personal Access Token feature
// (NB_SETUP_PAT_ENABLED). When false, the create_pat / pat_expire_in request
// fields are silently ignored, matching the "env-gated, opt-in" pattern used
// by other optional features (e.g. rate limiting).
func AddEndpoints(instanceManager nbinstance.Manager, accountManager account.Manager, createPATEnabled bool, router *mux.Router) {
h := &handler{
instanceManager: instanceManager,
instanceManager: instanceManager,
setupManager: nbinstance.NewSetupService(instanceManager, accountManager),
createPATEnabled: createPATEnabled,
}
router.HandleFunc("/instance", h.getInstanceStatus).Methods("GET", "OPTIONS")
@@ -55,24 +75,54 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) {
// setup creates the initial admin user for the instance.
// This endpoint is unauthenticated but only works when setup is required.
func (h *handler) setup(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var req api.SetupRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
util.WriteErrorResponse("invalid request body", http.StatusBadRequest, w)
return
}
userData, err := h.instanceManager.CreateOwnerUser(r.Context(), req.Email, req.Password, req.Name)
wantPAT := h.createPATEnabled && req.CreatePat != nil && *req.CreatePat
if wantPAT {
if req.PatExpireIn == nil {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "pat_expire_in is required when create_pat is true"), w)
return
}
if *req.PatExpireIn < patMinExpireDays || *req.PatExpireIn > patMaxExpireDays {
util.WriteError(ctx, status.Errorf(status.InvalidArgument, "pat_expire_in must be between %d and %d", patMinExpireDays, patMaxExpireDays), w)
return
}
}
result, err := h.setupManager.SetupOwner(ctx, req.Email, req.Password, req.Name, nbinstance.SetupOptions{
CreatePAT: wantPAT,
PATExpireInDays: expireInDays(req.PatExpireIn),
})
if err != nil {
util.WriteError(r.Context(), err, w)
util.WriteError(ctx, err, w)
return
}
log.WithContext(r.Context()).Infof("instance setup completed: created user %s", req.Email)
log.WithContext(ctx).Infof("instance setup completed: created user %s", req.Email)
util.WriteJSONObject(r.Context(), w, api.SetupResponse{
UserId: userData.ID,
Email: userData.Email,
})
resp := api.SetupResponse{
UserId: result.User.ID,
Email: result.User.Email,
}
if result.PATPlainToken != "" {
resp.PersonalAccessToken = &result.PATPlainToken
}
util.WriteJSONObject(ctx, w, resp)
}
func expireInDays(expireIn *int) int {
if expireIn == nil {
return 0
}
return *expireIn
}
// getVersionInfo returns version information for NetBird components.

View File

@@ -10,12 +10,18 @@ import (
"net/mail"
"testing"
"github.com/golang/mock/gomock"
"github.com/gorilla/mux"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
nbinstance "github.com/netbirdio/netbird/management/server/instance"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/http/api"
"github.com/netbirdio/netbird/shared/management/status"
)
@@ -25,6 +31,7 @@ type mockInstanceManager struct {
isSetupRequired bool
isSetupRequiredFn func(ctx context.Context) (bool, error)
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
getVersionInfoFn func(ctx context.Context) (*nbinstance.VersionInfo, error)
}
@@ -67,6 +74,13 @@ func (m *mockInstanceManager) CreateOwnerUser(ctx context.Context, email, passwo
}, nil
}
func (m *mockInstanceManager) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.VersionInfo, error) {
if m.getVersionInfoFn != nil {
return m.getVersionInfoFn(ctx)
@@ -82,8 +96,12 @@ func (m *mockInstanceManager) GetVersionInfo(ctx context.Context) (*nbinstance.V
var _ nbinstance.Manager = (*mockInstanceManager)(nil)
func setupTestRouter(manager nbinstance.Manager) *mux.Router {
return setupTestRouterWithPAT(manager, nil, false)
}
func setupTestRouterWithPAT(manager nbinstance.Manager, accountManager account.Manager, createPATEnabled bool) *mux.Router {
router := mux.NewRouter()
AddEndpoints(manager, router)
AddEndpoints(manager, accountManager, createPATEnabled, router)
return router
}
@@ -293,6 +311,190 @@ func TestSetup_ManagerError(t *testing.T) {
assert.Equal(t, http.StatusInternalServerError, rec.Code)
}
func TestSetup_PAT_FeatureDisabled_IgnoresCreatePAT(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
// createPATEnabled=false: request fields must be silently ignored
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}, false)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_FlagOmitted_NoPAT(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin"}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Nil(t, response.PersonalAccessToken)
}
func TestSetup_PAT_MissingExpireIn(t *testing.T) {
createCalled := false
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
createCalled = true
return &idp.UserData{ID: "u1", Email: email, Name: name}, nil
},
}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
assert.False(t, createCalled, "CreateOwnerUser must not run when input is rejected")
}
func TestSetup_PAT_ExpireOutOfRange(t *testing.T) {
manager := &mockInstanceManager{isSetupRequired: true}
router := setupTestRouterWithPAT(manager, &mock_server.MockAccountManager{}, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 0}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusUnprocessableEntity, rec.Code)
}
func TestSetup_PAT_Success(t *testing.T) {
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
}
gotAccountArgs := struct {
userID string
email string
}{}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
gotAccountArgs.userID = userAuth.UserId
gotAccountArgs.email = userAuth.Email
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiator, target, name string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiator)
assert.Equal(t, "owner-id", target)
assert.Equal(t, "setup-token", name)
assert.Equal(t, 30, expiresIn)
return &types.PersonalAccessTokenGenerated{PlainToken: "nbp_plain"}, nil
},
}
router := setupTestRouterWithPAT(manager, accountMgr, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusOK, rec.Code)
var response api.SetupResponse
require.NoError(t, json.NewDecoder(rec.Body).Decode(&response))
assert.Equal(t, "owner-id", response.UserId)
require.NotNil(t, response.PersonalAccessToken)
assert.Equal(t, "nbp_plain", *response.PersonalAccessToken)
assert.Equal(t, "owner-id", gotAccountArgs.userID)
}
func TestSetup_PAT_AccountCreationFails_Rollback(t *testing.T) {
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "", errors.New("db down")
},
}
router := setupTestRouterWithPAT(manager, accountMgr, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called with the created user id")
}
func TestSetup_PAT_CreatePATFails_Rollback(t *testing.T) {
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rolledBackFor := ""
manager := &mockInstanceManager{
isSetupRequired: true,
createOwnerUserFn: func(ctx context.Context, email, password, name string) (*idp.UserData, error) {
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
},
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
}
accountMgr := &mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
}
router := setupTestRouterWithPAT(manager, accountMgr, true)
body := `{"email": "admin@example.com", "password": "securepassword123", "name": "Admin", "create_pat": true, "pat_expire_in": 30}`
req := httptest.NewRequest(http.MethodPost, "/setup", bytes.NewBufferString(body))
rec := httptest.NewRecorder()
router.ServeHTTP(rec, req)
assert.Equal(t, http.StatusInternalServerError, rec.Code)
assert.Equal(t, "owner-id", rolledBackFor, "RollbackSetup must be called when CreatePAT fails")
}
func TestGetVersionInfo_Success(t *testing.T) {
manager := &mockInstanceManager{}
router := mux.NewRouter()

View File

@@ -12,6 +12,7 @@ import (
"sync"
"time"
"github.com/dexidp/dex/storage"
goversion "github.com/hashicorp/go-version"
log "github.com/sirupsen/logrus"
@@ -60,6 +61,13 @@ type Manager interface {
// This should only be called when IsSetupRequired returns true.
CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error)
// RollbackSetup reverses a successful CreateOwnerUser by deleting the user
// from the embedded IDP and reloading setupRequired from persistent state, so
// /api/setup can be retried only when no accounts or local users remain. Used
// when post-user steps (account or PAT creation) fail and the caller wants a
// clean slate.
RollbackSetup(ctx context.Context, userID string) error
// GetVersionInfo returns version information for NetBird components.
GetVersionInfo(ctx context.Context) (*VersionInfo, error)
}
@@ -70,6 +78,7 @@ type instanceStore interface {
type embeddedIdP interface {
CreateUserWithPassword(ctx context.Context, email, password, name string) (*idp.UserData, error)
DeleteUser(ctx context.Context, userID string) error
GetAllAccounts(ctx context.Context) (map[string][]*idp.UserData, error)
}
@@ -187,6 +196,51 @@ func (m *DefaultManager) CreateOwnerUser(ctx context.Context, email, password, n
return userData, nil
}
// RollbackSetup undoes a successful CreateOwnerUser: deletes the user from the
// embedded IDP and reloads setupRequired from persistent state.
func (m *DefaultManager) RollbackSetup(ctx context.Context, userID string) error {
if m.embeddedIdpManager == nil {
return errors.New("embedded IDP is not enabled")
}
var deleteErr error
if err := m.embeddedIdpManager.DeleteUser(ctx, userID); err != nil {
if isNotFoundError(err) {
log.WithContext(ctx).Debugf("setup rollback user %s already deleted", userID)
} else {
deleteErr = fmt.Errorf("failed to delete user from embedded IdP: %w", err)
}
}
if err := m.loadSetupRequired(ctx); err != nil {
reloadErr := fmt.Errorf("failed to reload setup state after rollback: %w", err)
if deleteErr != nil {
return errors.Join(deleteErr, reloadErr)
}
return reloadErr
}
if deleteErr != nil {
return deleteErr
}
log.WithContext(ctx).Infof("rolled back setup for user %s", userID)
return nil
}
func isNotFoundError(err error) bool {
if err == nil {
return false
}
if errors.Is(err, storage.ErrNotFound) {
return true
}
if s, ok := status.FromError(err); ok {
return s.Type() == status.NotFound
}
return false
}
func (m *DefaultManager) checkSetupRequiredFromDB(ctx context.Context) error {
numAccounts, err := m.store.GetAccountsCounter(ctx)
if err != nil {

View File

@@ -10,16 +10,19 @@ import (
"testing"
"time"
"github.com/dexidp/dex/storage"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/management/status"
)
type mockIdP struct {
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
users map[string][]*idp.UserData
mu sync.Mutex
createUserFunc func(ctx context.Context, email, password, name string) (*idp.UserData, error)
deleteUserFunc func(ctx context.Context, userID string) error
users map[string][]*idp.UserData
getAllAccountsErr error
}
@@ -30,6 +33,13 @@ func (m *mockIdP) CreateUserWithPassword(ctx context.Context, email, password, n
return &idp.UserData{ID: "test-user-id", Email: email, Name: name}, nil
}
func (m *mockIdP) DeleteUser(ctx context.Context, userID string) error {
if m.deleteUserFunc != nil {
return m.deleteUserFunc(ctx, userID)
}
return nil
}
func (m *mockIdP) GetAllAccounts(_ context.Context) (map[string][]*idp.UserData, error) {
m.mu.Lock()
defer m.mu.Unlock()
@@ -223,6 +233,77 @@ func TestIsSetupRequired_ReturnsFlag(t *testing.T) {
assert.False(t, required)
}
func TestRollbackSetup_UserAlreadyDeletedIsSuccess(t *testing.T) {
tests := []struct {
name string
err error
}{
{
name: "management status not found",
err: status.NewUserNotFoundError("owner-id"),
},
{
name: "dex storage not found",
err: fmt.Errorf("failed to get user for deletion: %w", storage.ErrNotFound),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, userID string) error {
assert.Equal(t, "owner-id", userID)
return tt.err
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup should be required when no accounts or local users remain")
})
}
}
func TestRollbackSetup_RecomputesSetupStateWhenAccountStillExists(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return status.NewUserNotFoundError("owner-id")
},
}
mgr := newTestManager(idpMock, &mockStore{accountsCount: 1})
mgr.setupRequired = true
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.NoError(t, err)
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.False(t, required, "setup should not be required while an account still exists")
}
func TestRollbackSetup_ReturnsDeleteErrorButReloadsSetupState(t *testing.T) {
idpMock := &mockIdP{
deleteUserFunc: func(_ context.Context, _ string) error {
return errors.New("idp unavailable")
},
}
mgr := newTestManager(idpMock, &mockStore{})
mgr.setupRequired = false
err := mgr.RollbackSetup(context.Background(), "owner-id")
require.Error(t, err)
assert.Contains(t, err.Error(), "idp unavailable")
required, err := mgr.IsSetupRequired(context.Background())
require.NoError(t, err)
assert.True(t, required, "setup state should be reloaded even when user deletion fails")
}
func TestDefaultManager_ValidateSetupRequest(t *testing.T) {
manager := &DefaultManager{setupRequired: true}

View File

@@ -0,0 +1,134 @@
package instance
import (
"context"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/account"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/shared/auth"
)
const setupPATTokenName = "setup-token"
// SetupOptions controls optional work performed during initial instance setup.
type SetupOptions struct {
CreatePAT bool
PATExpireInDays int
}
// SetupResult contains resources created during initial instance setup.
type SetupResult struct {
User *idp.UserData
PATPlainToken string
}
// SetupService orchestrates the initial setup use case across the instance and
// account bounded contexts and owns the compensation logic when a later step
// fails.
type SetupService struct {
instanceManager Manager
accountManager account.Manager
}
// NewSetupService creates a setup use-case service.
func NewSetupService(instanceManager Manager, accountManager account.Manager) *SetupService {
return &SetupService{
instanceManager: instanceManager,
accountManager: accountManager,
}
}
// SetupOwner creates the initial owner user and, optionally, provisions the
// account and a setup Personal Access Token. If account or PAT provisioning
// fails, created resources are rolled back so setup can be retried.
func (m *SetupService) SetupOwner(ctx context.Context, email, password, name string, opts SetupOptions) (*SetupResult, error) {
userData, err := m.instanceManager.CreateOwnerUser(ctx, email, password, name)
if err != nil {
return nil, err
}
result := &SetupResult{User: userData}
if !opts.CreatePAT {
return result, nil
}
if m.accountManager == nil {
err := fmt.Errorf("account manager is required to create setup PAT")
m.rollbackSetup(ctx, userData.ID, "setup PAT requested without account manager", err, "")
return nil, err
}
userAuth := auth.UserAuth{
UserId: userData.ID,
Email: userData.Email,
Name: userData.Name,
}
accountID, err := m.accountManager.GetAccountIDByUserID(ctx, userAuth)
if err != nil {
err = fmt.Errorf("create account for setup user: %w", err)
m.rollbackSetup(ctx, userData.ID, "account provisioning failed", err, "")
return nil, err
}
pat, err := m.accountManager.CreatePAT(ctx, accountID, userData.ID, userData.ID, setupPATTokenName, opts.PATExpireInDays)
if err != nil {
err = fmt.Errorf("create setup PAT: %w", err)
m.rollbackSetup(ctx, userData.ID, "setup PAT provisioning failed", err, accountID)
return nil, err
}
result.PATPlainToken = pat.PlainToken
return result, nil
}
func (m *SetupService) rollbackSetup(ctx context.Context, userID, reason string, origErr error, accountID string) {
if accountID != "" {
if err := m.rollbackSetupAccount(ctx, accountID); err != nil {
log.WithContext(ctx).Errorf("failed to roll back setup account %s for user %s after %s: original error: %v, rollback error: %v", accountID, userID, reason, origErr, err)
} else {
log.WithContext(ctx).Warnf("rolled back setup account %s for user %s after %s: %v", accountID, userID, reason, origErr)
}
}
if err := m.instanceManager.RollbackSetup(ctx, userID); err != nil {
log.WithContext(ctx).Errorf("failed to roll back setup user %s after %s: original error: %v, rollback error: %v", userID, reason, origErr, err)
return
}
log.WithContext(ctx).Warnf("rolled back setup user %s after %s: %v", userID, reason, origErr)
}
// rollbackSetupAccount removes only the setup-created account data from the
// store. It intentionally avoids accountManager.DeleteAccount because the normal
// account deletion path also deletes users from the IdP; embedded IdP cleanup is
// owned by instanceManager.RollbackSetup.
func (m *SetupService) rollbackSetupAccount(ctx context.Context, accountID string) error {
if m.accountManager == nil {
return fmt.Errorf("account manager is required to roll back setup account")
}
accountStore := m.accountManager.GetStore()
if accountStore == nil {
return fmt.Errorf("account store is unavailable")
}
account, err := accountStore.GetAccount(ctx, accountID)
if err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("get setup account for rollback: %w", err)
}
if err := accountStore.DeleteAccount(ctx, account); err != nil {
if isNotFoundError(err) {
return nil
}
return fmt.Errorf("delete setup account for rollback: %w", err)
}
return nil
}

View File

@@ -0,0 +1,169 @@
package instance
import (
"context"
"errors"
"testing"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/management/server/idp"
"github.com/netbirdio/netbird/management/server/mock_server"
nbstore "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/auth"
"github.com/netbirdio/netbird/shared/management/status"
)
type setupInstanceManagerMock struct {
createOwnerUserFn func(ctx context.Context, email, password, name string) (*idp.UserData, error)
rollbackSetupFn func(ctx context.Context, userID string) error
}
func (m *setupInstanceManagerMock) IsSetupRequired(context.Context) (bool, error) {
return true, nil
}
func (m *setupInstanceManagerMock) CreateOwnerUser(ctx context.Context, email, password, name string) (*idp.UserData, error) {
if m.createOwnerUserFn != nil {
return m.createOwnerUserFn(ctx, email, password, name)
}
return &idp.UserData{ID: "owner-id", Email: email, Name: name}, nil
}
func (m *setupInstanceManagerMock) RollbackSetup(ctx context.Context, userID string) error {
if m.rollbackSetupFn != nil {
return m.rollbackSetupFn(ctx, userID)
}
return nil
}
func (m *setupInstanceManagerMock) GetVersionInfo(context.Context) (*VersionInfo, error) {
return nil, nil
}
var _ Manager = (*setupInstanceManagerMock)(nil)
func TestSetupOwner_CreatePATFails_RollsBackSetupAccountAndUser(t *testing.T) {
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(nil)
rollbackCalls := 0
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rollbackCalls++
assert.Equal(t, "owner-id", userID)
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, userAuth auth.UserAuth) (string, error) {
assert.Equal(t, "owner-id", userAuth.UserId)
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, accountID, initiatorUserID, targetUserID, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) {
assert.Equal(t, "acc-1", accountID)
assert.Equal(t, "owner-id", initiatorUserID)
assert.Equal(t, "owner-id", targetUserID)
assert.Equal(t, setupPATTokenName, tokenName)
assert.Equal(t, 30, expiresIn)
return nil, status.Errorf(status.Internal, "token store unavailable")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: 30,
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, 1, rollbackCalls)
}
func TestSetupOwner_CreatePATFails_AccountAlreadyGoneStillRollsBackUser(t *testing.T) {
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(nil, status.NewAccountNotFoundError("acc-1"))
rolledBackFor := ""
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: 30,
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, "owner-id", rolledBackFor)
}
func TestSetupOwner_CreatePATFails_AccountRollbackFailureStillRollsBackUser(t *testing.T) {
ctrl := gomock.NewController(t)
accountStore := nbstore.NewMockStore(ctrl)
account := &types.Account{Id: "acc-1"}
accountStore.EXPECT().GetAccount(gomock.Any(), "acc-1").Return(account, nil)
accountStore.EXPECT().DeleteAccount(gomock.Any(), account).Return(errors.New("delete failed"))
rolledBackFor := ""
setupManager := NewSetupService(
&setupInstanceManagerMock{
rollbackSetupFn: func(_ context.Context, userID string) error {
rolledBackFor = userID
return nil
},
},
&mock_server.MockAccountManager{
GetAccountIDByUserIdFunc: func(_ context.Context, _ auth.UserAuth) (string, error) {
return "acc-1", nil
},
CreatePATFunc: func(_ context.Context, _, _, _, _ string, _ int) (*types.PersonalAccessTokenGenerated, error) {
return nil, errors.New("token failure")
},
GetStoreFunc: func() nbstore.Store {
return accountStore
},
},
)
result, err := setupManager.SetupOwner(context.Background(), "admin@example.com", "securepassword123", "Admin", SetupOptions{
CreatePAT: true,
PATExpireInDays: 30,
})
require.Error(t, err)
assert.Nil(t, result)
assert.Contains(t, err.Error(), "create setup PAT")
assert.Equal(t, "owner-id", rolledBackFor)
}