[management] Propagate user groups when group propagation setting is re-enabled (#3912)

This commit is contained in:
Bethuel Mmbaga
2025-06-11 14:32:16 +03:00
committed by GitHub
parent 75feb0da8b
commit 4ee1635baa
9 changed files with 352 additions and 111 deletions

View File

@@ -277,29 +277,11 @@ func (am *DefaultAccountManager) GetIdpManager() idp.Manager {
// UpdateAccountSettings updates Account settings. // UpdateAccountSettings updates Account settings.
// Only users with role UserRoleAdmin can update the account. // Only users with role UserRoleAdmin can update the account.
// User that performs the update has to belong to the account. // User that performs the update has to belong to the account.
// Returns an updated Account // Returns an updated Settings
func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
return nil, status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) unlock := am.Store.AcquireWriteLockByUID(ctx, accountID)
defer unlock() defer unlock()
account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return nil, err
}
allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update) allowed, err := am.permissionsManager.ValidateUserPermissions(ctx, accountID, userID, modules.Settings, operations.Update)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to validate user permissions: %w", err) return nil, fmt.Errorf("failed to validate user permissions: %w", err)
@@ -309,12 +291,118 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
return nil, status.NewPermissionDeniedError() return nil, status.NewPermissionDeniedError()
} }
err = am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, account.Settings.Extra, account.Peers, userID, accountID) var oldSettings *types.Settings
var updateAccountPeers bool
var groupChangesAffectPeers bool
err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error {
var groupsUpdated bool
oldSettings, err = transaction.GetAccountSettings(ctx, store.LockingStrengthUpdate, accountID)
if err != nil {
return err
}
if err = am.validateSettingsUpdate(ctx, transaction, newSettings, oldSettings, userID, accountID); err != nil {
return err
}
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled ||
oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled ||
oldSettings.DNSDomain != newSettings.DNSDomain {
updateAccountPeers = true
}
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled && newSettings.GroupsPropagationEnabled {
groupsUpdated, groupChangesAffectPeers, err = propagateUserGroupMemberships(ctx, transaction, accountID)
if err != nil {
return err
}
}
if updateAccountPeers || groupsUpdated {
if err = transaction.IncrementNetworkSerial(ctx, store.LockingStrengthUpdate, accountID); err != nil {
return err
}
}
return transaction.SaveAccountSettings(ctx, store.LockingStrengthUpdate, accountID, newSettings)
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
oldSettings := account.Settings extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
if err != nil {
return nil, err
}
am.handleRoutingPeerDNSResolutionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleLazyConnectionSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
return nil, err
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
}
if updateAccountPeers || extraSettingsChanged || groupChangesAffectPeers {
go am.UpdateAccountPeers(ctx, accountID)
}
return newSettings, nil
}
func (am *DefaultAccountManager) validateSettingsUpdate(ctx context.Context, transaction store.Store, newSettings, oldSettings *types.Settings, userID, accountID string) error {
halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
}
if newSettings.PeerLoginExpiration < time.Hour {
return status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
}
if newSettings.DNSDomain != "" && !isDomainValid(newSettings.DNSDomain) {
return status.Errorf(status.InvalidArgument, "invalid domain \"%s\" provided for DNS domain", newSettings.DNSDomain)
}
peers, err := transaction.GetAccountPeers(ctx, store.LockingStrengthShare, accountID, "", "")
if err != nil {
return err
}
peersMap := make(map[string]*nbpeer.Peer, len(peers))
for _, peer := range peers {
peersMap[peer.ID] = peer
}
return am.integratedPeerValidator.ValidateExtraSettings(ctx, newSettings.Extra, oldSettings.Extra, peersMap, userID, accountID)
}
func (am *DefaultAccountManager) handleRoutingPeerDNSResolutionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
if newSettings.RoutingPeerDNSResolutionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
}
}
}
func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
if newSettings.LazyConnectionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
}
}
}
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled { if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
event := activity.AccountPeerLoginExpirationEnabled event := activity.AccountPeerLoginExpirationEnabled
if !newSettings.PeerLoginExpirationEnabled { if !newSettings.PeerLoginExpirationEnabled {
@@ -330,82 +418,21 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerLoginExpirationDurationUpdated, nil)
am.checkAndSchedulePeerLoginExpiration(ctx, accountID) am.checkAndSchedulePeerLoginExpiration(ctx, accountID)
} }
updateAccountPeers := false
if oldSettings.RoutingPeerDNSResolutionEnabled != newSettings.RoutingPeerDNSResolutionEnabled {
if newSettings.RoutingPeerDNSResolutionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountRoutingPeerDNSResolutionDisabled, nil)
}
updateAccountPeers = true
}
if oldSettings.LazyConnectionEnabled != newSettings.LazyConnectionEnabled {
if newSettings.LazyConnectionEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionEnabled, nil)
} else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountLazyConnectionDisabled, nil)
}
updateAccountPeers = true
}
if oldSettings.DNSDomain != newSettings.DNSDomain {
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountDNSDomainUpdated, nil)
updateAccountPeers = true
}
err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, err
}
err = am.handleGroupsPropagationSettings(ctx, oldSettings, newSettings, userID, accountID)
if err != nil {
return nil, fmt.Errorf("groups propagation failed: %w", err)
}
account.UpdateSettings(newSettings)
if updateAccountPeers {
account.Network.Serial++
}
err = am.Store.SaveAccount(ctx, account)
if err != nil {
return nil, err
}
extraSettingsChanged, err := am.settingsManager.UpdateExtraSettings(ctx, accountID, userID, newSettings.Extra)
if err != nil {
return nil, err
}
if updateAccountPeers || extraSettingsChanged {
go am.UpdateAccountPeers(ctx, accountID)
}
return account, nil
} }
func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleGroupsPropagationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled { if oldSettings.GroupsPropagationEnabled != newSettings.GroupsPropagationEnabled {
if newSettings.GroupsPropagationEnabled { if newSettings.GroupsPropagationEnabled {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationEnabled, nil)
// Todo: retroactively add user groups to all peers
} else { } else {
am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.UserGroupPropagationDisabled, nil)
} }
} }
return nil
} }
func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error { func (am *DefaultAccountManager) handleInactivityExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) error {
if newSettings.PeerInactivityExpirationEnabled { if newSettings.PeerInactivityExpirationEnabled {
if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration { if oldSettings.PeerInactivityExpiration != newSettings.PeerInactivityExpiration {
oldSettings.PeerInactivityExpiration = newSettings.PeerInactivityExpiration
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil) am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountPeerInactivityExpirationDurationUpdated, nil)
am.checkAndSchedulePeerInactivityExpiration(ctx, accountID) am.checkAndSchedulePeerInactivityExpiration(ctx, accountID)
} }
@@ -1853,3 +1880,57 @@ func (am *DefaultAccountManager) UpdateToPrimaryAccount(ctx context.Context, acc
return account, nil return account, nil
} }
// propagateUserGroupMemberships propagates all account users' group memberships to their peers.
// Returns true if any groups were modified, true if those updates affect peers and an error.
func propagateUserGroupMemberships(ctx context.Context, transaction store.Store, accountID string) (groupsUpdated bool, peersAffected bool, err error) {
groups, err := transaction.GetAccountGroups(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsMap := make(map[string]*types.Group, len(groups))
for _, group := range groups {
groupsMap[group.ID] = group
}
users, err := transaction.GetAccountUsers(ctx, store.LockingStrengthShare, accountID)
if err != nil {
return false, false, err
}
groupsToUpdate := make(map[string]*types.Group)
for _, user := range users {
userPeers, err := transaction.GetUserPeers(ctx, store.LockingStrengthShare, accountID, user.Id)
if err != nil {
return false, false, err
}
updatedGroups, err := updateUserPeersInGroups(groupsMap, userPeers, user.AutoGroups, nil)
if err != nil {
return false, false, err
}
for _, group := range updatedGroups {
groupsToUpdate[group.ID] = group
groupsMap[group.ID] = group
}
}
if len(groupsToUpdate) == 0 {
return false, false, nil
}
peersAffected, err = areGroupChangesAffectPeers(ctx, transaction, accountID, maps.Keys(groupsToUpdate))
if err != nil {
return false, false, err
}
err = transaction.SaveGroups(ctx, store.LockingStrengthUpdate, accountID, maps.Values(groupsToUpdate))
if err != nil {
return false, false, err
}
return true, peersAffected, nil
}

View File

@@ -88,7 +88,7 @@ type Manager interface {
GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error) GetDNSSettings(ctx context.Context, accountID string, userID string) (*types.DNSSettings, error)
SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error SaveDNSSettings(ctx context.Context, accountID string, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API LoginPeer(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API SyncPeer(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) // used by peer gRPC API
GetAllConnectedPeers() (map[string]struct{}, error) GetAllConnectedPeers() (map[string]struct{}, error)

View File

@@ -1805,9 +1805,10 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID)
require.NoError(t, err, "unable to mark peer connected") require.NoError(t, err, "unable to mark peer connected")
account, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
}) })
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1825,11 +1826,11 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) {
// disable expiration first // disable expiration first
update := peer.Copy() update := peer.Copy()
update.LoginExpirationEnabled = false update.LoginExpirationEnabled = false
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
// enabling expiration should trigger the routine // enabling expiration should trigger the routine
update.LoginExpirationEnabled = true update.LoginExpirationEnabled = true
_, err = manager.UpdatePeer(context.Background(), account.Id, userID, update) _, err = manager.UpdatePeer(context.Background(), accountID, userID, update)
require.NoError(t, err, "unable to update peer") require.NoError(t, err, "unable to update peer")
failed := waitTimeout(wg, time.Second) failed := waitTimeout(wg, time.Second)
@@ -1856,6 +1857,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing.
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
}) })
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1919,9 +1921,10 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
}, },
} }
// enabling PeerLoginExpirationEnabled should trigger the expiration job // enabling PeerLoginExpirationEnabled should trigger the expiration job
account, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: true, PeerLoginExpirationEnabled: true,
Extra: &types.ExtraSettings{},
}) })
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
@@ -1935,6 +1938,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test
_, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), account.Id, userID, &types.Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
}) })
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
failed = waitTimeout(wg, time.Second) failed = waitTimeout(wg, time.Second)
@@ -1950,13 +1954,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "") accountID, err := manager.GetAccountIDByUserID(context.Background(), userID, "")
require.NoError(t, err, "unable to create an account") require.NoError(t, err, "unable to create an account")
updated, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ updatedSettings, err := manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour, PeerLoginExpiration: time.Hour,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
}) })
require.NoError(t, err, "expecting to update account settings successfully but got error") require.NoError(t, err, "expecting to update account settings successfully but got error")
assert.False(t, updated.Settings.PeerLoginExpirationEnabled) assert.False(t, updatedSettings.PeerLoginExpirationEnabled)
assert.Equal(t, updated.Settings.PeerLoginExpiration, time.Hour) assert.Equal(t, updatedSettings.PeerLoginExpiration, time.Hour)
settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID) settings, err := manager.Store.GetAccountSettings(context.Background(), store.LockingStrengthShare, accountID)
require.NoError(t, err, "unable to get account settings") require.NoError(t, err, "unable to get account settings")
@@ -1967,12 +1972,14 @@ func TestDefaultAccountManager_UpdateAccountSettings(t *testing.T) {
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Second, PeerLoginExpiration: time.Second,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
}) })
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration less than one hour")
_, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{
PeerLoginExpiration: time.Hour * 24 * 181, PeerLoginExpiration: time.Hour * 24 * 181,
PeerLoginExpirationEnabled: false, PeerLoginExpirationEnabled: false,
Extra: &types.ExtraSettings{},
}) })
require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days") require.Error(t, err, "expecting to fail when providing PeerLoginExpiration more than 180 days")
} }
@@ -3319,3 +3326,120 @@ func TestDefaultAccountManager_IsCacheCold(t *testing.T) {
}) })
}) })
} }
func TestPropagateUserGroupMemberships(t *testing.T) {
manager, err := createManager(t)
require.NoError(t, err)
ctx := context.Background()
initiatorId := "test-user"
domain := "example.com"
account, err := manager.GetOrCreateAccountByUser(ctx, initiatorId, domain)
require.NoError(t, err)
peer1 := &nbpeer.Peer{ID: "peer1", AccountID: account.Id, UserID: initiatorId}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer1)
require.NoError(t, err)
peer2 := &nbpeer.Peer{ID: "peer2", AccountID: account.Id, UserID: initiatorId}
err = manager.Store.AddPeerToAccount(ctx, store.LockingStrengthUpdate, peer2)
require.NoError(t, err)
t.Run("should skip propagation when the user has no groups", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should update membership but no account peers update for unused groups", func(t *testing.T) {
group1 := &types.Group{ID: "group1", Name: "Group 1", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group1))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group1.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
group, err := manager.Store.GetGroupByID(ctx, store.LockingStrengthShare, account.Id, group1.ID)
require.NoError(t, err)
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
})
t.Run("should update membership and account peers for used groups", func(t *testing.T) {
group2 := &types.Group{ID: "group2", Name: "Group 2", AccountID: account.Id}
require.NoError(t, manager.Store.SaveGroup(ctx, store.LockingStrengthUpdate, group2))
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = append(user.AutoGroups, group2.ID)
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
_, err = manager.SavePolicy(context.Background(), account.Id, initiatorId, &types.Policy{
Name: "Group1 Policy",
AccountID: account.Id,
Enabled: true,
Rules: []*types.PolicyRule{
{
Enabled: true,
Sources: []string{"group1"},
Destinations: []string{"group2"},
Bidirectional: true,
Action: types.PolicyTrafficActionAccept,
},
},
}, true)
require.NoError(t, err)
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.True(t, groupsUpdated)
assert.True(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
t.Run("should not update membership or account peers when no changes", func(t *testing.T) {
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
})
t.Run("should not remove peers when groups are removed from user", func(t *testing.T) {
user, err := manager.Store.GetUserByUserID(ctx, store.LockingStrengthShare, initiatorId)
require.NoError(t, err)
user.AutoGroups = []string{"group1"}
require.NoError(t, manager.Store.SaveUser(ctx, store.LockingStrengthUpdate, user))
groupsUpdated, groupChangesAffectPeers, err := propagateUserGroupMemberships(ctx, manager.Store, account.Id)
require.NoError(t, err)
assert.False(t, groupsUpdated)
assert.False(t, groupChangesAffectPeers)
groups, err := manager.Store.GetGroupsByIDs(ctx, store.LockingStrengthShare, account.Id, []string{"group1", "group2"})
require.NoError(t, err)
for _, group := range groups {
assert.Len(t, group.Peers, 2)
assert.Contains(t, group.Peers, "peer1")
assert.Contains(t, group.Peers, "peer2")
}
})
}

View File

@@ -126,7 +126,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled settings.LazyConnectionEnabled = *req.Settings.LazyConnectionEnabled
} }
updatedAccount, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings) updatedSettings, err := h.accountManager.UpdateAccountSettings(r.Context(), accountID, userID, settings)
if err != nil { if err != nil {
util.WriteError(r.Context(), err, w) util.WriteError(r.Context(), err, w)
return return
@@ -138,7 +138,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) {
return return
} }
resp := toAccountResponse(updatedAccount.Id, updatedAccount.Settings, meta) resp := toAccountResponse(accountID, updatedSettings, meta)
util.WriteJSONObject(r.Context(), w, &resp) util.WriteJSONObject(r.Context(), w, &resp)
} }

