diff --git a/management/server/account.go b/management/server/account.go index 63879802a..82f5ee4a3 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -277,29 +277,11 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager { // UpdateAccountSettings updates Account settings. // Only users with role UserRoleAdmin can update the account. // User that performs the update has to belong to the account. -// Returns an updated Account -func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { - halfYearLimit := 180 * 24 * time.Hour - if newSettings.PeerLoginExpiration > halfYearLimit { - return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") - } - - if newSettings.PeerLoginExpiration < time.Hour { - return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") - } - - if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { - return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) - } - +// Returns an updated Settings +func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() - account, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - return nil, err - } - allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) if err != nil { return nil, fmt.Errorf("failed to validate user permissions: %w", err) @@ -309,12 +291,118 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco return nil, status.NewPermissionDeniedError() } - err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) + var oldSettings *types.Settings + var updateAccountPeers bool + var groupChangesAffectPeers bool + + err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { + var groupsUpdated bool + + oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID) + if err != nil { + return err + } + + if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil { + return err + } + + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled || + oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled || + oldSettings.DNSDomain != newSettings.DNSDomain { + updateAccountPeers = true + } + + if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled { + groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID) + if err != nil { + return err + } + } + + if updateAccountPeers || groupsUpdated { + if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil { + return err + } + } + + return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings) + }) if err != nil { return nil, err } - oldSettings := account.Settings + extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra) + if err != nil { + return nil, err + } + + am.handleRoutingPeerDNSResolutionSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) + am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) + if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil { + return nil, err + } + if oldSettings.DNSDomain != newSettings.DNSDomain { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) + } + + if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers { + go am.UpdateAccountPeers(ctx, accountID) + } + + return newSettings, nil +} + +func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error { + halfYearLimit := 180 * 24 * time.Hour + if newSettings.PeerLoginExpiration > halfYearLimit { + return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") + } + + if newSettings.PeerLoginExpiration < time.Hour { + return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") + } + + if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) { + return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain) + } + + peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "") + if err != nil { + return err + } + + peersMap := make(map[string]*nbpeer.Peer, len(peers)) + for _, peer := range peers { + peersMap[peer.ID] = peer + } + + return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID) +} + +func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { + if newSettings.RoutingPeerDNSResolutionEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil) + } + } +} + +func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { + if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled { + if newSettings.LazyConnectionEnabled { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil) + } else { + am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil) + } + } +} + +func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { event := activity.AccountPeerLoginExpirationEnabled if !newSettings.PeerLoginExpirationEnabled { @@ -330,82 +418,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.checkAndSchedulePeerLoginExpiration(ctx, accountID) } - - updateAccountPeers := false - if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled { - if newSettings.RoutingPeerDNSResolutionEnabled { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil) - } else { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil) - } - updateAccountPeers = true - } - - if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled { - if newSettings.LazyConnectionEnabled { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil) - } else { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil) - } - updateAccountPeers = true - } - - if oldSettings.DNSDomain != newSettings.DNSDomain { - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil) - updateAccountPeers = true - } - - err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) - if err != nil { - return nil, err - } - - err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID) - if err != nil { - return nil, fmt.Errorf("groups propagation failed: %w", err) - } - - account.UpdateSettings(newSettings) - - if updateAccountPeers { - account.Network.Serial++ - } - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return nil, err - } - - extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra) - if err != nil { - return nil, err - } - - if updateAccountPeers || extraSettingsChanged { - go am.UpdateAccountPeers(ctx, accountID) - } - - return account, nil } -func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { +func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { if newSettings.GroupsPropagationEnabled { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) - // Todo: retroactively add user groups to all peers } else { am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) } } - - return nil } func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { if newSettings.PeerInactivityExpirationEnabled { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { - oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration - am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) } @@ -1853,3 +1880,57 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc return account, nil } + +// propagateUserGroupMemberships propagates all account users' group memberships to their peers. +// Returns true if any groups were modified, true if those updates affect peers and an error. +func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) { + groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, err + } + + groupsMap := make(map[string]*types.Group, len(groups)) + for _, group := range groups { + groupsMap[group.ID] = group + } + + users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID) + if err != nil { + return false, false, err + } + + groupsToUpdate := make(map[string]*types.Group) + + for _, user := range users { + userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id) + if err != nil { + return false, false, err + } + + updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil) + if err != nil { + return false, false, err + } + + for _, group := range updatedGroups { + groupsToUpdate[group.ID] = group + groupsMap[group.ID] = group + } + } + + if len(groupsToUpdate) == 0 { + return false, false, nil + } + + peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate)) + if err != nil { + return false, false, err + } + + err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate)) + if err != nil { + return false, false, err + } + + return true, peersAffected, nil +} diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 030bd94ef..de5031c03 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -88,7 +88,7 @@ type Manager interface { GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API GetAllConnectedPeers() (map[string]struct{}, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 5ada28ca3..ba0191c03 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1805,9 +1805,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) require.NoError(t, err, "unable to mark peer connected") - account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, + Extra: &types.ExtraSettings{}, }) require.NoError(t, err, "expecting to update account settings successfully but got error") @@ -1825,11 +1826,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { // disable expiration first update := peer.Copy() update.LoginExpirationEnabled = false - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") // enabling expiration should trigger the routine update.LoginExpirationEnabled = true - _, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) + _, err = manager.UpdatePeer(context.Background(), accountID, userID, update) require.NoError(t, err, "unable to update peer") failed := waitTimeout(wg, time.Second) @@ -1856,6 +1857,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, + Extra: &types.ExtraSettings{}, }) require.NoError(t, err, "expecting to update account settings successfully but got error") @@ -1919,9 +1921,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test }, } // enabling PeerLoginExpirationEnabled should trigger the expiration job - account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ + _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: true, + Extra: &types.ExtraSettings{}, }) require.NoError(t, err, "expecting to update account settings successfully but got error") @@ -1935,6 +1938,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, + Extra: &types.ExtraSettings{}, }) require.NoError(t, err, "expecting to update account settings successfully but got error") failed = waitTimeout(wg, time.Second) @@ -1950,13 +1954,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") require.NoError(t, err, "unable to create an account") - updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ + updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour, PeerLoginExpirationEnabled: false, + Extra: &types.ExtraSettings{}, }) require.NoError(t, err, "expecting to update account settings successfully but got error") - assert.False(t, updated.Settings.PeerLoginExpirationEnabled) - assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) + assert.False(t, updatedSettings.PeerLoginExpirationEnabled) + assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) require.NoError(t, err, "unable to get account settings") @@ -1967,12 +1972,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) { _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Second, PeerLoginExpirationEnabled: false, + Extra: &types.ExtraSettings{}, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpirationEnabled: false, + Extra: &types.ExtraSettings{}, }) require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") } @@ -3319,3 +3326,120 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) { }) }) } + +func TestPropagateUserGroupMemberships(t *testing.T) { + manager, err := createManager(t) + require.NoError(t, err) + + ctx := context.Background() + initiatorId := "test-user" + domain := "example.com" + + account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain) + require.NoError(t, err) + + peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId} + err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1) + require.NoError(t, err) + + peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId} + err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2) + require.NoError(t, err) + + t.Run("should skip propagation when the user has no groups", func(t *testing.T) { + groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + require.NoError(t, err) + assert.False(t, groupsUpdated) + assert.False(t, groupChangesAffectPeers) + }) + + t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) { + group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id} + require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1)) + + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + require.NoError(t, err) + + user.AutoGroups = append(user.AutoGroups, group1.ID) + require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + + groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + require.NoError(t, err) + assert.True(t, groupsUpdated) + assert.False(t, groupChangesAffectPeers) + + group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID) + require.NoError(t, err) + assert.Len(t, group.Peers, 2) + assert.Contains(t, group.Peers, "peer1") + assert.Contains(t, group.Peers, "peer2") + }) + + t.Run("should update membership and account peers for used groups", func(t *testing.T) { + group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id} + require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2)) + + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + require.NoError(t, err) + + user.AutoGroups = append(user.AutoGroups, group2.ID) + require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + + _, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{ + Name: "Group1 Policy", + AccountID: account.Id, + Enabled: true, + Rules: []*types.PolicyRule{ + { + Enabled: true, + Sources: []string{"group1"}, + Destinations: []string{"group2"}, + Bidirectional: true, + Action: types.PolicyTrafficActionAccept, + }, + }, + }, true) + require.NoError(t, err) + + groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + require.NoError(t, err) + assert.True(t, groupsUpdated) + assert.True(t, groupChangesAffectPeers) + + groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) + require.NoError(t, err) + for _, group := range groups { + assert.Len(t, group.Peers, 2) + assert.Contains(t, group.Peers, "peer1") + assert.Contains(t, group.Peers, "peer2") + } + }) + + t.Run("should not update membership or account peers when no changes", func(t *testing.T) { + groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + require.NoError(t, err) + assert.False(t, groupsUpdated) + assert.False(t, groupChangesAffectPeers) + }) + + t.Run("should not remove peers when groups are removed from user", func(t *testing.T) { + user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId) + require.NoError(t, err) + + user.AutoGroups = []string{"group1"} + require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user)) + + groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id) + require.NoError(t, err) + assert.False(t, groupsUpdated) + assert.False(t, groupChangesAffectPeers) + + groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"}) + require.NoError(t, err) + for _, group := range groups { + assert.Len(t, group.Peers, 2) + assert.Contains(t, group.Peers, "peer1") + assert.Contains(t, group.Peers, "peer2") + } + }) +} diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index 638524e31..dfc782b3f 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -126,7 +126,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled } - updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) + updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) if err != nil { util.WriteError(r.Context(), err, w) return @@ -138,7 +138,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta) + resp := toAccountResponse(accountID, updatedSettings, meta) util.WriteJSONObject(r.Context(), w, &resp) } diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index fec5140f4..a18798743 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -36,7 +36,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil }, - UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { + UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { halfYearLimit := 180 * 24 * time.Hour if newSettings.PeerLoginExpiration > halfYearLimit { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") @@ -46,9 +46,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") } - accCopy := account.Copy() - accCopy.UpdateSettings(newSettings) - return accCopy, nil + return newSettings, nil }, GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { return account.Copy(), nil diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index ed47d3914..3caa6744a 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -90,7 +90,7 @@ type MockAccountManager struct { GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) - UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) + UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error @@ -662,7 +662,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us } // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface -func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { +func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) { if am.UpdateAccountSettingsFunc != nil { return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 6c3104ef0..d81890775 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -2163,6 +2163,22 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre return nil } +// SaveAccountSettings stores the account settings in DB. +func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error { + result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}). + Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings}) + if result.Error != nil { + log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error) + return status.Errorf(status.Internal, "failed to save account settings to store") + } + + if result.RowsAffected == 0 { + return status.NewAccountNotFoundError(accountID) + } + + return nil +} + func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { tx := s.db if lockStrength != LockingStrengthNone { diff --git a/management/server/store/store.go b/management/server/store/store.go index fff809247..c7b103454 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -72,6 +72,7 @@ type Store interface { DeleteAccount(ctx context.Context, account *types.Account) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error + SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) diff --git a/management/server/user.go b/management/server/user.go index 6d780cda3..a1f1c46d5 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -1153,8 +1153,9 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp if !ok { return nil, errors.New("group not found") } - addUserPeersToGroup(userPeerIDMap, group) - groupsToUpdate = append(groupsToUpdate, group) + if changed := addUserPeersToGroup(userPeerIDMap, group); changed { + groupsToUpdate = append(groupsToUpdate, group) + } } for _, gid := range groupsToRemove { @@ -1162,45 +1163,65 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp if !ok { return nil, errors.New("group not found") } - removeUserPeersFromGroup(userPeerIDMap, group) - groupsToUpdate = append(groupsToUpdate, group) + if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed { + groupsToUpdate = append(groupsToUpdate, group) + } } return groupsToUpdate, nil } // addUserPeersToGroup adds the user's peers to the group. -func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { +func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { groupPeers := make(map[string]struct{}, len(group.Peers)) for _, pid := range group.Peers { groupPeers[pid] = struct{}{} } + changed := false for pid := range userPeerIDs { - groupPeers[pid] = struct{}{} + if _, exists := groupPeers[pid]; !exists { + groupPeers[pid] = struct{}{} + changed = true + } } group.Peers = make([]string, 0, len(groupPeers)) for pid := range groupPeers { group.Peers = append(group.Peers, pid) } + + if changed { + group.Peers = make([]string, 0, len(groupPeers)) + for pid := range groupPeers { + group.Peers = append(group.Peers, pid) + } + } + return changed } // removeUserPeersFromGroup removes user's peers from the group. -func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { +func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool { // skip removing peers from group All if group.Name == "All" { - return + return false } updatedPeers := make([]string, 0, len(group.Peers)) + changed := false + for _, pid := range group.Peers { - if _, found := userPeerIDs[pid]; !found { - updatedPeers = append(updatedPeers, pid) + if _, owned := userPeerIDs[pid]; owned { + changed = true + continue } + updatedPeers = append(updatedPeers, pid) } - group.Peers = updatedPeers + if changed { + group.Peers = updatedPeers + } + return changed } func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {