Refactor update account peers

Signed-off-by: bcmmbaga <bethuelmbaga12@gmail.com>
This commit is contained in:
bcmmbaga
2024-10-31 22:30:13 +03:00
parent 8cacdae70c
commit e73b5da42b
10 changed files with 49 additions and 168 deletions

View File

@@ -1165,11 +1165,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
am.handlePeerLoginExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
return newSettings, nil
}
@@ -2122,12 +2118,8 @@ func (am *DefaultAccountManager) syncJWTGroups(ctx context.Context, accountID st
}
if removedGroupAffectsPeers || newGroupsAffectsPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
log.WithContext(ctx).Tracef("user %s: JWT group membership changed, updating account peers", claims.UserId)
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
}
@@ -2362,12 +2354,7 @@ func (am *DefaultAccountManager) CheckUserAccessByJWTGroups(ctx context.Context,
func (am *DefaultAccountManager) onPeersInvalidated(ctx context.Context, accountID string) {
log.WithContext(ctx).Debugf("validated peers has been invalidated for account %s", accountID)
updatedAccount, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
return
}
am.updateAccountPeers(ctx, updatedAccount)
am.updateAccountPeers(ctx, accountID)
}
func (am *DefaultAccountManager) FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) {

View File

@@ -178,11 +178,7 @@ func (am *DefaultAccountManager) SaveDNSSettings(ctx context.Context, accountID
}
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -157,11 +157,7 @@ func (am *DefaultAccountManager) SaveGroups(ctx context.Context, accountID, user
}
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -371,11 +367,7 @@ func (am *DefaultAccountManager) GroupAddPeer(ctx context.Context, accountID, gr
}
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -421,11 +413,7 @@ func (am *DefaultAccountManager) GroupDeletePeer(ctx context.Context, accountID,
}
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -89,11 +89,7 @@ func (am *DefaultAccountManager) CreateNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, newNSGroup.ID, accountID, activity.NameserverGroupCreated, newNSGroup.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return newNSGroup.Copy(), nil
@@ -146,11 +142,7 @@ func (am *DefaultAccountManager) SaveNameServerGroup(ctx context.Context, accoun
am.StoreEvent(ctx, userID, nsGroupToSave.ID, accountID, activity.NameserverGroupUpdated, nsGroupToSave.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -195,11 +187,7 @@ func (am *DefaultAccountManager) DeleteNameServerGroup(ctx context.Context, acco
am.StoreEvent(ctx, userID, nsGroup.ID, accountID, activity.NameserverGroupDeleted, nsGroup.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -162,11 +162,7 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK
if expired {
// we need to update other peers because when peer login expires all other peers are notified to disconnect from
// the expired one. Here we notify them that connection is now allowed again.
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -309,11 +305,7 @@ func (am *DefaultAccountManager) UpdatePeer(ctx context.Context, accountID, user
}
if peerLabelUpdated {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return peer, nil
@@ -387,11 +379,7 @@ func (am *DefaultAccountManager) DeletePeer(ctx context.Context, accountID, peer
}
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -637,28 +625,11 @@ func (am *DefaultAccountManager) AddPeer(ctx context.Context, setupKey, userID s
return nil, nil, nil, err
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, fmt.Errorf("error getting account: %w", err)
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err := am.getPeerPostureChecks(ctx, account.Id, newPeer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
networkMap := account.GetPeerNetworkMap(ctx, newPeer.ID, customZone, approvedPeersMap, am.metrics.AccountManagerMetrics())
return newPeer, networkMap, postureChecks, nil
return am.getValidatedPeerWithMap(ctx, false, accountID, newPeer)
}
func (am *DefaultAccountManager) getFreeIP(ctx context.Context, store Store, accountID string) (net.IP, error) {
@@ -718,20 +689,6 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
return nil, nil, nil, err
}
var postureChecks []*posture.Checks
if peerNotValid {
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &NetworkMap{
Network: network.Copy(),
}
return peer, emptyMap, postureChecks, nil
}
updated := peer.UpdateMetaIfNew(sync.Meta)
if updated {
err = am.Store.SavePeer(ctx, LockingStrengthUpdate, accountID, peer)
@@ -740,27 +697,11 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync PeerSync, ac
}
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if isStatusChanged || (updated && sync.UpdateAccountPeers) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
validPeersMap, err := am.GetValidatedPeers(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
postureChecks, err = am.getPeerPostureChecks(ctx, accountID, peer.ID)
if err != nil {
return nil, nil, nil, err
}
customZone := account.GetPeersCustomZone(ctx, am.dnsDomain)
return peer, account.GetPeerNetworkMap(ctx, peer.ID, customZone, validPeersMap, am.metrics.AccountManagerMetrics()), postureChecks, nil
return am.getValidatedPeerWithMap(ctx, peerNotValid, accountID, peer)
}
// LoginPeer logs in or registers a peer.
@@ -875,16 +816,11 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login PeerLogin)
unlockPeer()
unlockPeer = nil
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
if updateRemotePeers || isStatusChanged {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, account, peer)
return am.getValidatedPeerWithMap(ctx, isRequiresApproval, accountID, peer)
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
@@ -916,14 +852,24 @@ func (am *DefaultAccountManager) checkIFPeerNeedsLoginWithoutLock(ctx context.Co
return nil
}
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, account *Account, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, peer *nbpeer.Peer) (*nbpeer.Peer, *NetworkMap, []*posture.Checks, error) {
if isRequiresApproval {
network, err := am.Store.GetAccountNetwork(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, nil, nil, err
}
emptyMap := &NetworkMap{
Network: account.Network.Copy(),
Network: network.Copy(),
}
return peer, emptyMap, nil, nil
}
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, err
}
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)
if err != nil {
return nil, nil, nil, err
@@ -1052,7 +998,7 @@ func (am *DefaultAccountManager) GetPeer(ctx context.Context, accountID, peerID,
// updateAccountPeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account *Account) {
func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, accountID string) {
start := time.Now()
defer func() {
if am.metrics != nil {
@@ -1060,6 +1006,12 @@ func (am *DefaultAccountManager) updateAccountPeers(ctx context.Context, account
}
}()
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to send out updates to peers. failed to get account: %v", err)
return
}
peers := account.GetPeers()
approvedPeersMap, err := am.GetValidatedPeers(ctx, account.Id)

View File

@@ -876,7 +876,7 @@ func BenchmarkUpdateAccountPeers(b *testing.B) {
start := time.Now()
for i := 0; i < b.N; i++ {
manager.updateAccountPeers(ctx, account)
manager.updateAccountPeers(ctx, accountID)
}
duration := time.Since(start)

View File

@@ -414,11 +414,7 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user
am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -462,11 +458,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po
am.StoreEvent(ctx, userID, policyID, accountID, activity.PolicyRemoved, policy.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -79,12 +79,7 @@ func (am *DefaultAccountManager) SavePostureChecks(ctx context.Context, accountI
am.StoreEvent(ctx, userID, postureChecks.ID, accountID, action, postureChecks.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -293,11 +293,7 @@ func (am *DefaultAccountManager) CreateRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(newRoute.ID), accountID, activity.RouteCreated, newRoute.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return &newRoute, nil
@@ -417,11 +413,7 @@ func (am *DefaultAccountManager) SaveRoute(ctx context.Context, accountID, userI
am.StoreEvent(ctx, userID, string(routeToSave.ID), accountID, activity.RouteUpdated, routeToSave.EventMeta())
if oldRouteAffectsPeers || newRouteAffectsPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
@@ -462,11 +454,7 @@ func (am *DefaultAccountManager) DeleteRoute(ctx context.Context, accountID stri
am.StoreEvent(ctx, userID, string(route.ID), accountID, activity.RouteRemoved, route.EventMeta())
if updateAccountPeers {
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf(errGetAccountFmt, err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil

View File

@@ -492,7 +492,7 @@ func (am *DefaultAccountManager) deleteRegularUser(ctx context.Context, account
am.StoreEvent(ctx, initiatorUserID, targetUserID, account.Id, activity.UserDeleted, meta)
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, account.Id)
}
return nil
@@ -797,7 +797,7 @@ func (am *DefaultAccountManager) SaveOrAddUsers(ctx context.Context, accountID,
}
if account.Settings.GroupsPropagationEnabled && areUsersLinkedToPeers(account, userIDs) {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
for _, storeEvent := range eventsToStore {
@@ -1088,12 +1088,7 @@ func (am *DefaultAccountManager) expireAndUpdatePeers(ctx context.Context, accou
if len(peerIDs) != 0 {
// this will trigger peer disconnect from the management service
am.peersUpdateManager.CloseChannels(ctx, peerIDs)
account, err := am.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("error getting account: %w", err)
}
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
return nil
}
@@ -1201,7 +1196,7 @@ func (am *DefaultAccountManager) DeleteRegularUsers(ctx context.Context, account
}
if updateAccountPeers {
am.updateAccountPeers(ctx, account)
am.updateAccountPeers(ctx, accountID)
}
for targetUserID, meta := range deletedUsersMeta {