diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 022ea774c..97eebfe6e 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -144,7 +144,7 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin if c.experimentalNetworkMap(accountID) { account = c.getAccountFromHolderOrInit(accountID) } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err = c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountID) if err != nil { return fmt.Errorf("failed to get account: %v", err) } @@ -300,7 +300,7 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe return fmt.Errorf("peer %s doesn't have a channel, skipping network map update", peerId) } - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + account, err := c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountId) if err != nil { return fmt.Errorf("failed to send out updates to peer %s: %v", peerId, err) } @@ -414,7 +414,7 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr if c.experimentalNetworkMap(accountID) { account = c.getAccountFromHolderOrInit(accountID) } else { - account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err = c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountID) if err != nil { return nil, nil, nil, 0, err } @@ -506,7 +506,7 @@ func (c *Controller) recalculateNetworkMapCache(account *types.Account, validate func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error { if c.experimentalNetworkMap(accountId) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId) + account, err := c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountId) if err != nil { return err } @@ -548,7 +548,7 @@ func (c *Controller) getAccountFromHolderOrInit(accountID string) *types.Account if a != nil { return a } - account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithBackpressure) + account, err := c.holder.LoadOrStoreFunc(accountID, c.requestBuffer.GetAccountWithoutUsersWithBackpressure) if err != nil { return nil } @@ -715,7 +715,7 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { for _, peerID := range peerIDs { if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err := c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountID) if err != nil { return err } @@ -761,7 +761,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI c.peersUpdateManager.CloseChannel(ctx, peerID) if c.experimentalNetworkMap(accountID) { - account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) + account, err := c.requestBuffer.GetAccountWithoutUsersWithBackpressure(ctx, accountID) if err != nil { log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err) continue diff --git a/management/server/account/request_buffer.go b/management/server/account/request_buffer.go index eced1929f..e528d9651 100644 --- a/management/server/account/request_buffer.go +++ b/management/server/account/request_buffer.go @@ -7,5 +7,5 @@ import ( ) type RequestBuffer interface { - GetAccountWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) + GetAccountWithoutUsersWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) } diff --git a/management/server/account_request_buffer.go b/management/server/account_request_buffer.go index fa6c45856..5d84b1042 100644 --- a/management/server/account_request_buffer.go +++ b/management/server/account_request_buffer.go @@ -25,11 +25,13 @@ type AccountResult struct { } type AccountRequestBuffer struct { - store store.Store - getAccountRequests map[string][]*AccountRequest - mu sync.Mutex - getAccountRequestCh chan *AccountRequest - bufferInterval time.Duration + store store.Store + getAccountRequests map[string][]*AccountRequest + getAccountWithoutUsersRequests map[string][]*AccountRequest + mu sync.Mutex + getAccountRequestCh chan *AccountRequest + getAccountWithoutUsersRequestCh chan *AccountRequest + bufferInterval time.Duration } func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountRequestBuffer { @@ -45,13 +47,16 @@ func NewAccountRequestBuffer(ctx context.Context, store store.Store) *AccountReq log.WithContext(ctx).Infof("set account request buffer interval to %s", bufferInterval) ac := AccountRequestBuffer{ - store: store, - getAccountRequests: make(map[string][]*AccountRequest), - getAccountRequestCh: make(chan *AccountRequest), - bufferInterval: bufferInterval, + store: store, + getAccountRequests: make(map[string][]*AccountRequest), + getAccountWithoutUsersRequests: make(map[string][]*AccountRequest), + getAccountRequestCh: make(chan *AccountRequest), + getAccountWithoutUsersRequestCh: make(chan *AccountRequest), + bufferInterval: bufferInterval, } go ac.processGetAccountRequests(ctx) + go ac.processGetAccountWithoutUsersRequests(ctx) return &ac } @@ -70,6 +75,21 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context, return result.Account, result.Err } +func (ac *AccountRequestBuffer) GetAccountWithoutUsersWithBackpressure(ctx context.Context, accountID string) (*types.Account, error) { + req := &AccountRequest{ + AccountID: accountID, + ResultChan: make(chan *AccountResult, 1), + } + + log.WithContext(ctx).Tracef("requesting account without users %s with backpressure", accountID) + startTime := time.Now() + ac.getAccountWithoutUsersRequestCh <- req + + result := <-req.ResultChan + log.WithContext(ctx).Tracef("got account without users with backpressure after %s", time.Since(startTime)) + return result.Account, result.Err +} + func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) { ac.mu.Lock() requests := ac.getAccountRequests[accountID] @@ -109,3 +129,43 @@ func (ac *AccountRequestBuffer) processGetAccountRequests(ctx context.Context) { } } } + +func (ac *AccountRequestBuffer) processGetAccountWithoutUsersBatch(ctx context.Context, accountID string) { + ac.mu.Lock() + requests := ac.getAccountWithoutUsersRequests[accountID] + delete(ac.getAccountWithoutUsersRequests, accountID) + ac.mu.Unlock() + + if len(requests) == 0 { + return + } + + startTime := time.Now() + account, err := ac.store.GetAccountWithoutUsers(ctx, accountID) + log.WithContext(ctx).Tracef("getting account without users %s in batch took %s", accountID, time.Since(startTime)) + result := &AccountResult{Account: account, Err: err} + + for _, req := range requests { + req.ResultChan <- result + close(req.ResultChan) + } +} + +func (ac *AccountRequestBuffer) processGetAccountWithoutUsersRequests(ctx context.Context) { + for { + select { + case req := <-ac.getAccountWithoutUsersRequestCh: + ac.mu.Lock() + ac.getAccountWithoutUsersRequests[req.AccountID] = append(ac.getAccountWithoutUsersRequests[req.AccountID], req) + if len(ac.getAccountWithoutUsersRequests[req.AccountID]) == 1 { + go func(ctx context.Context, accountID string) { + time.Sleep(ac.bufferInterval) + ac.processGetAccountWithoutUsersBatch(ctx, accountID) + }(ctx, req.AccountID) + } + ac.mu.Unlock() + case <-ctx.Done(): + return + } + } +} diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 2b8981b97..dd4f0a231 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -796,6 +796,13 @@ func (s *SqlStore) GetAccount(ctx context.Context, accountID string) (*types.Acc return s.getAccountGorm(ctx, accountID) } +func (s *SqlStore) GetAccountWithoutUsers(ctx context.Context, accountID string) (*types.Account, error) { + if s.pool != nil { + return s.getAccountWithoutUsersPgx(ctx, accountID) + } + return s.getAccountWithoutUsersGorm(ctx, accountID) +} + func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types.Account, error) { start := time.Now() defer func() { @@ -897,6 +904,94 @@ func (s *SqlStore) getAccountGorm(ctx context.Context, accountID string) (*types return &account, nil } +func (s *SqlStore) getAccountWithoutUsersGorm(ctx context.Context, accountID string) (*types.Account, error) { + start := time.Now() + defer func() { + elapsed := time.Since(start) + if elapsed > 1*time.Second { + log.WithContext(ctx).Tracef("GetAccountWithoutUsers for account %s exceeded 1s, took: %v", accountID, elapsed) + } + }() + + var account types.Account + result := s.db.Model(&account). + Preload("Policies.Rules"). + Preload("SetupKeysG"). + Preload("PeersG"). + Preload("GroupsG.GroupPeers"). + Preload("RoutesG"). + Preload("NameServerGroupsG"). + Preload("PostureChecks"). + Preload("Networks"). + Preload("NetworkRouters"). + Preload("NetworkResources"). + Preload("Onboarding"). + Take(&account, idQueryCondition, accountID) + if result.Error != nil { + log.WithContext(ctx).Errorf("error when getting account %s from the store: %s", accountID, result.Error) + if errors.Is(result.Error, gorm.ErrRecordNotFound) { + return nil, status.NewAccountNotFoundError(accountID) + } + return nil, status.NewGetAccountFromStoreError(result.Error) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for _, key := range account.SetupKeysG { + if key.UpdatedAt.IsZero() { + key.UpdatedAt = key.CreatedAt + } + if key.AutoGroups == nil { + key.AutoGroups = []string{} + } + account.SetupKeys[key.Key] = &key + } + account.SetupKeysG = nil + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for _, peer := range account.PeersG { + account.Peers[peer.ID] = &peer + } + account.PeersG = nil + + account.Users = make(map[string]*types.User) + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for _, group := range account.GroupsG { + group.Peers = make([]string, len(group.GroupPeers)) + for i, gp := range group.GroupPeers { + group.Peers[i] = gp.PeerID + } + if group.Resources == nil { + group.Resources = []types.Resource{} + } + account.Groups[group.ID] = group + } + account.GroupsG = nil + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for _, route := range account.RoutesG { + account.Routes[route.ID] = &route + } + account.RoutesG = nil + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for _, ns := range account.NameServerGroupsG { + ns.AccountID = "" + if ns.NameServers == nil { + ns.NameServers = []nbdns.NameServer{} + } + if ns.Groups == nil { + ns.Groups = []string{} + } + if ns.Domains == nil { + ns.Domains = []string{} + } + account.NameServerGroups[ns.ID] = &ns + } + account.NameServerGroupsG = nil + account.InitOnce() + return &account, nil +} + func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types.Account, error) { account, err := s.getAccount(ctx, accountID) if err != nil { @@ -1180,6 +1275,246 @@ func (s *SqlStore) getAccountPgx(ctx context.Context, accountID string) (*types. return account, nil } +func (s *SqlStore) getAccountWithoutUsersPgx(ctx context.Context, accountID string) (*types.Account, error) { + account, err := s.getAccount(ctx, accountID) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + errChan := make(chan error, 11) + + wg.Add(1) + go func() { + defer wg.Done() + keys, err := s.getSetupKeys(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.SetupKeysG = keys + }() + + wg.Add(1) + go func() { + defer wg.Done() + peers, err := s.getPeers(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PeersG = peers + }() + + wg.Add(1) + go func() { + defer wg.Done() + groups, err := s.getGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.GroupsG = groups + }() + + wg.Add(1) + go func() { + defer wg.Done() + policies, err := s.getPolicies(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Policies = policies + }() + + wg.Add(1) + go func() { + defer wg.Done() + routes, err := s.getRoutes(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.RoutesG = routes + }() + + wg.Add(1) + go func() { + defer wg.Done() + nsgs, err := s.getNameServerGroups(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NameServerGroupsG = nsgs + }() + + wg.Add(1) + go func() { + defer wg.Done() + checks, err := s.getPostureChecks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.PostureChecks = checks + }() + + wg.Add(1) + go func() { + defer wg.Done() + networks, err := s.getNetworks(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.Networks = networks + }() + + wg.Add(1) + go func() { + defer wg.Done() + routers, err := s.getNetworkRouters(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkRouters = routers + }() + + wg.Add(1) + go func() { + defer wg.Done() + resources, err := s.getNetworkResources(ctx, accountID) + if err != nil { + errChan <- err + return + } + account.NetworkResources = resources + }() + + wg.Add(1) + go func() { + defer wg.Done() + err := s.getAccountOnboarding(ctx, accountID, account) + if err != nil { + errChan <- err + return + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + var policyIDs []string + for _, p := range account.Policies { + policyIDs = append(policyIDs, p.ID) + } + var groupIDs []string + for _, g := range account.GroupsG { + groupIDs = append(groupIDs, g.ID) + } + + wg.Add(2) + errChan = make(chan error, 2) + + var rules []*types.PolicyRule + go func() { + defer wg.Done() + var err error + rules, err = s.getPolicyRules(ctx, policyIDs) + if err != nil { + errChan <- err + } + }() + + var groupPeers []types.GroupPeer + go func() { + defer wg.Done() + var err error + groupPeers, err = s.getGroupPeers(ctx, groupIDs) + if err != nil { + errChan <- err + } + }() + + wg.Wait() + close(errChan) + for e := range errChan { + if e != nil { + return nil, e + } + } + + rulesByPolicyID := make(map[string][]*types.PolicyRule) + for _, rule := range rules { + rulesByPolicyID[rule.PolicyID] = append(rulesByPolicyID[rule.PolicyID], rule) + } + + peersByGroupID := make(map[string][]string) + for _, gp := range groupPeers { + peersByGroupID[gp.GroupID] = append(peersByGroupID[gp.GroupID], gp.PeerID) + } + + account.SetupKeys = make(map[string]*types.SetupKey, len(account.SetupKeysG)) + for i := range account.SetupKeysG { + key := &account.SetupKeysG[i] + account.SetupKeys[key.Key] = key + } + + account.Peers = make(map[string]*nbpeer.Peer, len(account.PeersG)) + for i := range account.PeersG { + peer := &account.PeersG[i] + account.Peers[peer.ID] = peer + } + + account.Users = make(map[string]*types.User) + + for i := range account.Policies { + policy := account.Policies[i] + if policyRules, ok := rulesByPolicyID[policy.ID]; ok { + policy.Rules = policyRules + } + } + + account.Groups = make(map[string]*types.Group, len(account.GroupsG)) + for i := range account.GroupsG { + group := account.GroupsG[i] + if peerIDs, ok := peersByGroupID[group.ID]; ok { + group.Peers = peerIDs + } + account.Groups[group.ID] = group + } + + account.Routes = make(map[route.ID]*route.Route, len(account.RoutesG)) + for i := range account.RoutesG { + route := &account.RoutesG[i] + account.Routes[route.ID] = route + } + + account.NameServerGroups = make(map[string]*nbdns.NameServerGroup, len(account.NameServerGroupsG)) + for i := range account.NameServerGroupsG { + nsg := &account.NameServerGroupsG[i] + nsg.AccountID = "" + account.NameServerGroups[nsg.ID] = nsg + } + + account.SetupKeysG = nil + account.PeersG = nil + account.UsersG = nil + account.GroupsG = nil + account.RoutesG = nil + account.NameServerGroupsG = nil + + return account, nil +} + func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Account, error) { var account types.Account account.Network = &types.Network{} diff --git a/management/server/store/store.go b/management/server/store/store.go index 0ec7949f9..d3cac56d9 100644 --- a/management/server/store/store.go +++ b/management/server/store/store.go @@ -51,6 +51,7 @@ type Store interface { GetAccountsCounter(ctx context.Context) (int64, error) GetAllAccounts(ctx context.Context) []*types.Account GetAccount(ctx context.Context, accountID string) (*types.Account, error) + GetAccountWithoutUsers(ctx context.Context, accountID string) (*types.Account, error) GetAccountMeta(ctx context.Context, lockStrength LockingStrength, accountID string) (*types.AccountMeta, error) GetAccountOnboarding(ctx context.Context, accountID string) (*types.AccountOnboarding, error) AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)