mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-02 23:26:41 +00:00
Merge branch 'main' of github.com:netbirdio/netbird into feat/local-user-totp
This commit is contained in:
@@ -742,11 +742,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
|
||||
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
|
||||
if err != nil {
|
||||
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
|
||||
}
|
||||
|
||||
for _, otherUser := range account.Users {
|
||||
if otherUser.Id == userID {
|
||||
continue
|
||||
|
||||
@@ -75,7 +75,7 @@ type Manager interface {
|
||||
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
|
||||
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
|
||||
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error
|
||||
|
||||
@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
|
||||
ret0, _ := ret[0].(*types.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
|
||||
}
|
||||
|
||||
// GetIdentityProvider mocks base method.
|
||||
|
||||
@@ -63,20 +63,11 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
|
||||
|
||||
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
|
||||
startTime := time.Now()
|
||||
ac.getAccountRequestCh <- req
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case ac.getAccountRequestCh <- req:
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case result := <-req.ResultChan:
|
||||
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
|
||||
return result.Account, result.Err
|
||||
}
|
||||
result := <-req.ResultChan
|
||||
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
|
||||
return result.Account, result.Err
|
||||
}
|
||||
|
||||
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
"github.com/prometheus/client_golang/prometheus/push"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -23,6 +22,9 @@ import (
|
||||
"go.opentelemetry.io/otel/metric/noop"
|
||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/shared/management/status"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
|
||||
@@ -1815,6 +1817,13 @@ func TestAccount_Copy(t *testing.T) {
|
||||
Targets: []*service.Target{},
|
||||
},
|
||||
},
|
||||
Domains: []*domain.Domain{
|
||||
{
|
||||
ID: "domain1",
|
||||
Domain: "test.com",
|
||||
AccountID: "account1",
|
||||
},
|
||||
},
|
||||
NetworkMapCache: &types.NetworkMapBuilder{},
|
||||
}
|
||||
account.InitOnce()
|
||||
|
||||
@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
|
||||
}
|
||||
|
||||
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
|
||||
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
|
||||
groupName := r.URL.Query().Get("name")
|
||||
if groupName != "" {
|
||||
// Get single group by name
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
|
||||
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
|
||||
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
|
||||
if err != nil {
|
||||
util.WriteError(r.Context(), err, w)
|
||||
return
|
||||
|
||||
@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
|
||||
|
||||
return groups, nil
|
||||
},
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
|
||||
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
|
||||
if groupName == "All" {
|
||||
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
|
||||
}
|
||||
|
||||
@@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasForeignKey checks whether a foreign key constraint exists on the given table and column.
|
||||
func hasForeignKey(db *gorm.DB, table, column string) bool {
|
||||
var count int64
|
||||
|
||||
switch db.Name() {
|
||||
case "postgres":
|
||||
db.Raw(`
|
||||
SELECT COUNT(*) FROM information_schema.key_column_usage kcu
|
||||
JOIN information_schema.table_constraints tc
|
||||
ON tc.constraint_name = kcu.constraint_name
|
||||
AND tc.table_schema = kcu.table_schema
|
||||
WHERE tc.constraint_type = 'FOREIGN KEY'
|
||||
AND kcu.table_name = ?
|
||||
AND kcu.column_name = ?
|
||||
`, table, column).Scan(&count)
|
||||
case "mysql":
|
||||
db.Raw(`
|
||||
SELECT COUNT(*) FROM information_schema.key_column_usage
|
||||
WHERE table_schema = DATABASE()
|
||||
AND table_name = ?
|
||||
AND column_name = ?
|
||||
AND referenced_table_name IS NOT NULL
|
||||
`, table, column).Scan(&count)
|
||||
default: // sqlite
|
||||
type fkInfo struct {
|
||||
From string
|
||||
}
|
||||
var fks []fkInfo
|
||||
db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks)
|
||||
for _, fk := range fks {
|
||||
if fk.From == column {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return count > 0
|
||||
}
|
||||
|
||||
// CleanupOrphanedResources deletes rows from the table of model T where the foreign
|
||||
// key column (fkColumn) references a row in the table of model R that no longer exists.
|
||||
func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error {
|
||||
var model T
|
||||
var refModel R
|
||||
|
||||
if !db.Migrator().HasTable(&model) {
|
||||
log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !db.Migrator().HasTable(&refModel) {
|
||||
log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel)
|
||||
return nil
|
||||
}
|
||||
|
||||
stmtT := &gorm.Statement{DB: db}
|
||||
if err := stmtT.Parse(&model); err != nil {
|
||||
return fmt.Errorf("parse model %T: %w", model, err)
|
||||
}
|
||||
childTable := stmtT.Schema.Table
|
||||
|
||||
stmtR := &gorm.Statement{DB: db}
|
||||
if err := stmtR.Parse(&refModel); err != nil {
|
||||
return fmt.Errorf("parse reference model %T: %w", refModel, err)
|
||||
}
|
||||
parentTable := stmtR.Schema.Table
|
||||
|
||||
if !db.Migrator().HasColumn(&model, fkColumn) {
|
||||
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable)
|
||||
return nil
|
||||
}
|
||||
|
||||
// If a foreign key constraint already exists on the column, the DB itself
|
||||
// enforces referential integrity and orphaned rows cannot exist.
|
||||
if hasForeignKey(db, childTable, fkColumn) {
|
||||
log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := db.Exec(
|
||||
fmt.Sprintf(
|
||||
"DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)",
|
||||
childTable, fkColumn, parentTable,
|
||||
),
|
||||
)
|
||||
if result.Error != nil {
|
||||
return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error)
|
||||
}
|
||||
|
||||
log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s",
|
||||
result.RowsAffected, childTable, fkColumn, parentTable)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
|
||||
if !db.Migrator().HasTable("peers") {
|
||||
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")
|
||||
|
||||
@@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) {
|
||||
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
|
||||
require.NoError(t, err, "Should not fail when table does not exist")
|
||||
}
|
||||
|
||||
type testParent struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
}
|
||||
|
||||
func (testParent) TableName() string {
|
||||
return "test_parents"
|
||||
}
|
||||
|
||||
type testChild struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
ParentID string
|
||||
}
|
||||
|
||||
func (testChild) TableName() string {
|
||||
return "test_children"
|
||||
}
|
||||
|
||||
type testChildWithFK struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
ParentID string `gorm:"index"`
|
||||
Parent *testParent `gorm:"foreignKey:ParentID"`
|
||||
}
|
||||
|
||||
func (testChildWithFK) TableName() string {
|
||||
return "test_children"
|
||||
}
|
||||
|
||||
func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB {
|
||||
t.Helper()
|
||||
db := setupDatabase(t)
|
||||
for _, m := range models {
|
||||
_ = db.Migrator().DropTable(m)
|
||||
}
|
||||
err := db.AutoMigrate(models...)
|
||||
require.NoError(t, err, "Failed to auto-migrate tables")
|
||||
return db
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_NoChildTable(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
_ = db.Migrator().DropTable(&testChild{})
|
||||
_ = db.Migrator().DropTable(&testParent{})
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err, "Should not fail when child table does not exist")
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_NoParentTable(t *testing.T) {
|
||||
db := setupDatabase(t)
|
||||
_ = db.Migrator().DropTable(&testParent{})
|
||||
_ = db.Migrator().DropTable(&testChild{})
|
||||
|
||||
err := db.AutoMigrate(&testChild{})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err, "Should not fail when parent table does not exist")
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_EmptyTables(t *testing.T) {
|
||||
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err, "Should not fail on empty tables")
|
||||
|
||||
var count int64
|
||||
db.Model(&testChild{}).Count(&count)
|
||||
assert.Equal(t, int64(0), count)
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_NoOrphans(t *testing.T) {
|
||||
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
|
||||
|
||||
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
|
||||
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testChild{}).Count(&count)
|
||||
assert.Equal(t, int64(2), count, "All children should remain when no orphans")
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_AllOrphans(t *testing.T) {
|
||||
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
|
||||
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error)
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error)
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error)
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testChild{}).Count(&count)
|
||||
assert.Equal(t, int64(0), count, "All orphaned children should be deleted")
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) {
|
||||
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
|
||||
|
||||
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
|
||||
|
||||
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
|
||||
require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error)
|
||||
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error)
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error)
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
var remaining []testChild
|
||||
require.NoError(t, db.Order("id").Find(&remaining).Error)
|
||||
|
||||
assert.Len(t, remaining, 3, "Only valid children should remain")
|
||||
assert.Equal(t, "c1", remaining[0].ID)
|
||||
assert.Equal(t, "c2", remaining[1].ID)
|
||||
assert.Equal(t, "c3", remaining[2].ID)
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_Idempotent(t *testing.T) {
|
||||
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
|
||||
|
||||
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
|
||||
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testChild{}).Count(&count)
|
||||
assert.Equal(t, int64(1), count)
|
||||
|
||||
err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
db.Model(&testChild{}).Count(&count)
|
||||
assert.Equal(t, int64(1), count, "Count should remain the same after second run")
|
||||
}
|
||||
|
||||
func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) {
|
||||
engine := os.Getenv("NETBIRD_STORE_ENGINE")
|
||||
if engine != "postgres" && engine != "mysql" {
|
||||
t.Skip("FK constraint early-exit test requires postgres or mysql")
|
||||
}
|
||||
|
||||
db := setupDatabase(t)
|
||||
_ = db.Migrator().DropTable(&testChildWithFK{})
|
||||
_ = db.Migrator().DropTable(&testParent{})
|
||||
|
||||
err := db.AutoMigrate(&testParent{}, &testChildWithFK{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
|
||||
require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error)
|
||||
require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error)
|
||||
|
||||
switch engine {
|
||||
case "postgres":
|
||||
require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error)
|
||||
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
|
||||
require.NoError(t, db.Exec(
|
||||
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
|
||||
"FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID",
|
||||
).Error)
|
||||
case "mysql":
|
||||
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error)
|
||||
require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error)
|
||||
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
|
||||
require.NoError(t, db.Exec(
|
||||
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
|
||||
"FOREIGN KEY (parent_id) REFERENCES test_parents(id)",
|
||||
).Error)
|
||||
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error)
|
||||
}
|
||||
|
||||
err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id")
|
||||
require.NoError(t, err)
|
||||
|
||||
var count int64
|
||||
db.Model(&testChildWithFK{}).Count(&count)
|
||||
assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists")
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ type MockAccountManager struct {
|
||||
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
|
||||
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
|
||||
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
|
||||
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
|
||||
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
|
||||
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
|
||||
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
|
||||
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
|
||||
}
|
||||
|
||||
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
|
||||
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||
if am.GetGroupByNameFunc != nil {
|
||||
return am.GetGroupByNameFunc(ctx, accountID, groupName)
|
||||
return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
|
||||
}
|
||||
|
||||
@@ -396,6 +396,11 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
|
||||
return result.Error
|
||||
}
|
||||
|
||||
result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
result = tx.Select(clause.Associations).Delete(account)
|
||||
if result.Error != nil {
|
||||
return result.Error
|
||||
@@ -1012,10 +1017,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
|
||||
// GetCustomDomainsCounts returns the total and validated custom domain counts.
|
||||
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
|
||||
var total, validated int64
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
||||
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
||||
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return total, validated, nil
|
||||
@@ -2080,7 +2085,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
|
||||
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
|
||||
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
|
||||
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key
|
||||
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
|
||||
mode, listen_port, port_auto_assigned, source, source_peer, terminated
|
||||
FROM services WHERE account_id = $1`
|
||||
|
||||
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
|
||||
@@ -2097,6 +2103,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
var auth []byte
|
||||
var createdAt, certIssuedAt sql.NullTime
|
||||
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
|
||||
var mode, source, sourcePeer sql.NullString
|
||||
var terminated, portAutoAssigned sql.NullBool
|
||||
var listenPort sql.NullInt64
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
&s.AccountID,
|
||||
@@ -2112,6 +2121,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
&s.RewriteRedirects,
|
||||
&sessionPrivateKey,
|
||||
&sessionPublicKey,
|
||||
&mode,
|
||||
&listenPort,
|
||||
&portAutoAssigned,
|
||||
&source,
|
||||
&sourcePeer,
|
||||
&terminated,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -2143,7 +2158,24 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
|
||||
if sessionPublicKey.Valid {
|
||||
s.SessionPublicKey = sessionPublicKey.String
|
||||
}
|
||||
|
||||
if mode.Valid {
|
||||
s.Mode = mode.String
|
||||
}
|
||||
if source.Valid {
|
||||
s.Source = source.String
|
||||
}
|
||||
if sourcePeer.Valid {
|
||||
s.SourcePeer = sourcePeer.String
|
||||
}
|
||||
if terminated.Valid {
|
||||
s.Terminated = terminated.Bool
|
||||
}
|
||||
if portAutoAssigned.Valid {
|
||||
s.PortAutoAssigned = portAutoAssigned.Bool
|
||||
}
|
||||
if listenPort.Valid {
|
||||
s.ListenPort = uint16(listenPort.Int64)
|
||||
}
|
||||
s.Targets = []*rpservice.Target{}
|
||||
return &s, nil
|
||||
})
|
||||
@@ -4410,7 +4442,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
|
||||
|
||||
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
|
||||
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4429,7 +4461,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
|
||||
|
||||
// GetAllProxyAccessTokens retrieves all proxy access tokens.
|
||||
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -4445,7 +4477,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
|
||||
|
||||
// SaveProxyAccessToken saves a proxy access token to the database.
|
||||
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
|
||||
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
|
||||
if result := s.db.Create(token); result.Error != nil {
|
||||
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
|
||||
}
|
||||
return nil
|
||||
@@ -4453,7 +4485,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
|
||||
|
||||
// RevokeProxyAccessToken revokes a proxy access token by its ID.
|
||||
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
|
||||
if result.Error != nil {
|
||||
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
|
||||
}
|
||||
@@ -4467,7 +4499,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
|
||||
|
||||
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
|
||||
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
|
||||
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
|
||||
result := s.db.Model(&types.ProxyAccessToken{}).
|
||||
Where(idQueryCondition, tokenID).
|
||||
Update("last_used", time.Now().UTC())
|
||||
if result.Error != nil {
|
||||
@@ -5136,7 +5168,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
|
||||
|
||||
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
|
||||
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -5152,7 +5184,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
|
||||
|
||||
// GetServicesByCluster returns all services for the given proxy cluster.
|
||||
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
|
||||
tx := s.db.WithContext(ctx)
|
||||
tx := s.db
|
||||
if lockStrength != LockingStrengthNone {
|
||||
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
|
||||
}
|
||||
@@ -5262,7 +5294,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
var logs []*accesslogs.AccessLogEntry
|
||||
var totalCount int64
|
||||
|
||||
baseQuery := s.db.WithContext(ctx).
|
||||
baseQuery := s.db.
|
||||
Model(&accesslogs.AccessLogEntry{}).
|
||||
Where(accountIDCondition, accountID)
|
||||
|
||||
@@ -5273,7 +5305,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
|
||||
}
|
||||
|
||||
query := s.db.WithContext(ctx).
|
||||
query := s.db.
|
||||
Where(accountIDCondition, accountID)
|
||||
|
||||
query = s.applyAccessLogFilters(query, filter)
|
||||
@@ -5310,7 +5342,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
|
||||
|
||||
// DeleteOldAccessLogs deletes all access logs older than the specified time
|
||||
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Where("timestamp < ?", olderThan).
|
||||
Delete(&accesslogs.AccessLogEntry{})
|
||||
|
||||
@@ -5399,7 +5431,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
|
||||
|
||||
// SaveProxy saves or updates a proxy in the database
|
||||
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
result := s.db.WithContext(ctx).Save(p)
|
||||
result := s.db.Save(p)
|
||||
if result.Error != nil {
|
||||
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
|
||||
return status.Errorf(status.Internal, "failed to save proxy")
|
||||
@@ -5411,7 +5443,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
|
||||
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||
now := time.Now()
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("id = ? AND status = ?", proxyID, "connected").
|
||||
Update("last_seen", now)
|
||||
@@ -5430,7 +5462,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
ConnectedAt: &now,
|
||||
Status: "connected",
|
||||
}
|
||||
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
|
||||
if err := s.db.Save(p).Error; err != nil {
|
||||
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
|
||||
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
|
||||
}
|
||||
@@ -5443,7 +5475,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
|
||||
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
|
||||
var addresses []string
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Distinct("cluster_address").
|
||||
@@ -5482,6 +5514,7 @@ const proxyActiveThreshold = 2 * time.Minute
|
||||
var validCapabilityColumns = map[string]struct{}{
|
||||
"supports_custom_ports": {},
|
||||
"require_subdomain": {},
|
||||
"supports_crowdsec": {},
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
|
||||
@@ -5496,6 +5529,59 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
|
||||
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
|
||||
}
|
||||
|
||||
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
|
||||
// have CrowdSec configured. Returns nil when no proxy reported the capability.
|
||||
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
|
||||
// requires unanimous support: a single unconfigured proxy would let requests
|
||||
// bypass reputation checks.
|
||||
func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec")
|
||||
}
|
||||
|
||||
// getClusterUnanimousCapability returns an aggregated boolean capability
|
||||
// requiring all active proxies in the cluster to report true.
|
||||
func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool {
|
||||
if _, ok := validCapabilityColumns[column]; !ok {
|
||||
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
|
||||
return nil
|
||||
}
|
||||
|
||||
var result struct {
|
||||
Total int64
|
||||
Reported int64
|
||||
AllTrue bool
|
||||
}
|
||||
|
||||
// All active proxies must have reported the capability (no NULLs) and all
|
||||
// must report true. A single unreported or false proxy means the cluster
|
||||
// does not unanimously support the capability.
|
||||
err := s.db.WithContext(ctx).
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(*) AS total, "+
|
||||
"COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+
|
||||
"COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true").
|
||||
Where("cluster_address = ? AND status = ? AND last_seen > ?",
|
||||
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
|
||||
Scan(&result).Error
|
||||
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
|
||||
return nil
|
||||
}
|
||||
|
||||
if result.Total == 0 || result.Reported == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If any proxy has not reported (NULL), we can't confirm unanimous support.
|
||||
if result.Reported < result.Total {
|
||||
v := false
|
||||
return &v
|
||||
}
|
||||
|
||||
return &result.AllTrue
|
||||
}
|
||||
|
||||
// getClusterCapability returns an aggregated boolean capability for the given
|
||||
// cluster. It checks active (connected, recently seen) proxies and returns:
|
||||
// - *true if any proxy in the cluster has the capability set to true,
|
||||
@@ -5512,7 +5598,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
AnyTrue bool
|
||||
}
|
||||
|
||||
err := s.db.WithContext(ctx).
|
||||
err := s.db.
|
||||
Model(&proxy.Proxy{}).
|
||||
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
|
||||
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
|
||||
@@ -5536,7 +5622,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
|
||||
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
|
||||
cutoffTime := time.Now().Add(-inactivityDuration)
|
||||
|
||||
result := s.db.WithContext(ctx).
|
||||
result := s.db.
|
||||
Where("last_seen < ?", cutoffTime).
|
||||
Delete(&proxy.Proxy{})
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
@@ -350,6 +352,35 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
account.Services = []*rpservice.Service{
|
||||
{
|
||||
ID: "service_id",
|
||||
AccountID: account.Id,
|
||||
Name: "test service",
|
||||
Domain: "svc.example.com",
|
||||
Enabled: true,
|
||||
Targets: []*rpservice.Target{
|
||||
{
|
||||
AccountID: account.Id,
|
||||
ServiceID: "service_id",
|
||||
Host: "localhost",
|
||||
Port: 8080,
|
||||
Protocol: "http",
|
||||
Enabled: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
account.Domains = []*proxydomain.Domain{
|
||||
{
|
||||
ID: "domain_id",
|
||||
Domain: "custom.example.com",
|
||||
AccountID: account.Id,
|
||||
Validated: true,
|
||||
},
|
||||
}
|
||||
|
||||
err = store.SaveAccount(context.Background(), account)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -411,6 +442,20 @@ func TestSqlite_DeleteAccount(t *testing.T) {
|
||||
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources")
|
||||
require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount")
|
||||
}
|
||||
|
||||
domains, err := store.ListCustomDomains(context.Background(), account.Id)
|
||||
require.NoError(t, err, "expecting no error after DeleteAccount when searching for custom domains")
|
||||
require.Len(t, domains, 0, "expecting no custom domains to be found after DeleteAccount")
|
||||
|
||||
var services []*rpservice.Service
|
||||
err = store.(*SqlStore).db.Model(&rpservice.Service{}).Find(&services, "account_id = ?", account.Id).Error
|
||||
require.NoError(t, err, "expecting no error after DeleteAccount when searching for services")
|
||||
require.Len(t, services, 0, "expecting no services to be found after DeleteAccount")
|
||||
|
||||
var targets []*rpservice.Target
|
||||
err = store.(*SqlStore).db.Model(&rpservice.Target{}).Find(&targets, "account_id = ?", account.Id).Error
|
||||
require.NoError(t, err, "expecting no error after DeleteAccount when searching for service targets")
|
||||
require.Len(t, targets, 0, "expecting no service targets to be found after DeleteAccount")
|
||||
}
|
||||
|
||||
func Test_GetAccount(t *testing.T) {
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
|
||||
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
|
||||
@@ -265,6 +266,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
|
||||
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
|
||||
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
|
||||
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
|
||||
&domain.Domain{},
|
||||
}
|
||||
|
||||
for i := len(models) - 1; i >= 0; i-- {
|
||||
|
||||
@@ -121,7 +121,7 @@ type Store interface {
|
||||
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
|
||||
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
|
||||
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
|
||||
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
|
||||
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
|
||||
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
|
||||
@@ -289,6 +289,7 @@ type Store interface {
|
||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
|
||||
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
|
||||
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
|
||||
|
||||
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
|
||||
@@ -448,6 +449,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
|
||||
func(db *gorm.DB) error {
|
||||
return migration.RemoveDuplicatePeerKeys(ctx, db)
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id")
|
||||
},
|
||||
func(db *gorm.DB) error {
|
||||
return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -165,34 +165,19 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
// GetClusterSupportsCrowdSec mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain mocks base method.
|
||||
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
|
||||
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// Close mocks base method.
|
||||
func (m *MockStore) Close(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1389,6 +1374,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain mocks base method.
|
||||
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
|
||||
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts mocks base method.
|
||||
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
|
||||
ret0, _ := ret[0].(*bool)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
|
||||
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
|
||||
}
|
||||
|
||||
// GetCustomDomain mocks base method.
|
||||
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1466,18 +1479,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
|
||||
}
|
||||
|
||||
// GetGroupByName mocks base method.
|
||||
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) {
|
||||
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID)
|
||||
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
|
||||
ret0, _ := ret[0].(*types2.Group)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetGroupByName indicates an expected call of GetGroupByName.
|
||||
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call {
|
||||
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
|
||||
}
|
||||
|
||||
// GetGroupsByIDs mocks base method.
|
||||
@@ -1974,6 +1987,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks mocks base method.
|
||||
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
|
||||
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// GetServiceByDomain mocks base method.
|
||||
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -2361,21 +2389,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks mocks base method.
|
||||
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
|
||||
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
|
||||
}
|
||||
|
||||
// IsPrimaryAccount mocks base method.
|
||||
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
|
||||
|
||||
w := WrapResponseWriter(rw)
|
||||
|
||||
handlerDone := make(chan struct{})
|
||||
context.AfterFunc(ctx, func() {
|
||||
select {
|
||||
case <-handlerDone:
|
||||
default:
|
||||
log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)",
|
||||
r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx))
|
||||
}
|
||||
})
|
||||
|
||||
h.ServeHTTP(w, r.WithContext(ctx))
|
||||
close(handlerDone)
|
||||
|
||||
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
|
||||
if err == nil {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/auth"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
|
||||
@@ -101,6 +102,7 @@ type Account struct {
|
||||
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
|
||||
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
|
||||
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
|
||||
Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"`
|
||||
// Settings is a dictionary of Account settings
|
||||
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
|
||||
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
|
||||
@@ -911,6 +913,11 @@ func (a *Account) Copy() *Account {
|
||||
services = append(services, svc.Copy())
|
||||
}
|
||||
|
||||
domains := []*proxydomain.Domain{}
|
||||
for _, domain := range a.Domains {
|
||||
domains = append(domains, domain.Copy())
|
||||
}
|
||||
|
||||
return &Account{
|
||||
Id: a.Id,
|
||||
CreatedBy: a.CreatedBy,
|
||||
@@ -936,6 +943,7 @@ func (a *Account) Copy() *Account {
|
||||
Onboarding: a.Onboarding,
|
||||
NetworkMapCache: a.NetworkMapCache,
|
||||
nmapInitOnce: a.nmapInitOnce,
|
||||
Domains: domains,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
217
management/server/types/networkmap_benchmark_test.go
Normal file
217
management/server/types/networkmap_benchmark_test.go
Normal file
@@ -0,0 +1,217 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
type benchmarkScale struct {
|
||||
name string
|
||||
peers int
|
||||
groups int
|
||||
}
|
||||
|
||||
var defaultScales = []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
{"10000peers_200groups", 10000, 200},
|
||||
{"20000peers_200groups", 20000, 200},
|
||||
{"30000peers_300groups", 30000, 300},
|
||||
}
|
||||
|
||||
func skipCIBenchmark(b *testing.B) {
|
||||
if os.Getenv("CI") == "true" {
|
||||
b.Skip("Skipping benchmark in CI")
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Single Peer Network Map Generation
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer.
|
||||
func BenchmarkNetworkMapGeneration_Components(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// All Peers (UpdateAccountPeers hot path)
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers.
|
||||
func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
scales := []benchmarkScale{
|
||||
{"100peers_5groups", 100, 5},
|
||||
{"500peers_20groups", 500, 20},
|
||||
{"1000peers_50groups", 1000, 50},
|
||||
{"5000peers_100groups", 5000, 100},
|
||||
}
|
||||
|
||||
for _, scale := range scales {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
|
||||
peerIDs := make([]string, 0, len(account.Peers))
|
||||
for peerID := range account.Peers {
|
||||
peerIDs = append(peerIDs, peerID)
|
||||
}
|
||||
|
||||
b.Run("components/"+scale.name, func(b *testing.B) {
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for _, peerID := range peerIDs {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Sub-operations
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction.
|
||||
func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components.
|
||||
func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run(scale.name, func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = types.CalculateNetworkMapFromComponents(ctx, components)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs.
|
||||
func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
for _, scale := range defaultScales {
|
||||
b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetResourcePoliciesMap()
|
||||
}
|
||||
})
|
||||
b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetResourceRoutersMap()
|
||||
}
|
||||
})
|
||||
b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) {
|
||||
account, _ := scalableTestAccount(scale.peers, scale.groups)
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetActiveGroupUsers()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// Scaling Analysis
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance.
|
||||
func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
groupCounts := []int{1, 5, 20, 50, 100, 200, 500}
|
||||
for _, numGroups := range groupCounts {
|
||||
b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(1000, numGroups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance.
|
||||
func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) {
|
||||
skipCIBenchmark(b)
|
||||
peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000}
|
||||
for _, numPeers := range peerCounts {
|
||||
numGroups := numPeers / 20
|
||||
if numGroups < 1 {
|
||||
numGroups = 1
|
||||
}
|
||||
b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) {
|
||||
account, validatedPeers := scalableTestAccount(numPeers, numGroups)
|
||||
ctx := context.Background()
|
||||
resourcePolicies := account.GetResourcePoliciesMap()
|
||||
routers := account.GetResourceRoutersMap()
|
||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
1192
management/server/types/networkmap_components_correctness_test.go
Normal file
1192
management/server/types/networkmap_components_correctness_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user