diff --git a/management/server/account.go b/management/server/account.go index 0f60bc91c..b9eb3348d 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1952,20 +1952,19 @@ func (am *DefaultAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.C return nil, false, status.Errorf(status.Internal, "failed to get or create new account by private domain") } -func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { - var account *types.Account +func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error { err := am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { var err error - account, err = transaction.GetAccount(ctx, accountId) + ok, domain, err := transaction.IsPrimaryAccount(ctx, accountId) if err != nil { return err } - if account.IsDomainPrimaryAccount { + if ok { return nil } - existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, account.Domain) + existingPrimaryAccountID, err := transaction.GetAccountIDByPrivateDomain(ctx, store.LockingStrengthNone, domain) // error is not a not found error if handleNotFound(err) != nil { @@ -1981,9 +1980,7 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc return status.Errorf(status.Internal, "cannot update account to primary") } - account.IsDomainPrimaryAccount = true - - if err := transaction.SaveAccount(ctx, account); err != nil { + if err := transaction.MarkAccountPrimary(ctx, accountId); err != nil { log.WithContext(ctx).WithFields(log.Fields{ "accountId": accountId, }).Errorf("failed to update account to primary: %v", err) @@ -1993,10 +1990,10 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc return nil }) if err != nil { - return nil, err + return err } - return account, nil + return nil } // propagateUserGroupMemberships propagates all account users' group memberships to their peers. @@ -2067,14 +2064,12 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t Mask: net.CIDRMask(newNetworkRange.Bits(), newNetworkRange.Addr().BitLen()), } - account, err := transaction.GetAccount(ctx, accountID) + err := transaction.UpdateAccountNetwork(ctx, accountID, newIPNet) if err != nil { return err } - account.Network.Net = newIPNet - - peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthNone, accountID, "", "") + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthUpdate, accountID, "", "") if err != nil { return err } @@ -2094,10 +2089,6 @@ func (am *DefaultAccountManager) reallocateAccountPeerIPs(ctx context.Context, t takenIPs = append(takenIPs, newIP) } - if err = transaction.SaveAccount(ctx, account); err != nil { - return err - } - for _, peer := range peers { if err = transaction.SavePeer(ctx, accountID, peer); err != nil { return status.Errorf(status.Internal, "save updated peer %s: %v", peer.ID, err) diff --git a/management/server/account/manager.go b/management/server/account/manager.go index ee82346f3..f5af68f93 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -7,7 +7,6 @@ import ( "time" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/activity" nbcache "github.com/netbirdio/netbird/management/server/cache" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -18,6 +17,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) type ExternalCacheManager nbcache.UserDataCache @@ -120,7 +120,7 @@ type Manager interface { SyncUserJWTGroups(ctx context.Context, userAuth nbcontext.UserAuth) error GetStore() store.Store GetOrCreateAccountByPrivateDomain(ctx context.Context, initiatorId, domain string) (*types.Account, bool, error) - UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) + UpdateToPrimaryAccount(ctx context.Context, accountId string) error GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) GetCurrentUserInfo(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) } diff --git a/management/server/account_test.go b/management/server/account_test.go index 0c618a8a3..252be23f7 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -3250,11 +3250,13 @@ func Test_GetCreateAccountByPrivateDomain(t *testing.T) { assert.Equal(t, 0, len(account2.Users)) assert.Equal(t, 0, len(account2.SetupKeys)) - account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) + err = manager.UpdateToPrimaryAccount(ctx, account.Id) + assert.NoError(t, err) + account, err = manager.Store.GetAccount(ctx, account.Id) assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) - _, err = manager.UpdateToPrimaryAccount(ctx, account2.Id) + err = manager.UpdateToPrimaryAccount(ctx, account2.Id) assert.Error(t, err, "should not be able to update a second account to primary") } @@ -3275,7 +3277,9 @@ func Test_UpdateToPrimaryAccount(t *testing.T) { assert.False(t, account.IsDomainPrimaryAccount) assert.Equal(t, domain, account.Domain) - account, err = manager.UpdateToPrimaryAccount(ctx, account.Id) + err = manager.UpdateToPrimaryAccount(ctx, account.Id) + assert.NoError(t, err) + account, err = manager.Store.GetAccount(ctx, account.Id) assert.NoError(t, err) assert.True(t, account.IsDomainPrimaryAccount) diff --git a/management/server/integrated_validator.go b/management/server/integrated_validator.go index 73abacc36..509022015 100644 --- a/management/server/integrated_validator.go +++ b/management/server/integrated_validator.go @@ -50,23 +50,23 @@ func (am *DefaultAccountManager) UpdateIntegratedValidator(ctx context.Context, defer unlock() return am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - a, err := transaction.GetAccount(ctx, accountID) + settings, err := transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) if err != nil { return err } var extra *types.ExtraSettings - if a.Settings.Extra != nil { - extra = a.Settings.Extra + if settings.Extra != nil { + extra = settings.Extra } else { extra = &types.ExtraSettings{} - a.Settings.Extra = extra + settings.Extra = extra } extra.IntegratedValidator = validator extra.IntegratedValidatorGroups = groups - return transaction.SaveAccount(ctx, a) + return transaction.SaveAccountSettings(ctx, accountID, settings) }) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 1ae432412..1d44068d2 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -10,7 +10,6 @@ import ( "google.golang.org/grpc/status" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" nbcontext "github.com/netbirdio/netbird/management/server/context" @@ -21,6 +20,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" ) var _ account.Manager = (*MockAccountManager)(nil) @@ -114,7 +114,7 @@ type MockAccountManager struct { DeleteSetupKeyFunc func(ctx context.Context, accountID, userID, keyID string) error BuildUserInfosForAccountFunc func(ctx context.Context, accountID, initiatorUserID string, accountUsers []*types.User) (map[string]*types.UserInfo, error) GetStoreFunc func() store.Store - UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) (*types.Account, error) + UpdateToPrimaryAccountFunc func(ctx context.Context, accountId string) error GetOwnerInfoFunc func(ctx context.Context, accountID string) (*types.UserInfo, error) GetCurrentUserInfoFunc func(ctx context.Context, userAuth nbcontext.UserAuth) (*users.UserInfoWithPermissions, error) GetAccountMetaFunc func(ctx context.Context, accountID, userID string) (*types.AccountMeta, error) @@ -933,11 +933,11 @@ func (am *MockAccountManager) GetOrCreateAccountByPrivateDomain(ctx context.Cont return nil, false, status.Errorf(codes.Unimplemented, "method GetOrCreateAccountByPrivateDomainFunc is not implemented") } -func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) (*types.Account, error) { +func (am *MockAccountManager) UpdateToPrimaryAccount(ctx context.Context, accountId string) error { if am.UpdateToPrimaryAccountFunc != nil { return am.UpdateToPrimaryAccountFunc(ctx, accountId) } - return nil, status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") + return status.Errorf(codes.Unimplemented, "method UpdateToPrimaryAccount is not implemented") } func (am *MockAccountManager) GetOwnerInfo(ctx context.Context, accountId string) (*types.UserInfo, error) { diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8aa56f7b0..64f80776b 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2832,3 +2832,57 @@ func getDebuggingCtx(grpcCtx context.Context) (context.Context, context.CancelFu }() return ctx, cancel } + +func (s *SqlStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) { + var info types.PrimaryAccountInfo + result := s.db.Model(&types.Account{}). + Select("is_domain_primary_account, domain"). + Where(idQueryCondition, accountID). + Take(&info) + + if result.Error != nil { + return false, "", status.Errorf(status.Internal, "failed to get account info: %v", result.Error) + } + + return info.IsDomainPrimaryAccount, info.Domain, nil +} + +func (s *SqlStore) MarkAccountPrimary(ctx context.Context, accountID string) error { + result := s.db.Model(&types.Account{}). + Where(idQueryCondition, accountID). + Update("is_domain_primary_account", true) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to mark account as primary: %s", result.Error) + return status.Errorf(status.Internal, "failed to mark account as primary") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} + +type accountNetworkPatch struct { + Network *types.Network `gorm:"embedded;embeddedPrefix:network_"` +} + +func (s *SqlStore) UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error { + patch := accountNetworkPatch{ + Network: &types.Network{Net: ipNet}, + } + + result := s.db.WithContext(ctx). + Model(&types.Account{}). + Where(idQueryCondition, accountID). + Updates(&patch) + + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to update account network: %v", result.Error) + return status.Errorf(status.Internal, "failed to update account network") + } + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + return nil +} diff --git a/management/server/store/store.go b/management/server/store/store.go index da4459256..9e0c04853 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -202,6 +202,9 @@ type Store interface { GetPeerByIP(ctx context.Context, lockStrength LockingStrength, accountID string, ip net.IP) (*nbpeer.Peer, error) GetPeerIdByLabel(ctx context.Context, lockStrength LockingStrength, accountID string, hostname string) (string, error) GetAccountGroupPeers(ctx context.Context, lockStrength LockingStrength, accountID string) (map[string]map[string]struct{}, error) + IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) + MarkAccountPrimary(ctx context.Context, accountID string) error + UpdateAccountNetwork(ctx context.Context, accountID string, ipNet net.IPNet) error } const ( diff --git a/management/server/types/account.go b/management/server/types/account.go index 17a838aae..9ac2568a0 100644 --- a/management/server/types/account.go +++ b/management/server/types/account.go @@ -16,16 +16,16 @@ import ( log "github.com/sirupsen/logrus" nbdns "github.com/netbirdio/netbird/dns" - "github.com/netbirdio/netbird/shared/management/domain" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types" networkTypes "github.com/netbirdio/netbird/management/server/networks/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" "github.com/netbirdio/netbird/management/server/posture" - "github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/management/server/telemetry" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/shared/management/status" ) const ( @@ -89,6 +89,12 @@ type Account struct { Onboarding AccountOnboarding `gorm:"foreignKey:AccountID;references:id;constraint:OnDelete:CASCADE"` } +// this class is used by gorm only +type PrimaryAccountInfo struct { + IsDomainPrimaryAccount bool + Domain string +} + // Subclass used in gorm to only load network and not whole account type AccountNetwork struct { Network *Network `gorm:"embedded;embeddedPrefix:network_"`