From 432dc42bf563e6fb5c1ed7e7e9a8f36b7a9343af Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 1 Jul 2025 11:51:46 +0200 Subject: [PATCH 01/12] add account onboarding --- management/server/account.go | 12 +++++++ management/server/account/manager.go | 1 + management/server/http/api/openapi.yml | 19 ++++++++++++ management/server/http/api/types.gen.go | 17 ++++++++-- .../handlers/accounts/accounts_handler.go | 31 ++++++++++++++++--- management/server/mock_server/account_mock.go | 9 ++++++ management/server/types/account.go | 10 ++++++ 7 files changed, 92 insertions(+), 7 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index b376f6f5e..023818c67 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1188,6 +1188,18 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) } +func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) + if err != nil { + return nil, status.NewPermissionValidationError(err) + } + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + return am.Store.GetAccountOnboarding(ctx, store.LockingStrengthShare, accountID) +} + func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { if userAuth.UserId == "" { return "", "", errors.New(emptyUserID) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index de5031c03..f298bce10 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -39,6 +39,7 @@ type Manager interface { GetSetupKey(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) GetAccountByID(ctx context.Context, accountID string, userID string) (*types.Account, error) GetAccountMeta(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, accountID string) (bool, error) GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) diff --git a/management/server/http/api/openapi.yml b/management/server/http/api/openapi.yml index 1c5ca9b04..f8c2b9854 100644 --- a/management/server/http/api/openapi.yml +++ b/management/server/http/api/openapi.yml @@ -60,6 +60,8 @@ components: description: Account creator type: string example: google-oauth2|277474792786460067937 + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - id - settings @@ -67,6 +69,21 @@ components: - domain_category - created_at - created_by + - onboarding + AccountOnboarding: + type: object + properties: + signup_form_pending: + description: Indicates whether the account signup form is pending + type: boolean + example: true + onboarding_flow_pending: + description: Indicates whether the account onboarding flow is pending + type: boolean + example: false + required: + - signup_form_pending + - onboarding_flow_pending AccountSettings: type: object properties: @@ -153,6 +170,8 @@ components: properties: settings: $ref: '#/components/schemas/AccountSettings' + onboarding: + $ref: '#/components/schemas/AccountOnboarding' required: - settings User: diff --git a/management/server/http/api/types.gen.go b/management/server/http/api/types.gen.go index d27fd2a57..a9f17aab4 100644 --- a/management/server/http/api/types.gen.go +++ b/management/server/http/api/types.gen.go @@ -250,8 +250,9 @@ type Account struct { DomainCategory string `json:"domain_category"` // Id Account ID - Id string `json:"id"` - Settings AccountSettings `json:"settings"` + Id string `json:"id"` + Onboarding AccountOnboarding `json:"onboarding"` + Settings AccountSettings `json:"settings"` } // AccountExtraSettings defines model for AccountExtraSettings. @@ -266,9 +267,19 @@ type AccountExtraSettings struct { PeerApprovalEnabled bool `json:"peer_approval_enabled"` } +// AccountOnboarding defines model for AccountOnboarding. +type AccountOnboarding struct { + // OnboardingFlowPending Indicates whether the account onboarding flow is pending + OnboardingFlowPending bool `json:"onboarding_flow_pending"` + + // SignupFormPending Indicates whether the account signup form is pending + SignupFormPending bool `json:"signup_form_pending"` +} + // AccountRequest defines model for AccountRequest. type AccountRequest struct { - Settings AccountSettings `json:"settings"` + Onboarding *AccountOnboarding `json:"onboarding,omitempty"` + Settings AccountSettings `json:"settings"` } // AccountSettings defines model for AccountSettings. diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index dfc782b3f..84728c48d 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -59,7 +59,13 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -126,7 +132,12 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled } - updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + onboarding := &types.AccountOnboarding{ + OnboardingFlowPending: req.Onboarding.OnboardingFlowPending, + SignupFormPending: req.Onboarding.SignupFormPending, + } + + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings, onboarding) if err != nil { util.WriteError(r.Context(), err, w) return @@ -138,7 +149,13 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta) + onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + resp := toAccountResponse(accountID, updatedSettings, meta, onboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -167,7 +184,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -188,6 +205,11 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A DnsDomain: &settings.DNSDomain, } + apiOnboarding := api.AccountOnboarding{ + OnboardingFlowPending: onboarding.OnboardingFlowPending, + SignupFormPending: onboarding.SignupFormPending, + } + if settings.Extra != nil { apiSettings.Extra = &api.AccountExtraSettings{ PeerApprovalEnabled: settings.Extra.PeerApprovalEnabled, @@ -203,5 +225,6 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A CreatedBy: meta.CreatedBy, Domain: meta.Domain, DomainCategory: meta.DomainCategory, + Onboarding: apiOnboarding, } } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 3caa6744a..723f64b69 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -117,6 +117,7 @@ type MockAccountManager struct { GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) + GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) } @@ -814,6 +815,14 @@ func (am *MockAccountManager) GetAccountMeta(ctx context.Context, accountID stri return nil, status.Errorf(codes.Unimplemented, "method GetAccountMeta is not implemented") } +// GetAccountOnboarding mocks GetAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + if am.GetAccountOnboardingFunc != nil { + return am.GetAccountOnboardingFunc(ctx, accountID, userID) + } + return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { diff --git a/management/server/types/account.go b/management/server/types/account.go index 090ba76e4..ac05ee363 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -87,6 +87,7 @@ type Account struct { Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` + Onboarding AccountOnboarding } // Subclass used in gorm to only load network and not whole account @@ -104,6 +105,14 @@ type AccountSettings struct { Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` } +type AccountOnboarding struct { + AccountID string `gorm:"primaryKey"` + OnboardingFlowPending bool + SignupFormPending bool + CreatedAt time.Time + UpdatedAt time.Time +} + // GetRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. @@ -866,6 +875,7 @@ func (a *Account) Copy() *Account { Networks: nets, NetworkRouters: networkRouters, NetworkResources: networkResources, + Onboarding: a.Onboarding, } } From 2dc230ab9abd180735e53031d9a13355a34a5802 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 1 Jul 2025 19:54:53 +0200 Subject: [PATCH 02/12] add store and account manager methods add store tests --- management/server/account.go | 39 +++- management/server/account/manager.go | 1 + .../handlers/accounts/accounts_handler.go | 25 +-- .../accounts/accounts_handler_test.go | 34 +++- management/server/mock_server/account_mock.go | 186 +++++++++--------- management/server/store/sql_store.go | 30 ++- management/server/store/sql_store_test.go | 74 +++++++ management/server/store/store.go | 2 + management/server/testdata/store.sql | 4 +- management/server/types/account.go | 11 +- 10 files changed, 295 insertions(+), 111 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 023818c67..dd3a72f60 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -275,6 +275,41 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { return am.idpManager } +func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) + defer unlock() + + if newOnboarding == nil { + return nil, status.Errorf(status.InvalidArgument, "new onboarding data cannot be nil") + } + + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil { + return nil, fmt.Errorf("failed to get account onboarding: %w", err) + } + if oldOnboarding.IsEqual(*newOnboarding) { + log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) + return oldOnboarding, nil + } + + newOnboarding.AccountID = accountID + err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) + if err != nil { + return nil, fmt.Errorf("failed to update account onboarding: %w", err) + } + + return newOnboarding, nil +} + // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. @@ -1188,16 +1223,18 @@ func (am *DefaultAccountManager) GetAccountMeta(ctx context.Context, accountID s return am.Store.GetAccountMeta(ctx, store.LockingStrengthShare, accountID) } +// GetAccountOnboarding retrieves the onboarding information for a specific account. func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Accounts, operations.Read) if err != nil { return nil, status.NewPermissionValidationError(err) } + if !allowed { return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountOnboarding(ctx, store.LockingStrengthShare, accountID) + return am.Store.GetAccountOnboarding(ctx, accountID) } func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { diff --git a/management/server/account/manager.go b/management/server/account/manager.go index f298bce10..ed17fa5ec 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -90,6 +90,7 @@ type Manager interface { SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 84728c48d..ab59434d1 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -132,12 +132,21 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled } - onboarding := &types.AccountOnboarding{ - OnboardingFlowPending: req.Onboarding.OnboardingFlowPending, - SignupFormPending: req.Onboarding.SignupFormPending, + var onboarding *types.AccountOnboarding + if req.Onboarding != nil { + onboarding = &types.AccountOnboarding{ + OnboardingFlowPending: req.Onboarding.OnboardingFlowPending, + SignupFormPending: req.Onboarding.SignupFormPending, + } } - updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings, onboarding) + updatedOnboarding, err := h.accountManager.UpdateAccountOnboarding(r.Context(), accountID, userID, onboarding) + if err != nil { + util.WriteError(r.Context(), err, w) + return + } + + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return @@ -149,13 +158,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - onboarding, err := h.accountManager.GetAccountOnboarding(r.Context(), accountID, userID) - if err != nil { - util.WriteError(r.Context(), err, w) - return - } - - resp := toAccountResponse(accountID, updatedSettings, meta, onboarding) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index a18798743..fa901677b 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -54,6 +54,21 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetAccountMetaFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountMeta, error) { return account.GetMeta(), nil }, + GetAccountOnboardingFunc: func(ctx context.Context, accountID string, userID string) (*types.AccountOnboarding, error) { + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, + UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + if onboarding == nil { + return nil, status.Errorf(status.InvalidArgument, "onboarding cannot be nil") + } + return &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, nil + }, }, settingsManager: settingsMockManager, } @@ -117,7 +132,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -139,7 +154,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 15552000, @@ -161,7 +176,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 554400,\"peer_login_expiration_enabled\": true,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"groups\",\"groups_propagation_enabled\":true,\"regular_users_view_blocked\":true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusOK, expectedSettings: api.AccountSettings{ PeerLoginExpiration: 554400, @@ -178,12 +193,21 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedArray: false, expectedID: accountID, }, + { + name: "PutAccount fail without onboarding", + expectedBody: true, + requestType: http.MethodPut, + requestPath: "/api/accounts/" + accountID, + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), + expectedStatus: http.StatusUnprocessableEntity, + expectedArray: false, + }, { name: "Update account failure with high peer_login_expiration more than 180 days", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552001,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, @@ -192,7 +216,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, - requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true}}"), + requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 3599,\"peer_login_expiration_enabled\": true},\"onboarding\": {\"onboarding_flow_pending\": true,\"signup_form_pending\": true}}"), expectedStatus: http.StatusUnprocessableEntity, expectedArray: false, }, diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 723f64b69..8837f9f50 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -30,95 +30,95 @@ type MockAccountManager struct { GetAccountFunc func(ctx context.Context, accountID string) (*types.Account, error) CreateSetupKeyFunc func(ctx context.Context, accountId string, keyName string, keyType types.SetupKeyType, expiresIn time.Duration, autoGroups []string, usageLimit int, userID string, ephemeral bool, allowExtraDNSLabels bool) (*types.SetupKey, error) - GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) - AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) - GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) - GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) - ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) - GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error - GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) - GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) - AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) - GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) - GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) - SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error - SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error - DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error - DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error - GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error - GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) - DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error - GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) - SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) - DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error - ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) - GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) - UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error - UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) - CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) - GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) - SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error - DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error - ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) - SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) - ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) - SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) - SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) - SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) - DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error - DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error - CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) - DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error - GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) - GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) - GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) - CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) - SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error - DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error - ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) - CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) - GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) - DeleteAccountFunc func(ctx context.Context, accountID, userID string) error - GetDNSDomainFunc func(settings *types.Settings) string - StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) - GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) - GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) - SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error - GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) - LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) - InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error - GetAllConnectedPeersFunc func() (map[string]struct{}, error) - HasConnectedChannelFunc func(peerID string) bool - GetExternalCacheManagerFunc func() account.ExternalCacheManager - GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) - SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) - DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error - ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) - GetIdpManagerFunc func() idp.Manager - UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error - GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) - SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error - FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) - GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) - GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) - GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) - GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) - DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error - BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) - GetStoreFunc func() store.Store - UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) - GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) - GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) - GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) - GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) - + GetSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) (*types.SetupKey, error) + AccountExistsFunc func(ctx context.Context, accountID string) (bool, error) + GetAccountIDByUserIdFunc func(ctx context.Context, userId, domain string) (string, error) + GetUserFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*types.User, error) + ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) + GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error + GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) + GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) + AddPeerFunc func(ctx context.Context, setupKey string, userId string, peer *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error) + GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error) + GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error) + SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error + SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error + DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error + DeleteGroupsFunc func(ctx context.Context, accountId, userId string, groupIDs []string) error + GroupAddPeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GroupDeletePeerFunc func(ctx context.Context, accountID, groupID, peerID string) error + GetPeerGroupsFunc func(ctx context.Context, accountID, peerID string) ([]*types.Group, error) + DeleteRuleFunc func(ctx context.Context, accountID, ruleID, userID string) error + GetPolicyFunc func(ctx context.Context, accountID, policyID, userID string) (*types.Policy, error) + SavePolicyFunc func(ctx context.Context, accountID, userID string, policy *types.Policy, create bool) (*types.Policy, error) + DeletePolicyFunc func(ctx context.Context, accountID, policyID, userID string) error + ListPoliciesFunc func(ctx context.Context, accountID, userID string) ([]*types.Policy, error) + GetUsersFromAccountFunc func(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error) + UpdatePeerMetaFunc func(ctx context.Context, peerID string, meta nbpeer.PeerSystemMeta) error + UpdatePeerFunc func(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) + CreateRouteFunc func(ctx context.Context, accountID string, prefix netip.Prefix, networkType route.NetworkType, domains domain.List, peer string, peerGroups []string, description string, netID route.NetID, masquerade bool, metric int, groups, accessControlGroupIDs []string, enabled bool, userID string, keepRoute bool) (*route.Route, error) + GetRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) (*route.Route, error) + SaveRouteFunc func(ctx context.Context, accountID string, userID string, route *route.Route) error + DeleteRouteFunc func(ctx context.Context, accountID string, routeID route.ID, userID string) error + ListRoutesFunc func(ctx context.Context, accountID, userID string) ([]*route.Route, error) + SaveSetupKeyFunc func(ctx context.Context, accountID string, key *types.SetupKey, userID string) (*types.SetupKey, error) + ListSetupKeysFunc func(ctx context.Context, accountID, userID string) ([]*types.SetupKey, error) + SaveUserFunc func(ctx context.Context, accountID, userID string, user *types.User) (*types.UserInfo, error) + SaveOrAddUserFunc func(ctx context.Context, accountID, userID string, user *types.User, addIfNotExists bool) (*types.UserInfo, error) + SaveOrAddUsersFunc func(ctx context.Context, accountID, initiatorUserID string, update []*types.User, addIfNotExists bool) ([]*types.UserInfo, error) + DeleteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserID string) error + DeleteRegularUsersFunc func(ctx context.Context, accountID, initiatorUserID string, targetUserIDs []string, userInfos map[string]*types.UserInfo) error + CreatePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenName string, expiresIn int) (*types.PersonalAccessTokenGenerated, error) + DeletePATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) error + GetPATFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string, tokenID string) (*types.PersonalAccessToken, error) + GetAllPATsFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserId string) ([]*types.PersonalAccessToken, error) + GetNameServerGroupFunc func(ctx context.Context, accountID, userID, nsGroupID string) (*nbdns.NameServerGroup, error) + CreateNameServerGroupFunc func(ctx context.Context, accountID string, name, description string, nameServerList []nbdns.NameServer, groups []string, primary bool, domains []string, enabled bool, userID string, searchDomainsEnabled bool) (*nbdns.NameServerGroup, error) + SaveNameServerGroupFunc func(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error + DeleteNameServerGroupFunc func(ctx context.Context, accountID, nsGroupID, userID string) error + ListNameServerGroupsFunc func(ctx context.Context, accountID string, userID string) ([]*nbdns.NameServerGroup, error) + CreateUserFunc func(ctx context.Context, accountID, userID string, key *types.UserInfo) (*types.UserInfo, error) + GetAccountIDFromUserAuthFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) + DeleteAccountFunc func(ctx context.Context, accountID, userID string) error + GetDNSDomainFunc func(settings *types.Settings) string + StoreEventFunc func(ctx context.Context, initiatorID, targetID, accountID string, activityID activity.ActivityDescriber, meta map[string]any) + GetEventsFunc func(ctx context.Context, accountID, userID string) ([]*activity.Event, error) + GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) + SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error + GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) + LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) + InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error + GetAllConnectedPeersFunc func() (map[string]struct{}, error) + HasConnectedChannelFunc func(peerID string) bool + GetExternalCacheManagerFunc func() account.ExternalCacheManager + GetPostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) (*posture.Checks, error) + SavePostureChecksFunc func(ctx context.Context, accountID, userID string, postureChecks *posture.Checks, create bool) (*posture.Checks, error) + DeletePostureChecksFunc func(ctx context.Context, accountID, postureChecksID, userID string) error + ListPostureChecksFunc func(ctx context.Context, accountID, userID string) ([]*posture.Checks, error) + GetIdpManagerFunc func() idp.Manager + UpdateIntegratedValidatorGroupsFunc func(ctx context.Context, accountID string, userID string, groups []string) error + GroupValidationFunc func(ctx context.Context, accountId string, groups []string) (bool, error) + SyncPeerMetaFunc func(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error + FindExistingPostureCheckFunc func(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) + GetAccountIDForPeerKeyFunc func(ctx context.Context, peerKey string) (string, error) + GetAccountByIDFunc func(ctx context.Context, accountID string, userID string) (*types.Account, error) + GetUserByIDFunc func(ctx context.Context, id string) (*types.User, error) + GetAccountSettingsFunc func(ctx context.Context, accountID string, userID string) (*types.Settings, error) + DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error + BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) + GetStoreFunc func() store.Store + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) + GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) + GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) + GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) + GetAccountOnboardingFunc func(ctx context.Context, accountID, userID string) (*types.AccountOnboarding, error) + UpdateAccountOnboardingFunc func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) GetOrCreateAccountByPrivateDomainFunc func(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) } @@ -823,6 +823,14 @@ func (am *MockAccountManager) GetAccountOnboarding(ctx context.Context, accountI return nil, status.Errorf(codes.Unimplemented, "method GetAccountOnboarding is not implemented") } +// UpdateAccountOnboarding mocks UpdateAccountOnboarding of the AccountManager interface +func (am *MockAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID string, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + if am.UpdateAccountOnboardingFunc != nil { + return am.UpdateAccountOnboardingFunc(ctx, accountID, userID, onboarding) + } + return nil, status.Errorf(codes.Unimplemented, "method UpdateAccountOnboarding is not implemented") +} + // GetUserByID mocks GetUserByID of the AccountManager interface func (am *MockAccountManager) GetUserByID(ctx context.Context, id string) (*types.User, error) { if am.GetUserByIDFunc != nil { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index a6c4d56bf..9e33a6972 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -99,7 +99,7 @@ func NewSqlStore(ctx context.Context, db *gorm.DB, storeEngine types.Engine, met &types.SetupKey{}, &nbpeer.Peer{}, &types.User{}, &types.PersonalAccessToken{}, &types.Group{}, &types.Account{}, &types.Policy{}, &types.PolicyRule{}, &route.Route{}, &nbdns.NameServerGroup{}, &installation{}, &types.ExtraSettings{}, &posture.Checks{}, &nbpeer.NetworkAddress{}, - &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, + &networkTypes.Network{}, &routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{}, &types.AccountOnboarding{}, ) if err != nil { return nil, fmt.Errorf("auto migrate: %w", err) @@ -725,6 +725,34 @@ func (s *SqlStore) GetAccountMeta(ctx context.Context, lockStrength LockingStren return &accountMeta, nil } +// GetAccountOnboarding retrieves the onboarding information for a specific account. +func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) { + var accountOnboarding types.AccountOnboarding + result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + //accountOnboarding.AccountID = accountID + //return &accountOnboarding, nil + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + return &accountOnboarding, nil +} + +// SaveAccountOnboarding updates the onboarding information for a specific account. +func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error { + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Save(onboarding) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) + } + + return nil +} + func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index fab9048e5..2552f4791 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -353,9 +353,16 @@ func TestSqlite_DeleteAccount(t *testing.T) { t.Errorf("expecting 1 Accounts to be stored after SaveAccount()") } + o, err := store.GetAccountOnboarding(context.Background(), account.Id) + require.NoError(t, err) + require.Equal(t, o.AccountID, account.Id) + err = store.DeleteAccount(context.Background(), account) require.NoError(t, err) + _, err = store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "expecting error after removing DeleteAccount when getting onboarding") + if len(store.GetAllAccounts(context.Background())) != 0 { t.Errorf("expecting 0 Accounts to be stored after DeleteAccount()") } @@ -413,12 +420,21 @@ func Test_GetAccount(t *testing.T) { account, err := store.GetAccount(context.Background(), id) require.NoError(t, err) require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, false, account.Onboarding.OnboardingFlowPending) + + id = "9439-34653001fc3b-bf1c8084-ba50-4ce7" + + account, err = store.GetAccount(context.Background(), id) + require.NoError(t, err) + require.Equal(t, id, account.Id, "account id should match") + require.Equal(t, true, account.Onboarding.OnboardingFlowPending) _, err = store.GetAccount(context.Background(), "non-existing-account") assert.Error(t, err) parsedErr, ok := status.FromError(err) require.True(t, ok) require.Equal(t, status.NotFound, parsedErr.Type(), "should return not found error") + }) } @@ -2042,6 +2058,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty PeerInactivityExpirationEnabled: false, PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, }, + Onboarding: types.AccountOnboarding{SignupFormPending: true, OnboardingFlowPending: true}, } if err := acc.AddAllGroup(); err != nil { @@ -3386,6 +3403,63 @@ func TestSqlStore_GetAccountMeta(t *testing.T) { require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), accountMeta.CreatedAt.UTC()) } +func TestSqlStore_GetAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + a, err := store.GetAccount(context.Background(), accountID) + require.NoError(t, err) + t.Logf("Onboarding: %+v", a.Onboarding) + err = store.SaveAccount(context.Background(), a) + require.NoError(t, err) + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.NotNil(t, onboarding) + require.Equal(t, accountID, onboarding.AccountID) + require.Equal(t, time.Date(2024, time.October, 2, 14, 1, 38, 210000000, time.UTC), onboarding.CreatedAt.UTC()) +} + +func TestSqlStore_SaveAccountOnboarding(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + t.Run("New onboarding should be saved correctly", func(t *testing.T) { + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + onboarding := &types.AccountOnboarding{ + AccountID: accountID, + SignupFormPending: true, + OnboardingFlowPending: true, + } + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) + + t.Run("Existing onboarding should be updated correctly", func(t *testing.T) { + accountID := "9439-34653001fc3b-bf1c8084-ba50-4ce7" + onboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + + onboarding.OnboardingFlowPending = !onboarding.OnboardingFlowPending + onboarding.SignupFormPending = !onboarding.SignupFormPending + + err = store.SaveAccountOnboarding(context.Background(), onboarding) + require.NoError(t, err) + + savedOnboarding, err := store.GetAccountOnboarding(context.Background(), accountID) + require.NoError(t, err) + require.Equal(t, onboarding.SignupFormPending, savedOnboarding.SignupFormPending) + require.Equal(t, onboarding.OnboardingFlowPending, savedOnboarding.OnboardingFlowPending) + }) +} + func TestSqlStore_GetAnyAccountID(t *testing.T) { t.Run("should return account ID when accounts exist", func(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/extended-store.sql", t.TempDir()) diff --git a/management/server/store/store.go b/management/server/store/store.go index d41379b1c..9bdfa8b2f 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -52,6 +52,7 @@ type Store interface { GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) + GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error) GetAccountDomainAndCategory(ctx context.Context, lockStrength LockingStrength, accountID string) (string, string, error) GetAccountByUser(ctx context.Context, userID string) (*types.Account, error) @@ -74,6 +75,7 @@ type Store interface { SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) + SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByUserID(ctx context.Context, lockStrength LockingStrength, userID string) (*types.User, error) diff --git a/management/server/testdata/store.sql b/management/server/testdata/store.sql index 41b8fa2f7..02e20e057 100644 --- a/management/server/testdata/store.sql +++ b/management/server/testdata/store.sql @@ -1,4 +1,5 @@ CREATE TABLE `accounts` (`id` text,`created_by` text,`created_at` datetime,`domain` text,`domain_category` text,`is_domain_primary_account` numeric,`network_identifier` text,`network_net` text,`network_dns` text,`network_serial` integer,`dns_settings_disabled_management_groups` text,`settings_peer_login_expiration_enabled` numeric,`settings_peer_login_expiration` integer,`settings_regular_users_view_blocked` numeric,`settings_groups_propagation_enabled` numeric,`settings_jwt_groups_enabled` numeric,`settings_jwt_groups_claim_name` text,`settings_jwt_allow_groups` text,`settings_extra_peer_approval_enabled` numeric,`settings_extra_integrated_validator_groups` text,PRIMARY KEY (`id`)); +CREATE TABLE `account_onboardings` (`account_id` text, `created_at` datetime,`updated_at` datetime, `onboarding_flow_pending` numeric, `signup_form_pending` numeric, PRIMARY KEY (`account_id`)); CREATE TABLE `setup_keys` (`id` text,`account_id` text,`key` text,`name` text,`type` text,`created_at` datetime,`expires_at` datetime,`updated_at` datetime,`revoked` numeric,`used_times` integer,`last_used` datetime DEFAULT NULL,`auto_groups` text,`usage_limit` integer,`ephemeral` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_setup_keys_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `peers` (`id` text,`account_id` text,`key` text,`setup_key` text,`ip` text,`meta_hostname` text,`meta_go_os` text,`meta_kernel` text,`meta_core` text,`meta_platform` text,`meta_os` text,`meta_os_version` text,`meta_wt_version` text,`meta_ui_version` text,`meta_kernel_version` text,`meta_network_addresses` text,`meta_system_serial_number` text,`meta_system_product_name` text,`meta_system_manufacturer` text,`meta_environment` text,`meta_files` text,`name` text,`dns_label` text,`peer_status_last_seen` datetime,`peer_status_connected` numeric,`peer_status_login_expired` numeric,`peer_status_requires_approval` numeric,`user_id` text,`ssh_key` text,`ssh_enabled` numeric,`login_expiration_enabled` numeric,`last_login` datetime,`created_at` datetime,`ephemeral` numeric,`location_connection_ip` text,`location_country_code` text,`location_city_name` text,`location_geo_name_id` integer,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_peers_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); CREATE TABLE `users` (`id` text,`account_id` text,`role` text,`is_service_user` numeric,`non_deletable` numeric,`service_user_name` text,`auto_groups` text,`blocked` numeric,`last_login` datetime,`created_at` datetime,`issued` text DEFAULT "api",`integration_ref_id` integer,`integration_ref_integration_type` text,PRIMARY KEY (`id`),CONSTRAINT `fk_accounts_users_g` FOREIGN KEY (`account_id`) REFERENCES `accounts`(`id`)); @@ -38,7 +39,8 @@ CREATE INDEX `idx_networks_id` ON `networks`(`id`); CREATE INDEX `idx_networks_account_id` ON `networks`(`account_id`); INSERT INTO accounts VALUES('bf1c8084-ba50-4ce7-9439-34653001fc3b','edafee4e-63fb-11ec-90d6-0242ac120003','2024-10-02 16:03:06.778746+02:00','test.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); -INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); +INSERT INTO accounts VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','90d6-0242ac120003-edafee4e-63fb-11ec','2024-10-02 16:01:38.210000+02:00','test2.com','private',1,'af1c8024-ha40-4ce2-9418-34653101fc3c','{"IP":"100.64.0.0","Mask":"//8AAA=="}','',0,'[]',0,86400000000000,0,0,0,'',NULL,NULL,NULL); +INSERT INTO account_onboardings VALUES('9439-34653001fc3b-bf1c8084-ba50-4ce7','2024-10-02 16:01:38.210000+02:00','2021-08-19 20:46:20.005936822+02:00',1,0);INSERT INTO "groups" VALUES('cs1tnh0hhcjnqoiuebeg','bf1c8084-ba50-4ce7-9439-34653001fc3b','All','api','[]',0,''); INSERT INTO setup_keys VALUES('','bf1c8084-ba50-4ce7-9439-34653001fc3b','A2C8E62B-38F5-4553-B31E-DD66C696CEBB','Default key','reusable','2021-08-19 20:46:20.005936822+02:00','2321-09-18 20:46:20.005936822+02:00','2021-08-19 20:46:20.005936822+02:00',0,0,NULL,'["cs1tnh0hhcjnqoiuebeg"]',0,0); INSERT INTO users VALUES('a23efe53-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','owner',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 16:03:06.779156+02:00','api',0,''); diff --git a/management/server/types/account.go b/management/server/types/account.go index ac05ee363..dc9a33e09 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -82,12 +82,11 @@ type Account struct { DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"` PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"` // Settings is a dictionary of Account settings - Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` - + Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"` Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"` NetworkRouters []*routerTypes.NetworkRouter `gorm:"foreignKey:AccountID;references:id"` NetworkResources []*resourceTypes.NetworkResource `gorm:"foreignKey:AccountID;references:id"` - Onboarding AccountOnboarding + Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` } // Subclass used in gorm to only load network and not whole account @@ -113,6 +112,12 @@ type AccountOnboarding struct { UpdatedAt time.Time } +// IsEqual compares two AccountOnboarding objects and returns true if they are equal +func (o AccountOnboarding) IsEqual(onboarding AccountOnboarding) bool { + return o.OnboardingFlowPending == onboarding.OnboardingFlowPending && + o.SignupFormPending == onboarding.SignupFormPending +} + // GetRoutesToSync returns the enabled routes for the peer ID and the routes // from the ACL peers that have distribution groups associated with the peer ID. // Please mind, that the returned route.Route objects will contain Peer.Key instead of Peer.ID. From 7a5edb3894e31c9c84d8bab1aab681f1ba6e927c Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 1 Jul 2025 23:25:52 +0200 Subject: [PATCH 03/12] create accounts with pending onboarding --- management/server/account.go | 4 +++ management/server/account_test.go | 57 +++++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) diff --git a/management/server/account.go b/management/server/account.go index dd3a72f60..4f3220de0 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1778,6 +1778,10 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty PeerInactivityExpiration: types.DefaultPeerInactivityExpiration, RoutingPeerDNSResolutionEnabled: true, }, + Onboarding: types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + }, } if err := acc.AddAllGroup(); err != nil { diff --git a/management/server/account_test.go b/management/server/account_test.go index c3b1f31a6..dc21cf71f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3440,3 +3440,60 @@ func TestPropagateUserGroupMemberships(t *testing.T) { } }) } + +func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + assert.Equal(t, account.Id, onboarding.AccountID) + assert.Equal(t, true, onboarding.OnboardingFlowPending) + assert.Equal(t, true, onboarding.SignupFormPending) + if onboarding.UpdatedAt.IsZero() { + t.Errorf("Onboarding was not retrieved from the store") + } +} + +func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") + require.NoError(t, err) + + onboarding := &types.AccountOnboarding{ + OnboardingFlowPending: true, + SignupFormPending: true, + } + + t.Run("update onboarding with no change", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + if updated.UpdatedAt.IsZero() { + t.Errorf("Onboarding was updated in the store") + } + }) + + onboarding.OnboardingFlowPending = false + onboarding.SignupFormPending = false + + t.Run("update onboarding", func(t *testing.T) { + updated, err := manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, onboarding) + require.NoError(t, err) + require.NotNil(t, updated) + assert.Equal(t, onboarding.OnboardingFlowPending, updated.OnboardingFlowPending) + assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) + }) + + t.Run("update onboarding with no change", func(t *testing.T) { + _, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil) + require.Error(t, err) + }) +} From d806fc4a036a23d81ee05dd67b1a522800e5d379 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Wed, 2 Jul 2025 01:38:40 +0200 Subject: [PATCH 04/12] handle empty onboard to avoid breaking clients and dashboard --- management/server/account.go | 88 +++++++++++-------- management/server/account_test.go | 36 +++++--- .../accounts/accounts_handler_test.go | 22 +++-- management/server/status/error.go | 5 ++ management/server/store/sql_store.go | 8 +- 5 files changed, 101 insertions(+), 58 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 4f3220de0..e57d7092b 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -275,41 +275,6 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { return am.idpManager } -func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) - defer unlock() - - if newOnboarding == nil { - return nil, status.Errorf(status.InvalidArgument, "new onboarding data cannot be nil") - } - - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) - if err != nil { - return nil, fmt.Errorf("failed to validate user permissions: %w", err) - } - - if !allowed { - return nil, status.NewPermissionDeniedError() - } - - oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("failed to get account onboarding: %w", err) - } - if oldOnboarding.IsEqual(*newOnboarding) { - log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) - return oldOnboarding, nil - } - - newOnboarding.AccountID = accountID - err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) - if err != nil { - return nil, fmt.Errorf("failed to update account onboarding: %w", err) - } - - return newOnboarding, nil -} - // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. @@ -1234,7 +1199,58 @@ func (am *DefaultAccountManager) GetAccountOnboarding(ctx context.Context, accou return nil, status.NewPermissionDeniedError() } - return am.Store.GetAccountOnboarding(ctx, accountID) + onboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + log.Errorf("failed to get account onboarding for accountssssssss %s: %v", accountID, err) + return nil, err + } + + if onboarding == nil { + onboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + return onboarding, nil +} + +func (am *DefaultAccountManager) UpdateAccountOnboarding(ctx context.Context, accountID, userID string, newOnboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { + allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) + if err != nil { + return nil, fmt.Errorf("failed to validate user permissions: %w", err) + } + + if !allowed { + return nil, status.NewPermissionDeniedError() + } + + oldOnboarding, err := am.Store.GetAccountOnboarding(ctx, accountID) + if err != nil && err.Error() != status.NewAccountOnboardingNotFoundError(accountID).Error() { + return nil, fmt.Errorf("failed to get account onboarding: %w", err) + } + + if oldOnboarding == nil { + oldOnboarding = &types.AccountOnboarding{ + AccountID: accountID, + } + } + + if newOnboarding == nil { + return oldOnboarding, nil + } + + if oldOnboarding.IsEqual(*newOnboarding) { + log.WithContext(ctx).Debugf("no changes in onboarding for account %s", accountID) + return oldOnboarding, nil + } + + newOnboarding.AccountID = accountID + err = am.Store.SaveAccountOnboarding(ctx, newOnboarding) + if err != nil { + return nil, fmt.Errorf("failed to update account onboarding: %w", err) + } + + return newOnboarding, nil } func (am *DefaultAccountManager) GetAccountIDFromUserAuth(ctx context.Context, userAuth nbcontext.UserAuth) (string, string, error) { diff --git a/management/server/account_test.go b/management/server/account_test.go index dc21cf71f..137179fa0 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3448,15 +3448,29 @@ func TestDefaultAccountManager_GetAccountOnboarding(t *testing.T) { account, err := manager.GetOrCreateAccountByUser(context.Background(), userID, "") require.NoError(t, err) - onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) - require.NoError(t, err) - require.NotNil(t, onboarding) - assert.Equal(t, account.Id, onboarding.AccountID) - assert.Equal(t, true, onboarding.OnboardingFlowPending) - assert.Equal(t, true, onboarding.SignupFormPending) - if onboarding.UpdatedAt.IsZero() { - t.Errorf("Onboarding was not retrieved from the store") - } + t.Run("should return account onboarding when onboarding exist", func(t *testing.T) { + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + assert.Equal(t, account.Id, onboarding.AccountID) + assert.Equal(t, true, onboarding.OnboardingFlowPending) + assert.Equal(t, true, onboarding.SignupFormPending) + if onboarding.UpdatedAt.IsZero() { + t.Errorf("Onboarding was not retrieved from the store") + } + }) + + t.Run("should return account onboarding when onboard don't exist", func(t *testing.T) { + account.Id = "with-zero-onboarding" + account.Onboarding = types.AccountOnboarding{} + err = manager.Store.SaveAccount(context.Background(), account) + require.NoError(t, err) + onboarding, err := manager.GetAccountOnboarding(context.Background(), account.Id, userID) + require.NoError(t, err) + require.NotNil(t, onboarding) + _, err = manager.Store.GetAccountOnboarding(context.Background(), account.Id) + require.Error(t, err, "should return error when onboarding is not set") + }) } func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { @@ -3492,8 +3506,8 @@ func TestDefaultAccountManager_UpdateAccountOnboarding(t *testing.T) { assert.Equal(t, onboarding.SignupFormPending, updated.SignupFormPending) }) - t.Run("update onboarding with no change", func(t *testing.T) { + t.Run("update onboarding with no onboarding", func(t *testing.T) { _, err = manager.UpdateAccountOnboarding(context.Background(), account.Id, userID, nil) - require.Error(t, err) + require.NoError(t, err) }) } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index fa901677b..dbf0c22bc 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -61,9 +61,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { }, nil }, UpdateAccountOnboardingFunc: func(ctx context.Context, accountID, userID string, onboarding *types.AccountOnboarding) (*types.AccountOnboarding, error) { - if onboarding == nil { - return nil, status.Errorf(status.InvalidArgument, "onboarding cannot be nil") - } return &types.AccountOnboarding{ OnboardingFlowPending: true, SignupFormPending: true, @@ -194,13 +191,26 @@ func TestAccounts_AccountsHandler(t *testing.T) { expectedID: accountID, }, { - name: "PutAccount fail without onboarding", + name: "PutAccount OK without onboarding", expectedBody: true, requestType: http.MethodPut, requestPath: "/api/accounts/" + accountID, requestBody: bytes.NewBufferString("{\"settings\": {\"peer_login_expiration\": 15552000,\"peer_login_expiration_enabled\": false,\"jwt_groups_enabled\":true,\"jwt_groups_claim_name\":\"roles\",\"jwt_allow_groups\":[\"test\"],\"regular_users_view_blocked\":true}}"), - expectedStatus: http.StatusUnprocessableEntity, - expectedArray: false, + expectedStatus: http.StatusOK, + expectedSettings: api.AccountSettings{ + PeerLoginExpiration: 15552000, + PeerLoginExpirationEnabled: false, + GroupsPropagationEnabled: br(false), + JwtGroupsClaimName: sr("roles"), + JwtGroupsEnabled: br(true), + JwtAllowGroups: &[]string{"test"}, + RegularUsersViewBlocked: true, + RoutingPeerDnsResolutionEnabled: br(false), + LazyConnectionEnabled: br(false), + DnsDomain: sr(""), + }, + expectedArray: false, + expectedID: accountID, }, { name: "Update account failure with high peer_login_expiration more than 180 days", diff --git a/management/server/status/error.go b/management/server/status/error.go index 5a6f6d1a7..47c236e93 100644 --- a/management/server/status/error.go +++ b/management/server/status/error.go @@ -90,6 +90,11 @@ func NewAccountNotFoundError(accountKey string) error { return Errorf(NotFound, "account not found: %s", accountKey) } +// NewAccountOnboardingNotFoundError creates a new Error with NotFound type for a missing account onboarding +func NewAccountOnboardingNotFoundError(accountKey string) error { + return Errorf(NotFound, "account onboarding not found: %s", accountKey) +} + // NewPeerNotPartOfAccountError creates a new Error with PermissionDenied type for a peer not being part of an account func NewPeerNotPartOfAccountError() error { return Errorf(PermissionDenied, "peer is not part of this account") diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 9e33a6972..c24ebf6ec 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -730,12 +730,10 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) ( var accountOnboarding types.AccountOnboarding result := s.db.Model(&accountOnboarding).First(&accountOnboarding, accountIDCondition, accountID) if result.Error != nil { - log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) if errors.Is(result.Error, gorm.ErrRecordNotFound) { - //accountOnboarding.AccountID = accountID - //return &accountOnboarding, nil - return nil, status.NewAccountNotFoundError(accountID) + return nil, status.NewAccountOnboardingNotFoundError(accountID) } + log.WithContext(ctx).Errorf("error when getting account onboarding %s from the store: %s", accountID, result.Error) return nil, status.NewGetAccountFromStoreError(result.Error) } @@ -744,7 +742,7 @@ func (s *SqlStore) GetAccountOnboarding(ctx context.Context, accountID string) ( // SaveAccountOnboarding updates the onboarding information for a specific account. func (s *SqlStore) SaveAccountOnboarding(ctx context.Context, onboarding *types.AccountOnboarding) error { - result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Save(onboarding) + result := s.db.Clauses(clause.OnConflict{UpdateAll: true}).Create(onboarding) if result.Error != nil { log.WithContext(ctx).Errorf("error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) return status.Errorf(status.Internal, "error when saving account onboarding %s in the store: %s", onboarding.AccountID, result.Error) From 1ffc8933de13113c45ef41a7d6e207222e10e19d Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 7 Jul 2025 19:15:54 +0200 Subject: [PATCH 05/12] add rate limit --- management/server/grpcserver.go | 85 +++++++++++++++++++++++++++++++-- 1 file changed, 81 insertions(+), 4 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 2b27f9e0f..86d4e7fab 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,8 +3,11 @@ package server import ( "context" "fmt" + "math/rand/v2" "net" "net/netip" + "os" + "strconv" "strings" "sync" "time" @@ -13,6 +16,7 @@ import ( "github.com/golang/protobuf/ptypes/timestamp" "github.com/grpc-ecosystem/go-grpc-middleware/v2/interceptors/realip" log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "google.golang.org/grpc/codes" "google.golang.org/grpc/peer" @@ -47,6 +51,10 @@ type GRPCServer struct { ephemeralManager *EphemeralManager peerLocks sync.Map authManager auth.Manager + syncLimiter *rate.Limiter + loginLimiterStore sync.Map + loginPeerBooster int + loginPeerLimit rate.Limit } // NewServer creates a new Management server @@ -76,6 +84,41 @@ func NewServer( } } + multiplier := time.Minute + d, e := time.ParseDuration(os.Getenv("NB_LOGIN_RATE")) + if e == nil { + multiplier = d + } + + loginRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_RATE_PER_M")) + if loginRatePerS == 0 || err != nil { + loginRatePerS = 200 + } + + loginBurst, err := strconv.Atoi(os.Getenv("NB_LOGIN_BURST")) + if loginBurst == 0 || err != nil { + loginBurst = 200 + } + log.WithContext(ctx).Infof("login burst limit set to %d", loginBurst) + + loginPeerRatePerS, err := strconv.Atoi(os.Getenv("NB_LOGIN_PEER_RATE_PER_M")) + if loginPeerRatePerS == 0 || err != nil { + loginPeerRatePerS = 200 + } + log.WithContext(ctx).Infof("login rate limit set to %d/min", loginRatePerS) + + syncRatePerS, err := strconv.Atoi(os.Getenv("NB_SYNC_RATE_PER_M")) + if syncRatePerS == 0 || err != nil { + syncRatePerS = 20000 + } + log.WithContext(ctx).Infof("sync rate limit set to %d/min", syncRatePerS) + + syncBurst, err := strconv.Atoi(os.Getenv("NB_SYNC_BURST")) + if syncBurst == 0 || err != nil { + syncBurst = 30000 + } + log.WithContext(ctx).Infof("sync burst limit set to %d", syncBurst) + return &GRPCServer{ wgKey: key, // peerKey -> event channel @@ -87,6 +130,9 @@ func NewServer( authManager: authManager, appMetrics: appMetrics, ephemeralManager: ephemeralManager, + syncLimiter: rate.NewLimiter(rate.Every(time.Minute/time.Duration(syncRatePerS)), syncBurst), + loginPeerLimit: rate.Every(multiplier / time.Duration(loginPeerRatePerS)), + loginPeerBooster: loginBurst, }, nil } @@ -128,11 +174,18 @@ func getRealIP(ctx context.Context) net.IP { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - reqStart := time.Now() if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountSyncRequest() } + if !s.syncLimiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey) + return status.Errorf(codes.Internal, "temp rate limit reached") + } + + reqStart := time.Now() + ctx := srv.Context() syncReq := &proto.SyncRequest{} @@ -428,15 +481,39 @@ func (s *GRPCServer) parseRequest(ctx context.Context, req *proto.EncryptedMessa // In case it isn't, the endpoint checks whether setup key is provided within the request and tries to register a peer. // In case of the successful registration login is also successful func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*proto.EncryptedMessage, error) { + if s.appMetrics != nil { + s.appMetrics.GRPCMetrics().CountLoginRequest() + } + + limiterIface, ok := s.loginLimiterStore.Load(req.WgPubKey) + if !ok { + // Create new limiter for this peer + newLimiter := rate.NewLimiter(s.loginPeerLimit, s.loginPeerBooster) + s.loginLimiterStore.Store(req.WgPubKey, newLimiter) + + if !newLimiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) + return nil, fmt.Errorf("temp rate limit reached (new peer limit)") + } + } else { + // Use existing limiter for this peer + limiter := limiterIface.(*rate.Limiter) + if !limiter.Allow() { + time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) + return nil, fmt.Errorf("temp rate limit reached (peer limit)") + } + } reqStart := time.Now() defer func() { if s.appMetrics != nil { s.appMetrics.GRPCMetrics().CountLoginRequestDuration(time.Since(reqStart)) } }() - if s.appMetrics != nil { - s.appMetrics.GRPCMetrics().CountLoginRequest() - } + //if s.appMetrics != nil { + // s.appMetrics.GRPCMetrics().CountLoginRequest() + //} realIP := getRealIP(ctx) log.WithContext(ctx).Debugf("Login request from peer [%s] [%s]", req.WgPubKey, realIP.String()) From bcccd6500893a754f2d0c35fa81ffef86ba56b95 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 7 Jul 2025 19:35:48 +0200 Subject: [PATCH 06/12] add rate limit --- management/server/grpcserver.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/management/server/grpcserver.go b/management/server/grpcserver.go index 86d4e7fab..a9fb60c21 100644 --- a/management/server/grpcserver.go +++ b/management/server/grpcserver.go @@ -3,7 +3,6 @@ package server import ( "context" "fmt" - "math/rand/v2" "net" "net/netip" "os" @@ -179,7 +178,6 @@ func (s *GRPCServer) Sync(req *proto.EncryptedMessage, srv proto.ManagementServi } if !s.syncLimiter.Allow() { - time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) log.Warnf("sync rate limit exceeded for peer %s", req.WgPubKey) return status.Errorf(codes.Internal, "temp rate limit reached") } @@ -492,7 +490,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p s.loginLimiterStore.Store(req.WgPubKey, newLimiter) if !newLimiter.Allow() { - time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + //time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) return nil, fmt.Errorf("temp rate limit reached (new peer limit)") } @@ -500,7 +498,7 @@ func (s *GRPCServer) Login(ctx context.Context, req *proto.EncryptedMessage) (*p // Use existing limiter for this peer limiter := limiterIface.(*rate.Limiter) if !limiter.Allow() { - time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) + //time.Sleep(time.Second + (time.Millisecond * time.Duration(rand.IntN(20)*100))) log.WithContext(ctx).Warnf("rate limit exceeded for peer %s", req.WgPubKey) return nil, fmt.Errorf("temp rate limit reached (peer limit)") } From 4c5808831108ac49656246b7cbed981cfd6d83fe Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Mon, 7 Jul 2025 19:38:02 +0200 Subject: [PATCH 07/12] fix go mod --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index a12058278..6ddd0dac5 100644 --- a/go.mod +++ b/go.mod @@ -105,6 +105,7 @@ require ( golang.org/x/oauth2 v0.24.0 golang.org/x/sync v0.13.0 golang.org/x/term v0.31.0 + golang.org/x/time v0.5.0 google.golang.org/api v0.177.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 @@ -240,7 +241,6 @@ require ( golang.org/x/image v0.18.0 // indirect golang.org/x/mod v0.17.0 // indirect golang.org/x/text v0.24.0 // indirect - golang.org/x/time v0.5.0 // indirect golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20240509183442-62759503f434 // indirect From 7aa2ca87f21bf175625b9214355e46e7c3fdf356 Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 8 Jul 2025 01:41:07 +0200 Subject: [PATCH 08/12] split call to BufferUpdateAccountPeers when system user initiates --- management/server/account/manager.go | 1 + management/server/ephemeral.go | 7 +++++++ management/server/mock_server/account_mock.go | 4 ++++ management/server/peer.go | 2 +- 4 files changed, 13 insertions(+), 1 deletion(-) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ed17fa5ec..f8aa2756a 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -112,6 +112,7 @@ type Manager interface { GetAccountSettings(ctx context.Context, accountID string, userID string) (*types.Settings, error) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error UpdateAccountPeers(ctx context.Context, accountID string) + BufferUpdateAccountPeers(ctx context.Context, accountID string) BuildUserInfosForAccount(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 3cb9b7536..7cdfb51ce 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -164,13 +164,20 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { e.peersLock.Unlock() + bufferAccountCall := make(map[string]struct{}) + for id, p := range deletePeers { log.WithContext(ctx).Debugf("delete ephemeral peer: %s", id) err := e.accountManager.DeletePeer(ctx, p.accountID, id, activity.SystemInitiator) if err != nil { log.WithContext(ctx).Errorf("failed to delete ephemeral peer: %s", err) + } else { + bufferAccountCall[p.accountID] = struct{}{} } } + for accountID := range bufferAccountCall { + e.accountManager.BufferUpdateAccountPeers(ctx, accountID) + } } func (e *EphemeralManager) addPeer(accountID string, peerID string, deadline time.Time) { diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 8837f9f50..4004f1b57 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -126,6 +126,10 @@ func (am *MockAccountManager) UpdateAccountPeers(ctx context.Context, accountID // do nothing } +func (am *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + // do nothing +} + func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, userID, keyID string) error { if am.DeleteSetupKeyFunc != nil { return am.DeleteSetupKeyFunc(ctx, accountID, userID, keyID) diff --git a/management/server/peer.go b/management/server/peer.go index 254048a96..77e69d370 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -391,7 +391,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer storeEvent() } - if updateAccountPeers { + if updateAccountPeers && userID != activity.SystemInitiator { am.BufferUpdateAccountPeers(ctx, accountID) } From ad0b78a7acc917ded6b58469204a3fafe261897e Mon Sep 17 00:00:00 2001 From: Maycon Santos Date: Tue, 8 Jul 2025 02:04:05 +0200 Subject: [PATCH 09/12] add request id to ephemeral cleanup --- management/server/ephemeral.go | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/management/server/ephemeral.go b/management/server/ephemeral.go index 7cdfb51ce..4bf7dba2b 100644 --- a/management/server/ephemeral.go +++ b/management/server/ephemeral.go @@ -5,10 +5,12 @@ import ( "sync" "time" + "github.com/google/uuid" log "github.com/sirupsen/logrus" nbAccount "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + nbContext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/store" ) @@ -138,6 +140,9 @@ func (e *EphemeralManager) loadEphemeralPeers(ctx context.Context) { func (e *EphemeralManager) cleanup(ctx context.Context) { log.Tracef("on ephemeral cleanup") + reqID := uuid.New().String() + //nolint + ctx = context.WithValue(ctx, nbContext.RequestIDKey, reqID) deletePeers := make(map[string]*ephemeralPeer) e.peersLock.Lock() @@ -176,6 +181,7 @@ func (e *EphemeralManager) cleanup(ctx context.Context) { } } for accountID := range bufferAccountCall { + log.WithContext(ctx).Debugf("ephemeral - buffer update account peers for account: %s", accountID) e.accountManager.BufferUpdateAccountPeers(ctx, accountID) } } From 470b80c1b8f09f5383be8f52c00cd1feb95beecd Mon Sep 17 00:00:00 2001 From: Pedro Costa <550684+pnmcosta@users.noreply.github.com> Date: Tue, 8 Jul 2025 08:42:45 +0100 Subject: [PATCH 10/12] get extra settings only once per updateaccountpeers --- management/server/peer.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/management/server/peer.go b/management/server/peer.go index 77e69d370..68897f444 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1192,6 +1192,12 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account return } + extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) + return + } + for _, peer := range account.Peers { if !am.peersUpdateManager.HasChannel(peer.ID) { log.WithContext(ctx).Tracef("peer %s doesn't have a channel, skipping network map update", peer.ID) @@ -1226,12 +1232,6 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account } am.metrics.UpdateChannelMetrics().CountMergeNetworkMapDuration(time.Since(start)) - extraSetting, err := am.settingsManager.GetExtraSettings(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get flow enabled status: %v", err) - return - } - start = time.Now() update := toSyncResponse(ctx, nil, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting) am.metrics.UpdateChannelMetrics().CountToSyncResponseDuration(time.Since(start)) From 7d9ca73f6c7c0d00110c5dee38e51a5a4c8aeede Mon Sep 17 00:00:00 2001 From: Pedro Costa <550684+pnmcosta@users.noreply.github.com> Date: Tue, 8 Jul 2025 08:58:31 +0100 Subject: [PATCH 11/12] fix ephemeral test --- management/server/ephemeral_test.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/management/server/ephemeral_test.go b/management/server/ephemeral_test.go index 3cf6ae7f3..7dbecf2da 100644 --- a/management/server/ephemeral_test.go +++ b/management/server/ephemeral_test.go @@ -37,6 +37,10 @@ func (a MocAccountManager) DeletePeer(_ context.Context, accountID, peerID, user return nil //nolint:nil } +func (a MocAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) { + // noop +} + func (a MocAccountManager) GetStore() store.Store { return a.store } From 49f083a372b565a76101a9b9d8edcbeac83bcf2c Mon Sep 17 00:00:00 2001 From: Pedro Costa <550684+pnmcosta@users.noreply.github.com> Date: Tue, 8 Jul 2025 09:07:26 +0100 Subject: [PATCH 12/12] fix dns and client tests --- management/client/client_test.go | 6 ++++++ management/server/dns_test.go | 2 ++ 2 files changed, 8 insertions(+) diff --git a/management/client/client_test.go b/management/client/client_test.go index c163d1833..1847af73e 100644 --- a/management/client/client_test.go +++ b/management/client/client_test.go @@ -87,6 +87,12 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { ). Return(&types.Settings{}, nil). AnyTimes() + settingsMockManager. + EXPECT(). + GetExtraSettings(gomock.Any(), gomock.Any()). + Return(&types.ExtraSettings{}, nil). + AnyTimes() + permissionsManagerMock := permissions.NewMockManager(ctrl) permissionsManagerMock. EXPECT(). diff --git a/management/server/dns_test.go b/management/server/dns_test.go index 02bb042d7..31c944a25 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -216,6 +216,8 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { t.Cleanup(ctrl.Finish) settingsMockManager := settings.NewMockManager(ctrl) + // return empty extra settings for expected calls to UpdateAccountPeers + settingsMockManager.EXPECT().GetExtraSettings(gomock.Any(), gomock.Any()).Return(&types.ExtraSettings{}, nil).AnyTimes() permissionsManager := permissions.NewManager(store) return BuildManager(context.Background(), store, NewPeersUpdateManager(nil), nil, "", "netbird.test", eventStore, nil, false, MocIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) }