View File

@@ -36,7 +36,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) {
return account.Settings, nil return account.Settings, nil
}, },
UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { UpdateAccountSettingsFunc: func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
halfYearLimit := 180 * 24 * time.Hour halfYearLimit := 180 * 24 * time.Hour
if newSettings.PeerLoginExpiration > halfYearLimit { if newSettings.PeerLoginExpiration > halfYearLimit {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be larger than 180 days")
@@ -46,9 +46,7 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler {
return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour") return nil, status.Errorf(status.InvalidArgument, "peer login expiration can't be smaller than one hour")
} }
accCopy := account.Copy() return newSettings, nil
accCopy.UpdateSettings(newSettings)
return accCopy, nil
}, },
GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) { GetAccountByIDFunc: func(ctx context.Context, accountID string, userID string) (*types.Account, error) {
return account.Copy(), nil return account.Copy(), nil

View File

@@ -90,7 +90,7 @@ type MockAccountManager struct {
GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error) GetDNSSettingsFunc func(ctx context.Context, accountID, userID string) (*types.DNSSettings, error)
SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error SaveDNSSettingsFunc func(ctx context.Context, accountID, userID string, dnsSettingsToSave *types.DNSSettings) error
GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) GetPeerFunc func(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error)
UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) UpdateAccountSettingsFunc func(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error)
LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) LoginPeerFunc func(ctx context.Context, login types.PeerLogin) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error) SyncPeerFunc func(ctx context.Context, sync types.PeerSync, accountID string) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error InviteUserFunc func(ctx context.Context, accountID string, initiatorUserID string, targetUserEmail string) error
@@ -662,7 +662,7 @@ func (am *MockAccountManager) GetPeer(ctx context.Context, accountID, peerID, us
} }
// UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface // UpdateAccountSettings mocks UpdateAccountSettings of the AccountManager interface
func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Account, error) { func (am *MockAccountManager) UpdateAccountSettings(ctx context.Context, accountID, userID string, newSettings *types.Settings) (*types.Settings, error) {
if am.UpdateAccountSettingsFunc != nil { if am.UpdateAccountSettingsFunc != nil {
return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings) return am.UpdateAccountSettingsFunc(ctx, accountID, userID, newSettings)
} }

View File

@@ -2163,6 +2163,22 @@ func (s *SqlStore) SaveDNSSettings(ctx context.Context, lockStrength LockingStre
return nil return nil
} }
// SaveAccountSettings stores the account settings in DB.
func (s *SqlStore) SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error {
result := s.db.Clauses(clause.Locking{Strength: string(lockStrength)}).Model(&types.Account{}).
Select("*").Where(idQueryCondition, accountID).Updates(&types.AccountSettings{Settings: settings})
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save account settings to store: %v", result.Error)
return status.Errorf(status.Internal, "failed to save account settings to store")
}
if result.RowsAffected == 0 {
return status.NewAccountNotFoundError(accountID)
}
return nil
}
func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) { func (s *SqlStore) GetAccountNetworks(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*networkTypes.Network, error) {
tx := s.db tx := s.db
if lockStrength != LockingStrengthNone { if lockStrength != LockingStrengthNone {

View File

@@ -72,6 +72,7 @@ type Store interface {
DeleteAccount(ctx context.Context, account *types.Account) error DeleteAccount(ctx context.Context, account *types.Account) error
UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error UpdateAccountDomainAttributes(ctx context.Context, accountID string, domain string, category string, isPrimaryDomain bool) error
SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error SaveDNSSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.DNSSettings) error
SaveAccountSettings(ctx context.Context, lockStrength LockingStrength, accountID string, settings *types.Settings) error
CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error) CountAccountsByPrivateDomain(ctx context.Context, domain string) (int64, error)
GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error) GetUserByPATID(ctx context.Context, lockStrength LockingStrength, patID string) (*types.User, error)

View File

@@ -1153,8 +1153,9 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
if !ok { if !ok {
return nil, errors.New("group not found") return nil, errors.New("group not found")
} }
addUserPeersToGroup(userPeerIDMap, group) if changed := addUserPeersToGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group) groupsToUpdate = append(groupsToUpdate, group)
}
} }
for _, gid := range groupsToRemove { for _, gid := range groupsToRemove {
@@ -1162,45 +1163,65 @@ func updateUserPeersInGroups(accountGroups map[string]*types.Group, peers []*nbp
if !ok { if !ok {
return nil, errors.New("group not found") return nil, errors.New("group not found")
} }
removeUserPeersFromGroup(userPeerIDMap, group) if changed := removeUserPeersFromGroup(userPeerIDMap, group); changed {
groupsToUpdate = append(groupsToUpdate, group) groupsToUpdate = append(groupsToUpdate, group)
}
} }
return groupsToUpdate, nil return groupsToUpdate, nil
} }
// addUserPeersToGroup adds the user's peers to the group. // addUserPeersToGroup adds the user's peers to the group.
func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) { func addUserPeersToGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
groupPeers := make(map[string]struct{}, len(group.Peers)) groupPeers := make(map[string]struct{}, len(group.Peers))
for _, pid := range group.Peers { for _, pid := range group.Peers {
groupPeers[pid] = struct{}{} groupPeers[pid] = struct{}{}
} }
changed := false
for pid := range userPeerIDs { for pid := range userPeerIDs {
groupPeers[pid] = struct{}{} if _, exists := groupPeers[pid]; !exists {
groupPeers[pid] = struct{}{}
changed = true
}
} }
group.Peers = make([]string, 0, len(groupPeers)) group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers { for pid := range groupPeers {
group.Peers = append(group.Peers, pid) group.Peers = append(group.Peers, pid)
} }
if changed {
group.Peers = make([]string, 0, len(groupPeers))
for pid := range groupPeers {
group.Peers = append(group.Peers, pid)
}
}
return changed
} }
// removeUserPeersFromGroup removes user's peers from the group. // removeUserPeersFromGroup removes user's peers from the group.
func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) { func removeUserPeersFromGroup(userPeerIDs map[string]struct{}, group *types.Group) bool {
// skip removing peers from group All // skip removing peers from group All
if group.Name == "All" { if group.Name == "All" {
return return false
} }
updatedPeers := make([]string, 0, len(group.Peers)) updatedPeers := make([]string, 0, len(group.Peers))
changed := false
for _, pid := range group.Peers { for _, pid := range group.Peers {
if _, found := userPeerIDs[pid]; !found { if _, owned := userPeerIDs[pid]; owned {
updatedPeers = append(updatedPeers, pid) changed = true
continue
} }
updatedPeers = append(updatedPeers, pid)
} }
group.Peers = updatedPeers if changed {
group.Peers = updatedPeers
}
return changed
} }
func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) {