mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
[management] faster server bootstrap (#3365)
Faster server bootstrap by counting accounts rather than fetching all from storage in the account manager instantiation. This change moved the deprecated need to ensure accounts have an All group to tests instead.
This commit is contained in:
@@ -15,7 +15,6 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"gorm.io/driver/mysql"
|
||||
"gorm.io/driver/postgres"
|
||||
@@ -24,6 +23,8 @@ import (
|
||||
"gorm.io/gorm/clause"
|
||||
"gorm.io/gorm/logger"
|
||||
|
||||
"github.com/netbirdio/netbird/management/server/util"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/account"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -615,6 +616,16 @@ func (s *SqlStore) GetResourceGroups(ctx context.Context, lockStrength LockingSt
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
||||
var count int64
|
||||
result := s.db.Model(&types.Account{}).Count(&count)
|
||||
if result.Error != nil {
|
||||
return 0, fmt.Errorf("failed to get all accounts counter: %w", result.Error)
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (s *SqlStore) GetAllAccounts(ctx context.Context) (all []*types.Account) {
|
||||
var accounts []types.Account
|
||||
result := s.db.Find(&accounts)
|
||||
@@ -1035,6 +1046,13 @@ func NewSqliteStoreFromFileStore(ctx context.Context, fileStore *FileStore, data
|
||||
}
|
||||
|
||||
for _, account := range fileStore.GetAllAccounts(ctx) {
|
||||
_, err = account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err := store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/xid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -2045,52 +2044,12 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain string) *ty
|
||||
},
|
||||
}
|
||||
|
||||
if err := addAllGroup(acc); err != nil {
|
||||
if err := acc.AddAllGroup(); err != nil {
|
||||
log.WithContext(ctx).Errorf("error adding all group to account %s: %v", acc.Id, err)
|
||||
}
|
||||
return acc
|
||||
}
|
||||
|
||||
// addAllGroup to account object if it doesn't exist
|
||||
func addAllGroup(account *types.Account) error {
|
||||
if len(account.Groups) == 0 {
|
||||
allGroup := &types.Group{
|
||||
ID: xid.New().String(),
|
||||
Name: "All",
|
||||
Issued: types.GroupIssuedAPI,
|
||||
}
|
||||
for _, peer := range account.Peers {
|
||||
allGroup.Peers = append(allGroup.Peers, peer.ID)
|
||||
}
|
||||
account.Groups = map[string]*types.Group{allGroup.ID: allGroup}
|
||||
|
||||
id := xid.New().String()
|
||||
|
||||
defaultPolicy := &types.Policy{
|
||||
ID: id,
|
||||
Name: types.DefaultRuleName,
|
||||
Description: types.DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Rules: []*types.PolicyRule{
|
||||
{
|
||||
ID: id,
|
||||
Name: types.DefaultRuleName,
|
||||
Description: types.DefaultRuleDescription,
|
||||
Enabled: true,
|
||||
Sources: []string{allGroup.ID},
|
||||
Destinations: []string{allGroup.ID},
|
||||
Bidirectional: true,
|
||||
Protocol: types.PolicyRuleProtocolALL,
|
||||
Action: types.PolicyTrafficActionAccept,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.Policies = []*types.Policy{defaultPolicy}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSqlStore_GetAccountNetworks(t *testing.T) {
|
||||
store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store.sql", t.TempDir())
|
||||
t.Cleanup(cleanup)
|
||||
|
||||
@@ -48,6 +48,7 @@ const (
|
||||
)
|
||||
|
||||
type Store interface {
|
||||
GetAccountsCounter(ctx context.Context) (int64, error)
|
||||
GetAllAccounts(ctx context.Context) []*types.Account
|
||||
GetAccount(ctx context.Context, accountID string) (*types.Account, error)
|
||||
AccountExists(ctx context.Context, lockStrength LockingStrength, id string) (bool, error)
|
||||
@@ -352,9 +353,37 @@ func NewTestStoreFromSQL(ctx context.Context, filename string, dataDir string) (
|
||||
return nil, nil, fmt.Errorf("failed to create test store: %v", err)
|
||||
}
|
||||
|
||||
err = addAllGroupToAccount(ctx, store)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to add all group to account: %v", err)
|
||||
}
|
||||
|
||||
return getSqlStoreEngine(ctx, store, kind)
|
||||
}
|
||||
|
||||
func addAllGroupToAccount(ctx context.Context, store Store) error {
|
||||
allAccounts := store.GetAllAccounts(ctx)
|
||||
for _, account := range allAccounts {
|
||||
shouldSave := false
|
||||
|
||||
_, err := account.GetGroupAll()
|
||||
if err != nil {
|
||||
if err := account.AddAllGroup(); err != nil {
|
||||
return err
|
||||
}
|
||||
shouldSave = true
|
||||
}
|
||||
|
||||
if shouldSave {
|
||||
err = store.SaveAccount(ctx, account)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getSqlStoreEngine(ctx context.Context, store *SqlStore, kind Engine) (Store, func(), error) {
|
||||
var cleanup func()
|
||||
var err error
|
||||
|
||||
Reference in New Issue
Block a user