From e73b5da42b795d18e06265fb400a8e8e1df605c2 Mon Sep 17 00:00:00 2001 From: bcmmbaga Date: Thu, 31 Oct 2024 22:30:13 +0300 Subject: [PATCH] Refactor update account peers Signed-off-by: bcmmbaga --- management/server/account.go | 19 +---- management/server/dns.go | 6 +- management/server/group.go | 18 +---- management/server/nameserver.go | 18 +---- management/server/peer.go | 104 ++++++++-------------------- management/server/peer_test.go | 2 +- management/server/policy.go | 12 +--- management/server/posture_checks.go | 7 +- management/server/route.go | 18 +---- management/server/user.go | 13 ++-- 10 files changed, 49 insertions(+), 168 deletions(-) diff --git a/management/server/account.go b/management/server/account.go index 676659b56..72c866289 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1165,11 +1165,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID) - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) return newSettings, nil } @@ -2122,12 +2118,8 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st } if removedGroupAffectsPeers || newGroupsAffectsPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId) - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } } @@ -2362,12 +2354,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context, func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) { log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID) - updatedAccount, err := am.Store.GetAccount(ctx, accountID) - if err != nil { - log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) - return - } - am.updateAccountPeers(ctx, updatedAccount) + am.updateAccountPeers(ctx, accountID) } func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) { diff --git a/management/server/dns.go b/management/server/dns.go index ace6c680d..719b73307 100644 --- a/management/server/dns.go +++ b/management/server/dns.go @@ -178,11 +178,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID } if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/group.go b/management/server/group.go index 2584be24a..ec3fc4680 100644 --- a/management/server/group.go +++ b/management/server/group.go @@ -157,11 +157,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user } if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -371,11 +367,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr } if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -421,11 +413,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID, } if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/nameserver.go b/management/server/nameserver.go index 48ff35987..883150510 100644 --- a/management/server/nameserver.go +++ b/management/server/nameserver.go @@ -89,11 +89,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return newNSGroup.Copy(), nil @@ -146,11 +142,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -195,11 +187,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/peer.go b/management/server/peer.go index eaa119e11..19c673818 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -162,11 +162,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK if expired { // we need to update other peers because when peer login expires all other peers are notified to disconnect from // the expired one. Here we notify them that connection is now allowed again. - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -309,11 +305,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user } if peerLabelUpdated { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return peer, nil @@ -387,11 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer } if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -637,28 +625,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s return nil, nil, nil, err } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, fmt.Errorf("error getting account: %w", err) - } - if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } - approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics()) - return newPeer, networkMap, postureChecks, nil + return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer) } func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) { @@ -718,20 +689,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac return nil, nil, nil, err } - var postureChecks []*posture.Checks - - if peerNotValid { - network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) - if err != nil { - return nil, nil, nil, err - } - - emptyMap := &NetworkMap{ - Network: network.Copy(), - } - return peer, emptyMap, postureChecks, nil - } - updated := peer.UpdateMetaIfNew(sync.Meta) if updated { err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer) @@ -740,27 +697,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac } } - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - if isStatusChanged || (updated && sync.UpdateAccountPeers) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } - validPeersMap, err := am.GetValidatedPeers(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - - postureChecks, err = am.getPeerPostureChecks(ctx, accountID, peer.ID) - if err != nil { - return nil, nil, nil, err - } - - customZone := account.GetPeersCustomZone(ctx, am.dnsDomain) - return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil + return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer) } // LoginPeer logs in or registers a peer. @@ -875,16 +816,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin) unlockPeer() unlockPeer = nil - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, nil, nil, err - } - if updateRemotePeers || isStatusChanged { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } - return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer) + return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer) } // checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO @@ -916,14 +852,24 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co return nil } -func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { +func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) { if isRequiresApproval { + network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID) + if err != nil { + return nil, nil, nil, err + } + emptyMap := &NetworkMap{ - Network: account.Network.Copy(), + Network: network.Copy(), } return peer, emptyMap, nil, nil } + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + return nil, nil, nil, err + } + approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) if err != nil { return nil, nil, nil, err @@ -1052,7 +998,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID, // updateAccountPeers updates all peers that belong to an account. // Should be called when changes have to be synced to peers. -func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) { +func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) { start := time.Now() defer func() { if am.metrics != nil { @@ -1060,6 +1006,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account } }() + account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + if err != nil { + log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err) + return + } + peers := account.GetPeers() approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index c0ae4e178..b48e94273 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -876,7 +876,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) { start := time.Now() for i := 0; i < b.N; i++ { - manager.updateAccountPeers(ctx, account) + manager.updateAccountPeers(ctx, accountID) } duration := time.Since(start) diff --git a/management/server/policy.go b/management/server/policy.go index b75853f65..1a9f3b8e2 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -414,11 +414,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -462,11 +458,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/posture_checks.go b/management/server/posture_checks.go index d75b99ffa..48065d5ad 100644 --- a/management/server/posture_checks.go +++ b/management/server/posture_checks.go @@ -79,12 +79,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("failed to get account: %w", err) - } - - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/route.go b/management/server/route.go index 9b5229092..60914e8e1 100644 --- a/management/server/route.go +++ b/management/server/route.go @@ -293,11 +293,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return nil, fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return &newRoute, nil @@ -417,11 +413,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta()) if oldRouteAffectsPeers || newRouteAffectsPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil @@ -462,11 +454,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta()) if updateAccountPeers { - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf(errGetAccountFmt, err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil diff --git a/management/server/user.go b/management/server/user.go index ac42b600b..97886e2b9 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -492,7 +492,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta) if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, account.Id) } return nil @@ -797,7 +797,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID, } if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for _, storeEvent := range eventsToStore { @@ -1088,12 +1088,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou if len(peerIDs) != 0 { // this will trigger peer disconnect from the management service am.peersUpdateManager.CloseChannels(ctx, peerIDs) - - account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID) - if err != nil { - return fmt.Errorf("error getting account: %w", err) - } - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } return nil } @@ -1201,7 +1196,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account } if updateAccountPeers { - am.updateAccountPeers(ctx, account) + am.updateAccountPeers(ctx, accountID) } for targetUserID, meta := range deletedUsersMeta {