diff --git a/management/server/dns_test.go b/management/server/dns_test.go index e033c1a21..a861f6d34 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -6,9 +6,11 @@ import ( "net/netip" "reflect" "testing" + "time" nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server/telemetry" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -476,3 +478,92 @@ func TestToProtocolDNSConfigWithCache(t *testing.T) { t.Errorf("Cache should contain name server group 'group2'") } } + +func TestDNSAccountPeerUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.SaveGroup(context.Background(), account.Id, userID, &group.Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + assert.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // Saving DNS settings with unused groups should not update account peers and not send peer update + t.Run("saving dns setting with unused groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"group-id"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + _, err = manager.CreateNameServerGroup( + context.Background(), account.Id, "ns-group-1", "ns-group-1", []dns.NameServer{{ + IP: netip.MustParseAddr(peer1.IP.String()), + NSType: dns.UDPNameServerType, + Port: dns.DefaultDNSPort, + }}, + []string{"group-id"}, + true, []string{}, true, userID, false, + ) + assert.NoError(t, err) + + // Saving DNS settings with used groups should update account peers and send peer update + t.Run("saving dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"group-id"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + // Saving unchanged DNS settings with used groups should update account peers and not send peer update + // since there is no change in the network map + t.Run("saving unchanged dns setting with used groups", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err := manager.SaveDNSSettings(context.Background(), account.Id, userID, &DNSSettings{ + DisabledManagementGroups: []string{"group-id"}, + }) + assert.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + +} diff --git a/management/server/group.go b/management/server/group.go index 49720f347..63281a2f1 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -166,12 +166,19 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user eventsToStore = append(eventsToStore, events...) } + newGroupIDs := make([]string, 0, len(newGroups)) + for _, newGroup := range newGroups { + newGroupIDs = append(newGroupIDs, newGroup.ID) + } + account.Network.IncSerial() if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, newGroupIDs) { + am.updateAccountPeers(ctx, account) + } for _, storeEvent := range eventsToStore { storeEvent() @@ -274,8 +281,6 @@ func (am *DefaultAccountManager) DeleteGroup(ctx context.Context, accountId, use am.StoreEvent(ctx, userId, groupID, accountId, activity.GroupDeleted, group.EventMeta()) - am.updateAccountPeers(ctx, account) - return nil } @@ -318,8 +323,6 @@ func (am *DefaultAccountManager) DeleteGroups(ctx context.Context, accountId, us am.StoreEvent(ctx, userId, g.ID, accountId, activity.GroupDeleted, g.EventMeta()) } - am.updateAccountPeers(ctx, account) - return allErrors } @@ -372,7 +375,9 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr return err } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -402,7 +407,9 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } } - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, []string{group.ID}) { + am.updateAccountPeers(ctx, account) + } return nil } @@ -505,3 +512,29 @@ func isGroupLinkedToUser(users map[string]*User, groupID string) (bool, *User) { } return false, nil } + +// anyGroupHasPeers checks if any of the given groups in the account have peers. +func anyGroupHasPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if group, exists := account.Groups[groupID]; exists && group.HasPeers() { + return true + } + } + return false +} + +func areGroupChangesAffectPeers(account *Account, groupIDs []string) bool { + for _, groupID := range groupIDs { + if linked, _ := isGroupLinkedToDns(account.NameServerGroups, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToPolicy(account.Policies, groupID); linked { + return true + } + if linked, _ := isGroupLinkedToRoute(account.Routes, groupID); linked { + return true + } + } + + return false +} diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 636f7cfee..2cd934065 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -80,13 +80,13 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco account.NameServerGroups[newNSGroup.ID] = newNSGroup account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return nil, err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, newNSGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) return newNSGroup.Copy(), nil @@ -94,7 +94,6 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco // SaveNameServerGroup saves nameserver group func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accountID, userID string, nsGroupToSave *nbdns.NameServerGroup) error { - unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) defer unlock() @@ -112,16 +111,17 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun return err } + oldNSGroup := account.NameServerGroups[nsGroupToSave.ID] account.NameServerGroups[nsGroupToSave.ID] = nsGroupToSave account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, nsGroupToSave.Groups) || anyGroupHasPeers(account, oldNSGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) return nil @@ -145,13 +145,13 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco delete(account.NameServerGroups, nsGroupID) account.Network.IncSerial() - err = am.Store.SaveAccount(ctx, account) - if err != nil { + if err = am.Store.SaveAccount(ctx, account); err != nil { return err } - am.updateAccountPeers(ctx, account) - + if anyGroupHasPeers(account, nsGroup.Groups) { + am.updateAccountPeers(ctx, account) + } am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) return nil diff --git a/management/server/peer.go b/management/server/peer.go index 8c241c186..1d0a1dcef 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net" + "slices" "strings" "sync" "time" @@ -219,13 +220,16 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } account.UpdatePeer(peer) - + account.Network.IncSerial() err = am.Store.SaveAccount(ctx, account) if err != nil { return nil, err } - am.updateAccountPeers(ctx, account) + expired, _ := peer.LoginExpired(account.Settings.PeerLoginExpiration) + if expired && peer.LoginExpirationEnabled { + am.updateAccountPeers(ctx, account) + } return peer, nil } @@ -299,7 +303,10 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer return err } - am.updateAccountPeers(ctx, account) + updateAccountPeers := isPeerInActiveGroup(account, peerID) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } @@ -531,7 +538,9 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s am.StoreEvent(ctx, opEvent.InitiatorID, opEvent.TargetID, opEvent.AccountID, opEvent.Activity, opEvent.Meta) - am.updateAccountPeers(ctx, account) + if areGroupChangesAffectPeers(account, groupsToAdd) { + am.updateAccountPeers(ctx, account) + } approvedPeersMap, err := am.GetValidatedPeers(account) if err != nil { @@ -577,16 +586,22 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } - var postureChecks []*posture.Checks - if peerNotValid { emptyMap := &NetworkMap{ Network: account.Network.Copy(), } - return peer, emptyMap, postureChecks, nil + return peer, emptyMap, nil, nil } - if isStatusChanged { + peer, peerMetaUpdated := updatePeerMeta(peer, sync.Meta, account) + if peerMetaUpdated { + err = am.Store.SaveAccount(ctx, account) + if err != nil { + return nil, nil, nil, err + } + } + + if isStatusChanged || (peerMetaUpdated && sync.UpdateAccountPeers) { am.updateAccountPeers(ctx, account) } @@ -594,7 +609,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac if err != nil { return nil, nil, nil, err } - postureChecks = am.getPeerPostureChecks(account, peer) + postureChecks := am.getPeerPostureChecks(account, peer) customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil @@ -812,51 +827,6 @@ func updatePeerLastLogin(peer *nbpeer.Peer, account *Account) { account.UpdatePeer(peer) } -// UpdatePeerSSHKey updates peer's public SSH key -func (am *DefaultAccountManager) UpdatePeerSSHKey(ctx context.Context, peerID string, sshKey string) error { - if sshKey == "" { - log.WithContext(ctx).Debugf("empty SSH key provided for peer %s, skipping update", peerID) - return nil - } - - account, err := am.Store.GetAccountByPeerID(ctx, peerID) - if err != nil { - return err - } - - unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id) - defer unlock() - - // ensure that we consider modification happened meanwhile (because we were outside the account lock when we fetched the account) - account, err = am.Store.GetAccount(ctx, account.Id) - if err != nil { - return err - } - - peer := account.GetPeer(peerID) - if peer == nil { - return status.Errorf(status.NotFound, "peer with ID %s not found", peerID) - } - - if peer.SSHKey == sshKey { - log.WithContext(ctx).Debugf("same SSH key provided for peer %s, skipping update", peerID) - return nil - } - - peer.SSHKey = sshKey - account.UpdatePeer(peer) - - err = am.Store.SaveAccount(ctx, account) - if err != nil { - return err - } - - // trigger network map update - am.updateAccountPeers(ctx, account) - - return nil -} - // GetPeer for a given accountID, peerID and userID error if not found. func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, userID string) (*nbpeer.Peer, error) { unlock := am.Store.AcquireWriteLockByUID(ctx, accountID) @@ -963,3 +933,15 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account wg.Wait() } + +// IsPeerInActiveGroup checks if the given peer is part of a group that is used +// in an active DNS, route, or ACL configuration. +func isPeerInActiveGroup(account *Account, peerID string) bool { + peerGroupIDs := make([]string, 0) + for _, group := range account.Groups { + if slices.Contains(group.Peers, peerID) { + peerGroupIDs = append(peerGroupIDs, group.ID) + } + } + return areGroupChangesAffectPeers(account, peerGroupIDs) +} diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 918436515..c93b98c8e 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -13,6 +13,7 @@ import ( "github.com/rs/xid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" nbdns "github.com/netbirdio/netbird/dns" @@ -995,3 +996,156 @@ func TestToSyncResponse(t *testing.T) { assert.Equal(t, 1, len(response.Checks)) assert.Equal(t, "/usr/bin/netbird", response.Checks[0].Files[0]) } + +func TestPeerAccountPeerUpdate(t *testing.T) { + manager, account, peer1, peer2, peer3 := setupNetworkMapTest(t) + + err := manager.DeletePolicy(context.Background(), account.Id, account.Policies[0].ID, userID) + require.NoError(t, err) + + err = manager.SaveGroup(context.Background(), account.Id, userID, &nbgroup.Group{ + ID: "group-id", + Name: "GroupA", + Peers: []string{peer1.ID, peer2.ID, peer3.ID}, + }) + require.NoError(t, err) + + updMsg := manager.peersUpdateManager.CreateChannel(context.Background(), peer1.ID) + t.Cleanup(func() { + manager.peersUpdateManager.CloseChannel(context.Background(), peer1.ID) + }) + + // create a user with auto groups + _, err = manager.SaveOrAddUser(context.Background(), account.Id, userID, &User{ + Id: "regularUser1", + AccountID: account.Id, + Role: UserRoleAdmin, + Issued: UserIssuedAPI, + AutoGroups: []string{"group-id"}, + }, true) + require.NoError(t, err) + + var peer4 *nbpeer.Peer + + // Updating not expired peer and peer expiration is enabled should not update account peers and not send peer update + t.Run("updating not expired peer and peer expiration is enabled", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + _, err := manager.UpdatePeer(context.Background(), account.Id, userID, peer2) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Adding peer with an unused group in active dns, route, acl should not update account peers and not send peer update + t.Run("adding peer with unused group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // Deleting peer with an unused group in active dns, route, acl should not update account peers and not send peer update + t.Run("deleting peer with unused group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldNotReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldNotReceiveUpdate") + } + }) + + // use the group-id in policy + err = manager.SavePolicy(context.Background(), account.Id, userID, &Policy{ + ID: "policy", + Enabled: true, + Rules: []*PolicyRule{ + { + Enabled: true, + Sources: []string{"group-id"}, + Destinations: []string{"group-id"}, + Bidirectional: true, + Action: PolicyTrafficActionAccept, + }, + }, + }) + require.NoError(t, err) + + // Adding peer with a used group in active dns, route or policy should update account peers and send peer update + t.Run("adding peer with used group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + key, err := wgtypes.GeneratePrivateKey() + require.NoError(t, err) + + expectedPeerKey := key.PublicKey().String() + peer4, _, _, err = manager.AddPeer(context.Background(), "", "regularUser1", &nbpeer.Peer{ + Key: expectedPeerKey, + LoginExpirationEnabled: true, + Meta: nbpeer.PeerSystemMeta{Hostname: expectedPeerKey}, + }) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + + //Deleting peer with a used group in active dns, route or acl should update account peers and send peer update + t.Run("deleting peer with used group", func(t *testing.T) { + done := make(chan struct{}) + go func() { + peerShouldReceiveUpdate(t, updMsg) + close(done) + }() + + err = manager.DeletePeer(context.Background(), account.Id, peer4.ID, userID) + require.NoError(t, err) + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Error("timeout waiting for peerShouldReceiveUpdate") + } + }) + +} diff --git a/management/server/user.go b/management/server/user.go index 727bc5c6b..6c7fdfe3c 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "slices" "strings" "time" @@ -473,7 +474,7 @@ func (am *DefaultAccountManager) DeleteUser(ctx context.Context, accountID, init } func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account *Account, initiatorUserID, targetUserID string) error { - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, updateAccountPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { return err } @@ -485,15 +486,22 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account } am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } return nil } -func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) error { +func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorUserID string, targetUserID string, account *Account) (bool, error) { peers, err := account.FindUserPeers(targetUserID) if err != nil { - return status.Errorf(status.Internal, "failed to find user peers") + return false, status.Errorf(status.Internal, "failed to find user peers") + } + + hadPeers := len(peers) > 0 + if !hadPeers { + return false, nil } peerIDs := make([]string, 0, len(peers)) @@ -501,7 +509,7 @@ func (am *DefaultAccountManager) deleteUserPeers(ctx context.Context, initiatorU peerIDs = append(peerIDs, peer.ID) } - return am.deletePeers(ctx, account, peerIDs, initiatorUserID) + return hadPeers, am.deletePeers(ctx, account, peerIDs, initiatorUserID) } // InviteUser resend invitations to users who haven't activated their accounts prior to the expiration period. @@ -760,6 +768,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, updatedUsers := make([]*UserInfo, 0, len(updates)) var ( expiredPeers []*nbpeer.Peer + userIDs []string eventsToStore []func() ) @@ -768,6 +777,8 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, status.Errorf(status.InvalidArgument, "provided user update is nil") } + userIDs = append(userIDs, update.Id) + oldUser := account.Users[update.Id] if oldUser == nil { if !addIfNotExists { @@ -831,7 +842,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, return nil, err } - if account.Settings.GroupsPropagationEnabled { + if areUsersLinkedToPeers(account, userIDs) && account.Settings.GroupsPropagationEnabled { am.updateAccountPeers(ctx, account) } @@ -1182,7 +1193,10 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return status.Errorf(status.PermissionDenied, "only users with admin power can delete users") } - var allErrors error + var ( + allErrors error + updateAccountPeers bool + ) deletedUsersMeta := make(map[string]map[string]any) for _, targetUserID := range targetUserIDs { @@ -1208,12 +1222,16 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account continue } - meta, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) + meta, hadPeers, err := am.prepareUserDeletion(ctx, account, initiatorUserID, targetUserID) if err != nil { allErrors = errors.Join(allErrors, fmt.Errorf("failed to delete user %s: %s", targetUserID, err)) continue } + if hadPeers && !updateAccountPeers { + updateAccountPeers = true + } + delete(account.Users, targetUserID) deletedUsersMeta[targetUserID] = meta } @@ -1223,7 +1241,9 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return fmt.Errorf("failed to delete users: %w", err) } - am.updateAccountPeers(ctx, account) + if updateAccountPeers { + am.updateAccountPeers(ctx, account) + } for targetUserID, meta := range deletedUsersMeta { am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) @@ -1232,11 +1252,11 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account return allErrors } -func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, error) { +func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, account *Account, initiatorUserID, targetUserID string) (map[string]any, bool, error) { tuEmail, tuName, err := am.getEmailAndNameOfTargetUser(ctx, account.Id, initiatorUserID, targetUserID) if err != nil { log.WithContext(ctx).Errorf("failed to resolve email address: %s", err) - return nil, err + return nil, false, err } if !isNil(am.idpManager) { @@ -1247,16 +1267,16 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun err = am.deleteUserFromIDP(ctx, targetUserID, account.Id) if err != nil { log.WithContext(ctx).Debugf("failed to delete user from IDP: %s", targetUserID) - return nil, err + return nil, false, err } } else { log.WithContext(ctx).Debugf("skipped deleting user %s from IDP, error: %v", targetUserID, err) } } - err = am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) + hadPeers, err := am.deleteUserPeers(ctx, initiatorUserID, targetUserID, account) if err != nil { - return nil, err + return nil, false, err } u, err := account.FindUser(targetUserID) @@ -1269,7 +1289,7 @@ func (am *DefaultAccountManager) prepareUserDeletion(ctx context.Context, accoun tuCreatedAt = u.CreatedAt } - return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, nil + return map[string]any{"name": tuName, "email": tuEmail, "created_at": tuCreatedAt}, hadPeers, nil } func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserData, bool) { @@ -1280,3 +1300,13 @@ func findUserInIDPUserdata(userID string, userData []*idp.UserData) (*idp.UserDa } return nil, false } + +// areUsersLinkedToPeers checks if any of the given userIDs are linked to any of the peers in the account. +func areUsersLinkedToPeers(account *Account, userIDs []string) bool { + for _, peer := range account.Peers { + if slices.Contains(userIDs, peer.UserID) { + return true + } + } + return false +}