diff --git a/management/server/holder.go b/management/server/holder.go index f1c79f4ce..517f6cc47 100644 --- a/management/server/holder.go +++ b/management/server/holder.go @@ -1,6 +1,8 @@ package server import ( + "context" + "github.com/netbirdio/netbird/management/server/types" ) @@ -18,6 +20,19 @@ func (am *DefaultAccountManager) enrichAccountFromHolder(account *types.Account) am.holder.AddAccount(account) } +func (am *DefaultAccountManager) getAccountFromHolder(accountID string) *types.Account { + a := am.holder.GetAccount(accountID) + if a != nil { + return a + } + account, err := am.Store.GetAccount(context.Background(), accountID) + if err != nil { + return nil + } + am.holder.AddAccount(account) + return account +} + func (am *DefaultAccountManager) updateAccountInHolder(account *types.Account) { am.holder.AddAccount(account) } diff --git a/management/server/networkmap.go b/management/server/networkmap.go index ff8b3f1b0..c9fa2bba8 100644 --- a/management/server/networkmap.go +++ b/management/server/networkmap.go @@ -19,13 +19,13 @@ func (am *DefaultAccountManager) initNetworkMapBuilderIfNeeded(account *types.Ac func (am *DefaultAccountManager) getPeerNetworkMapExp( ctx context.Context, - account *types.Account, + accountId string, peerId string, validatedPeers map[string]struct{}, customZone nbdns.CustomZone, metrics *telemetry.AccountManagerMetrics, ) *types.NetworkMap { - am.enrichAccountFromHolder(account) + account := am.getAccountFromHolder(accountId) return account.GetPeerNetworkMapExp(ctx, peerId, customZone, validatedPeers, metrics) } diff --git a/management/server/peer.go b/management/server/peer.go index bb8a5dd37..23bbf1be4 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -453,7 +453,7 @@ func (am *DefaultAccountManager) GetNetworkMap(ctx context.Context, peerID strin var networkMap *types.NetworkMap if am.expNewNetworkMap { - networkMap = am.getPeerNetworkMapExp(ctx, account, peerID, validatedPeers, customZone, nil) + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, customZone, nil) } else { networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, validatedPeers, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), nil) } @@ -1093,7 +1093,7 @@ func (am *DefaultAccountManager) getValidatedPeerWithMap(ctx context.Context, is var networkMap *types.NetworkMap if am.expNewNetworkMap { - networkMap = am.getPeerNetworkMapExp(ctx, account, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + networkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) } else { networkMap = account.GetPeerNetworkMap(ctx, peer.ID, customZone, approvedPeersMap, account.GetResourcePoliciesMap(), account.GetResourceRoutersMap(), am.metrics.AccountManagerMetrics()) } @@ -1305,7 +1305,7 @@ func (am *DefaultAccountManager) UpdateAccountPeers(ctx context.Context, account var remotePeerNetworkMap *types.NetworkMap if am.expNewNetworkMap { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, account, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) } else { remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) } @@ -1421,7 +1421,7 @@ func (am *DefaultAccountManager) UpdateAccountPeer(ctx context.Context, accountI var remotePeerNetworkMap *types.NetworkMap if am.expNewNetworkMap { - remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, account, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) + remotePeerNetworkMap = am.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, customZone, am.metrics.AccountManagerMetrics()) } else { remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, customZone, approvedPeersMap, resourcePolicies, routers, am.metrics.AccountManagerMetrics()) }