From 2c81cf2c1ea3a55466090fbf9741880c861ace3a Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Thu, 3 Jul 2025 09:01:32 +0200 Subject: [PATCH] [management] Add account onboarding (#4084) This PR introduces a new onboarding feature to handle such flows in the dashboard by defining an AccountOnboarding model, persisting it in the store, exposing CRUD operations in the manager and HTTP handlers, and updating API schemas and tests accordingly. Add AccountOnboarding struct and embed it in Account Extend Store and DefaultAccountManager with onboarding methods and SQL migrations Update HTTP handlers, API types, OpenAPI spec, and add end-to-end tests --- management/server/account.go | 69 +++++++ management/server/account/manager.go | 2 + management/server/account_test.go | 71 +++++++ management/server/http/api/openapi.yml | 19 ++ management/server/http/api/types.gen.go | 17 +- .../handlers/accounts/accounts_handler.go | 32 ++- .../accounts/accounts_handler_test.go | 44 +++- management/server/mock_server/account_mock.go | 193 ++++++++++-------- management/server/status/error.go | 5 + management/server/store/sql_store.go | 28 ++- management/server/store/sql_store_test.go | 74 +++++++ management/server/store/store.go | 2 + management/server/testdata/store.sql | 4 +- management/server/types/account.go | 19 +- 14 files changed, 476 insertions(+), 103 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 8a80aefb6..cd0c933f0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1204,6 +1204,71 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) } +// GetAccountOnboarding retrieves the onboarding information for a specific account. +func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + return nil, err + } + + if onboarding == nil { + onboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + return onboarding, nil +} + +func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + return nil, fmt.Errorf("failed to get account onboarding: %w", err) + } + + if oldOnboarding == nil { + oldOnboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + if newOnboarding == nil { + return oldOnboarding, nil + } + + if oldOnboarding.IsEqual(*newOnboarding) { + log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) + return oldOnboarding, nil + } + + newOnboarding.AccountID = accountID + err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) + if err != nil { + return nil, fmt.Errorf("failed to update account onboarding: %w", err) + } + + return newOnboarding, nil +} + func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) @@ -1726,6 +1791,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string, dis PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, } if err := acc.AddAllGroup(disableDefaultPolicy); err != nil { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index de5031c03..ed17fa5ec 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -39,6 +39,7 @@ type Manager interface { GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) @@ -89,6 +90,7 @@ type Manager interface { SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 60353389f..fcd40b082 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3448,3 +3448,74 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } }) } + +func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + t.Run("should return account onboarding when onboarding exist", func(t *testing.T) { + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + assert.Equal(t, account.Id, onboarding.AccountID) + assert.Equal(t, true, onboarding.OnboardingFlowPending) + assert.Equal(t, true, onboarding.SignupFormPending) + if onboarding.UpdatedAt.IsZero() { + t.Errorf("Onboarding was not retrieved from the store") + } + }) + + t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) { + account.Id = "with-zero-onboarding" + account.Onboarding = types.AccountOnboarding{} + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + _, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "should return error when onboarding is not set") + }) +} + +func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + onboarding := &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + } + + t.Run("update onboarding with no change", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + if updated.UpdatedAt.IsZero() { + t.Errorf("Onboarding was updated in the store") + } + }) + + onboarding.OnboardingFlowPending = false + onboarding.SignupFormPending = false + + t.Run("update onboarding", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + require.NotNil(t, updated) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + }) + + t.Run("update onboarding with no onboarding", func(t *testing.T) { + _, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil) + require.NoError(t, err) + }) +} diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1c5ca9b04..f8c2b9854 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -60,6 +60,8 @@ components: description: Account creator type: string example: google-oauth2|277474792786460067937 + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - id - settings @@ -67,6 +69,21 @@ components: - domain_category - created_at - created_by + - onboarding + AccountOnboarding: + type: object + properties: + signup_form_pending: + description: Indicates whether the account signup form is pending + type: boolean + example: true + onboarding_flow_pending: + description: Indicates whether the account onboarding flow is pending + type: boolean + example: false + required: + - signup_form_pending + - onboarding_flow_pending AccountSettings: type: object properties: @@ -153,6 +170,8 @@ components: properties: settings: $ref: '#/components/schemas/AccountSettings' + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - settings User: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index d27fd2a57..a9f17aab4 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -250,8 +250,9 @@ type Account struct { DomainCategory string `json:"domain_category"` // Id Account ID - Id string `json:"id"` - Settings AccountSettings `json:"settings"` + Id string `json:"id"` + Onboarding AccountOnboarding `json:"onboarding"` + Settings AccountSettings `json:"settings"` } // AccountExtraSettings defines model for AccountExtraSettings. @@ -266,9 +267,19 @@ type AccountExtraSettings struct { PeerApprovalEnabled bool `json:"peer_approval_enabled"` } +// AccountOnboarding defines model for AccountOnboarding. +type AccountOnboarding struct { + // OnboardingFlowPending Indicates whether the account onboarding flow is pending + OnboardingFlowPending bool `json:"onboarding_flow_pending"` + + // SignupFormPending Indicates whether the account signup form is pending + SignupFormPending bool `json:"signup_form_pending"` +} + // AccountRequest defines model for AccountRequest. type AccountRequest struct { - Settings AccountSettings `json:"settings"` + Onboarding *AccountOnboarding `json:"onboarding,omitempty"` + Settings AccountSettings `json:"settings"` } // AccountSettings defines model for AccountSettings. diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index dfc782b3f..ab59434d1 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -126,6 +132,20 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled } + var onboarding *types.AccountOnboarding + if req.Onboarding != nil { + onboarding = &types.AccountOnboarding{ + OnboardingFlowPending: req.Onboarding.OnboardingFlowPending, + SignupFormPending: req.Onboarding.SignupFormPending, + } + } + + updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) @@ -138,7 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -167,7 +187,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -188,6 +208,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A DnsDomain: &settings.DNSDomain, } + apiOnboarding := api.AccountOnboarding{ + OnboardingFlowPending: onboarding.OnboardingFlowPending, + SignupFormPending: onboarding.SignupFormPending, + } + if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, @@ -203,5 +228,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A CreatedBy: meta.CreatedBy, Domain: meta.Domain, DomainCategory: meta.DomainCategory, + Onboarding: apiOnboarding, } } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index a18798743..dbf0c22bc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -54,6 +54,18 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { return account.GetMeta(), nil }, + GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, + UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, }, settingsManager: settingsMockManager, } @@ -117,7 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -139,7 +151,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -161,7 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -178,12 +190,34 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedArray: false, expectedID: accountID, }, + { + name: "PutAccount OK without onboarding", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + }, + expectedArray: false, + expectedID: accountID, + }, { name: "Update account failure with high peer_login_expiration more than 180 days", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, @@ -192,7 +226,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3caa6744a..8837f9f50 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -30,94 +30,95 @@ type MockAccountManager struct { GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) - AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) - GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) - GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error - DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error - DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) - DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) - DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) - UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) - GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error - DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error - ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) - DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) - DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) - GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) - DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func(settings *types.Settings) string - StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error - GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) - LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error - GetAllConnectedPeersFunc func() (map[string]struct{}, error) - HasConnectedChannelFunc func(peerID string) bool - GetExternalCacheManagerFunc func() account.ExternalCacheManager - GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) - DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) - GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error - GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) - SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error - FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) - DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error - BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - GetStoreFunc func() store.Store - UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) - GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) - GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) - + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) + GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error + DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error + GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) + DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) + DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) + UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) + DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) + GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + DeleteAccountFunc func(ctx context.Context, accountID, userID string) error + GetDNSDomainFunc func(settings *types.Settings) string + StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error + GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) + HasConnectedChannelFunc func(peerID string) bool + GetExternalCacheManagerFunc func() account.ExternalCacheManager + GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) + DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) + GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) + SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error + FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error + BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) + GetStoreFunc func() store.Store + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) + GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) + GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) + UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) } @@ -814,6 +815,22 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") } +// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + if am.GetAccountOnboardingFunc != nil { + return am.GetAccountOnboardingFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented") +} + +// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + if am.UpdateAccountOnboardingFunc != nil { + return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a6f6d1a7..47c236e93 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding +func NewAccountOnboardingNotFoundError(accountKey string) error { + return Errorf(NotFound, "account onboarding not found: %s", accountKey) +} + // NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account func NewPeerNotPartOfAccountError() error { return Errorf(PermissionDenied, "peer is not part of this account") diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 197255ab6..baee4ad28 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -99,7 +99,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, - &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, ) if err != nil { return nil, fmt.Errorf("auto migratePreAuto: %w", err) @@ -728,6 +728,32 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren return &accountMeta, nil } +// GetAccountOnboarding retrieves the onboarding information for a specific account. +func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) { + var accountOnboarding types.AccountOnboarding + result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) + if result.Error != nil { + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountOnboardingNotFoundError(accountID) + } + log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + return &accountOnboarding, nil +} + +// SaveAccountOnboarding updates the onboarding information for a specific account. +func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error { + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + } + + return nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 928486ab4..738c5a28c 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -354,9 +354,16 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } + o, err := store.GetAccountOnboarding(context.Background(), account.Id) + require.NoError(t, err) + require.Equal(t, o.AccountID, account.Id) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) + _, err = store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding") + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } @@ -414,12 +421,21 @@ func Test_GetAccount(t *testing.T) { account, err := store.GetAccount(context.Background(), id) require.NoError(t, err) require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, false, account.Onboarding.OnboardingFlowPending) + + id = "9439-34653001fc3b-bf1c8084-ba50-4ce7" + + account, err = store.GetAccount(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, true, account.Onboarding.OnboardingFlowPending) _, err = store.GetAccount(context.Background(), "non-existing-account") assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } @@ -2096,6 +2112,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, }, + Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true}, } if err := acc.AddAllGroup(false); err != nil { @@ -3440,6 +3457,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) } +func TestSqlStore_GetAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + a, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + t.Logf("Onboarding: %+v", a.Onboarding) + err = store.SaveAccount(context.Background(), a) + require.NoError(t, err) + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.NotNil(t, onboarding) + require.Equal(t, accountID, onboarding.AccountID) + require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC()) +} + +func TestSqlStore_SaveAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + t.Run("New onboarding should be saved correctly", func(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + onboarding := &types.AccountOnboarding{ + AccountID: accountID, + SignupFormPending: true, + OnboardingFlowPending: true, + } + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) + + t.Run("Existing onboarding should be updated correctly", func(t *testing.T) { + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + + onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending + onboarding.SignupFormPending = !onboarding.SignupFormPending + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) +} + func TestSqlStore_GetAnyAccountID(t *testing.T) { t.Run("should return account ID when accounts exist", func(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index 30ff1549d..b3254c4c9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -52,6 +52,7 @@ type Store interface { GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) @@ -74,6 +75,7 @@ type Store interface { SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) + SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 4b126c618..a21783857 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -1,4 +1,5 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`); CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index 5a62ee4c6..f0887be07 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -82,11 +82,11 @@ type Account struct { DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` - + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` + Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` } // Subclass used in gorm to only load network and not whole account @@ -104,6 +104,20 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type AccountOnboarding struct { + AccountID string `gorm:"primaryKey"` + OnboardingFlowPending bool + SignupFormPending bool + CreatedAt time.Time + UpdatedAt time.Time +} + +// IsEqual compares two AccountOnboarding objects and returns true if they are equal +func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { + return o.OnboardingFlowPending == onboarding.OnboardingFlowPending && + o.SignupFormPending == onboarding.SignupFormPending +} + // GetRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. @@ -866,6 +880,7 @@ func (a *Account) Copy() *Account { Networks: nets, NetworkRouters: networkRouters, NetworkResources: networkResources, + Onboarding: a.Onboarding, } }