Merge branch 'main' of github.com:netbirdio/netbird into feat/local-user-totp

This commit is contained in:
jnfrati
2026-04-15 17:30:38 +02:00
160 changed files with 9330 additions and 1754 deletions

View File

@@ -742,11 +742,6 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
return status.Errorf(status.Internal, "failed to build user infos for account %s: %v", accountID, err)
}
err = am.serviceManager.DeleteAllServices(ctx, accountID, userID)
if err != nil {
return status.Errorf(status.Internal, "failed to delete service %s: %v", accountID, err)
}
for _, otherUser := range account.Users {
if otherUser.Id == userID {
continue

View File

@@ -75,7 +75,7 @@ type Manager interface {
GetUsersFromAccount(ctx context.Context, accountID, userID string) (map[string]*types.UserInfo, error)
GetGroup(ctx context.Context, accountId, groupID, userID string) (*types.Group, error)
GetAllGroups(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
CreateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
UpdateGroup(ctx context.Context, accountID, userID string, group *types.Group) error
CreateGroups(ctx context.Context, accountID, userID string, newGroups []*types.Group) error

View File

@@ -736,18 +736,18 @@ func (mr *MockManagerMockRecorder) GetGroup(ctx, accountId, groupID, userID inte
}
// GetGroupByName mocks base method.
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (m *MockManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, groupName, accountID, userID)
ret0, _ := ret[0].(*types.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID interface{}) *gomock.Call {
func (mr *MockManagerMockRecorder) GetGroupByName(ctx, groupName, accountID, userID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockManager)(nil).GetGroupByName), ctx, groupName, accountID, userID)
}
// GetIdentityProvider mocks base method.

View File

@@ -63,20 +63,11 @@ func (ac *AccountRequestBuffer) GetAccountWithBackpressure(ctx context.Context,
log.WithContext(ctx).Tracef("requesting account %s with backpressure", accountID)
startTime := time.Now()
ac.getAccountRequestCh <- req
select {
case <-ctx.Done():
return nil, ctx.Err()
case ac.getAccountRequestCh <- req:
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case result := <-req.ResultChan:
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
result := <-req.ResultChan
log.WithContext(ctx).Tracef("got account with backpressure after %s", time.Since(startTime))
return result.Account, result.Err
}
func (ac *AccountRequestBuffer) processGetAccountBatch(ctx context.Context, accountID string) {

View File

@@ -15,7 +15,6 @@ import (
"time"
"github.com/golang/mock/gomock"
"github.com/netbirdio/netbird/shared/management/status"
"github.com/prometheus/client_golang/prometheus/push"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
@@ -23,6 +22,9 @@ import (
"go.opentelemetry.io/otel/metric/noop"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/shared/management/status"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
@@ -1815,6 +1817,13 @@ func TestAccount_Copy(t *testing.T) {
Targets: []*service.Target{},
},
},
Domains: []*domain.Domain{
{
ID: "domain1",
Domain: "test.com",
AccountID: "account1",
},
},
NetworkMapCache: &types.NetworkMapBuilder{},
}
account.InitOnce()

View File

@@ -61,7 +61,10 @@ func (am *DefaultAccountManager) GetAllGroups(ctx context.Context, accountID, us
}
// GetGroupByName filters all groups in an account by name and returns the one with the most peers
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID string) (*types.Group, error) {
func (am *DefaultAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if err := am.CheckGroupPermissions(ctx, accountID, userID); err != nil {
return nil, err
}
return am.Store.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}

View File

@@ -52,7 +52,7 @@ func (h *handler) getAllGroups(w http.ResponseWriter, r *http.Request) {
groupName := r.URL.Query().Get("name")
if groupName != "" {
// Get single group by name
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID)
group, err := h.accountManager.GetGroupByName(r.Context(), groupName, accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return
@@ -118,7 +118,7 @@ func (h *handler) updateGroup(w http.ResponseWriter, r *http.Request) {
return
}
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID)
allGroup, err := h.accountManager.GetGroupByName(r.Context(), "All", accountID, userID)
if err != nil {
util.WriteError(r.Context(), err, w)
return

View File

@@ -71,7 +71,7 @@ func initGroupTestData(initGroups ...*types.Group) *handler {
return groups, nil
},
GetGroupByNameFunc: func(ctx context.Context, groupName, _ string) (*types.Group, error) {
GetGroupByNameFunc: func(ctx context.Context, groupName, _, _ string) (*types.Group, error) {
if groupName == "All" {
return &types.Group{ID: "id-all", Name: "All", Issued: types.GroupIssuedAPI}, nil
}

View File

@@ -489,6 +489,102 @@ func MigrateJsonToTable[T any](ctx context.Context, db *gorm.DB, columnName stri
return nil
}
// hasForeignKey checks whether a foreign key constraint exists on the given table and column.
func hasForeignKey(db *gorm.DB, table, column string) bool {
var count int64
switch db.Name() {
case "postgres":
db.Raw(`
SELECT COUNT(*) FROM information_schema.key_column_usage kcu
JOIN information_schema.table_constraints tc
ON tc.constraint_name = kcu.constraint_name
AND tc.table_schema = kcu.table_schema
WHERE tc.constraint_type = 'FOREIGN KEY'
AND kcu.table_name = ?
AND kcu.column_name = ?
`, table, column).Scan(&count)
case "mysql":
db.Raw(`
SELECT COUNT(*) FROM information_schema.key_column_usage
WHERE table_schema = DATABASE()
AND table_name = ?
AND column_name = ?
AND referenced_table_name IS NOT NULL
`, table, column).Scan(&count)
default: // sqlite
type fkInfo struct {
From string
}
var fks []fkInfo
db.Raw(fmt.Sprintf("PRAGMA foreign_key_list(%s)", table)).Scan(&fks)
for _, fk := range fks {
if fk.From == column {
return true
}
}
return false
}
return count > 0
}
// CleanupOrphanedResources deletes rows from the table of model T where the foreign
// key column (fkColumn) references a row in the table of model R that no longer exists.
func CleanupOrphanedResources[T any, R any](ctx context.Context, db *gorm.DB, fkColumn string) error {
var model T
var refModel R
if !db.Migrator().HasTable(&model) {
log.WithContext(ctx).Debugf("table for %T does not exist, no cleanup needed", model)
return nil
}
if !db.Migrator().HasTable(&refModel) {
log.WithContext(ctx).Debugf("referenced table for %T does not exist, no cleanup needed", refModel)
return nil
}
stmtT := &gorm.Statement{DB: db}
if err := stmtT.Parse(&model); err != nil {
return fmt.Errorf("parse model %T: %w", model, err)
}
childTable := stmtT.Schema.Table
stmtR := &gorm.Statement{DB: db}
if err := stmtR.Parse(&refModel); err != nil {
return fmt.Errorf("parse reference model %T: %w", refModel, err)
}
parentTable := stmtR.Schema.Table
if !db.Migrator().HasColumn(&model, fkColumn) {
log.WithContext(ctx).Debugf("column %s does not exist in table %s, no cleanup needed", fkColumn, childTable)
return nil
}
// If a foreign key constraint already exists on the column, the DB itself
// enforces referential integrity and orphaned rows cannot exist.
if hasForeignKey(db, childTable, fkColumn) {
log.WithContext(ctx).Debugf("foreign key constraint for %s already exists on %s, no cleanup needed", fkColumn, childTable)
return nil
}
result := db.Exec(
fmt.Sprintf(
"DELETE FROM %s WHERE %s NOT IN (SELECT id FROM %s)",
childTable, fkColumn, parentTable,
),
)
if result.Error != nil {
return fmt.Errorf("cleanup orphaned rows in %s: %w", childTable, result.Error)
}
log.WithContext(ctx).Infof("Cleaned up %d orphaned rows from %s where %s had no matching row in %s",
result.RowsAffected, childTable, fkColumn, parentTable)
return nil
}
func RemoveDuplicatePeerKeys(ctx context.Context, db *gorm.DB) error {
if !db.Migrator().HasTable("peers") {
log.WithContext(ctx).Debug("peers table does not exist, skipping duplicate key cleanup")

View File

@@ -441,3 +441,197 @@ func TestRemoveDuplicatePeerKeys_NoTable(t *testing.T) {
err := migration.RemoveDuplicatePeerKeys(context.Background(), db)
require.NoError(t, err, "Should not fail when table does not exist")
}
type testParent struct {
ID string `gorm:"primaryKey"`
}
func (testParent) TableName() string {
return "test_parents"
}
type testChild struct {
ID string `gorm:"primaryKey"`
ParentID string
}
func (testChild) TableName() string {
return "test_children"
}
type testChildWithFK struct {
ID string `gorm:"primaryKey"`
ParentID string `gorm:"index"`
Parent *testParent `gorm:"foreignKey:ParentID"`
}
func (testChildWithFK) TableName() string {
return "test_children"
}
func setupOrphanTestDB(t *testing.T, models ...any) *gorm.DB {
t.Helper()
db := setupDatabase(t)
for _, m := range models {
_ = db.Migrator().DropTable(m)
}
err := db.AutoMigrate(models...)
require.NoError(t, err, "Failed to auto-migrate tables")
return db
}
func TestCleanupOrphanedResources_NoChildTable(t *testing.T) {
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testChild{})
_ = db.Migrator().DropTable(&testParent{})
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail when child table does not exist")
}
func TestCleanupOrphanedResources_NoParentTable(t *testing.T) {
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testParent{})
_ = db.Migrator().DropTable(&testChild{})
err := db.AutoMigrate(&testChild{})
require.NoError(t, err)
err = migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail when parent table does not exist")
}
func TestCleanupOrphanedResources_EmptyTables(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err, "Should not fail on empty tables")
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(0), count)
}
func TestCleanupOrphanedResources_NoOrphans(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(2), count, "All children should remain when no orphans")
}
func TestCleanupOrphanedResources_AllOrphans(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c1", "gone1").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone2").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c3", "gone3").Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(0), count, "All orphaned children should be deleted")
}
func TestCleanupOrphanedResources_MixedValidAndOrphaned(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c2", ParentID: "p2"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c3", ParentID: "p1"}).Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c4", "gone1").Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c5", "gone2").Error)
err := migration.CleanupOrphanedResources[testChild, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var remaining []testChild
require.NoError(t, db.Order("id").Find(&remaining).Error)
assert.Len(t, remaining, 3, "Only valid children should remain")
assert.Equal(t, "c1", remaining[0].ID)
assert.Equal(t, "c2", remaining[1].ID)
assert.Equal(t, "c3", remaining[2].ID)
}
func TestCleanupOrphanedResources_Idempotent(t *testing.T) {
db := setupOrphanTestDB(t, &testParent{}, &testChild{})
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testChild{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Exec("INSERT INTO test_children (id, parent_id) VALUES (?, ?)", "c2", "gone").Error)
ctx := context.Background()
err := migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(1), count)
err = migration.CleanupOrphanedResources[testChild, testParent](ctx, db, "parent_id")
require.NoError(t, err)
db.Model(&testChild{}).Count(&count)
assert.Equal(t, int64(1), count, "Count should remain the same after second run")
}
func TestCleanupOrphanedResources_SkipsWhenForeignKeyExists(t *testing.T) {
engine := os.Getenv("NETBIRD_STORE_ENGINE")
if engine != "postgres" && engine != "mysql" {
t.Skip("FK constraint early-exit test requires postgres or mysql")
}
db := setupDatabase(t)
_ = db.Migrator().DropTable(&testChildWithFK{})
_ = db.Migrator().DropTable(&testParent{})
err := db.AutoMigrate(&testParent{}, &testChildWithFK{})
require.NoError(t, err)
require.NoError(t, db.Create(&testParent{ID: "p1"}).Error)
require.NoError(t, db.Create(&testParent{ID: "p2"}).Error)
require.NoError(t, db.Create(&testChildWithFK{ID: "c1", ParentID: "p1"}).Error)
require.NoError(t, db.Create(&testChildWithFK{ID: "c2", ParentID: "p2"}).Error)
switch engine {
case "postgres":
require.NoError(t, db.Exec("ALTER TABLE test_children DROP CONSTRAINT fk_test_children_parent").Error)
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
require.NoError(t, db.Exec(
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
"FOREIGN KEY (parent_id) REFERENCES test_parents(id) NOT VALID",
).Error)
case "mysql":
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error)
require.NoError(t, db.Exec("ALTER TABLE test_children DROP FOREIGN KEY fk_test_children_parent").Error)
require.NoError(t, db.Exec("DELETE FROM test_parents WHERE id = ?", "p2").Error)
require.NoError(t, db.Exec(
"ALTER TABLE test_children ADD CONSTRAINT fk_test_children_parent "+
"FOREIGN KEY (parent_id) REFERENCES test_parents(id)",
).Error)
require.NoError(t, db.Exec("SET FOREIGN_KEY_CHECKS = 1").Error)
}
err = migration.CleanupOrphanedResources[testChildWithFK, testParent](context.Background(), db, "parent_id")
require.NoError(t, err)
var count int64
db.Model(&testChildWithFK{}).Count(&count)
assert.Equal(t, int64(2), count, "Both rows should survive — migration must skip when FK constraint exists")
}

View File

@@ -46,7 +46,7 @@ type MockAccountManager struct {
AddPeerFunc func(ctx context.Context, accountID string, setupKey string, userId string, peer *nbpeer.Peer, temporary bool) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, error)
GetGroupFunc func(ctx context.Context, accountID, groupID, userID string) (*types.Group, error)
GetAllGroupsFunc func(ctx context.Context, accountID, userID string) ([]*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, accountID, groupName string) (*types.Group, error)
GetGroupByNameFunc func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error)
SaveGroupFunc func(ctx context.Context, accountID, userID string, group *types.Group, create bool) error
SaveGroupsFunc func(ctx context.Context, accountID, userID string, groups []*types.Group, create bool) error
DeleteGroupFunc func(ctx context.Context, accountID, userId, groupID string) error
@@ -406,9 +406,9 @@ func (am *MockAccountManager) AddPeer(
}
// GetGroupByName mock implementation of GetGroupByName from server.AccountManager interface
func (am *MockAccountManager) GetGroupByName(ctx context.Context, accountID, groupName string) (*types.Group, error) {
func (am *MockAccountManager) GetGroupByName(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
if am.GetGroupByNameFunc != nil {
return am.GetGroupByNameFunc(ctx, accountID, groupName)
return am.GetGroupByNameFunc(ctx, groupName, accountID, userID)
}
return nil, status.Errorf(codes.Unimplemented, "method GetGroupByName is not implemented")
}

View File

@@ -396,6 +396,11 @@ func (s *SqlStore) DeleteAccount(ctx context.Context, account *types.Account) er
return result.Error
}
result = tx.Select(clause.Associations).Delete(account.Services, "account_id = ?", account.Id)
if result.Error != nil {
return result.Error
}
result = tx.Select(clause.Associations).Delete(account)
if result.Error != nil {
return result.Error
@@ -1012,10 +1017,10 @@ func (s *SqlStore) GetAccountsCounter(ctx context.Context) (int64, error) {
// GetCustomDomainsCounts returns the total and validated custom domain counts.
func (s *SqlStore) GetCustomDomainsCounts(ctx context.Context) (int64, int64, error) {
var total, validated int64
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Count(&total).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Count(&total).Error; err != nil {
return 0, 0, err
}
if err := s.db.WithContext(ctx).Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
if err := s.db.Model(&domain.Domain{}).Where("validated = ?", true).Count(&validated).Error; err != nil {
return 0, 0, err
}
return total, validated, nil
@@ -2080,7 +2085,8 @@ func (s *SqlStore) getPostureChecks(ctx context.Context, accountID string) ([]*p
func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpservice.Service, error) {
const serviceQuery = `SELECT id, account_id, name, domain, enabled, auth,
meta_created_at, meta_certificate_issued_at, meta_status, proxy_cluster,
pass_host_header, rewrite_redirects, session_private_key, session_public_key
pass_host_header, rewrite_redirects, session_private_key, session_public_key,
mode, listen_port, port_auto_assigned, source, source_peer, terminated
FROM services WHERE account_id = $1`
const targetsQuery = `SELECT id, account_id, service_id, path, host, port, protocol,
@@ -2097,6 +2103,9 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
var auth []byte
var createdAt, certIssuedAt sql.NullTime
var status, proxyCluster, sessionPrivateKey, sessionPublicKey sql.NullString
var mode, source, sourcePeer sql.NullString
var terminated, portAutoAssigned sql.NullBool
var listenPort sql.NullInt64
err := row.Scan(
&s.ID,
&s.AccountID,
@@ -2112,6 +2121,12 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
&s.RewriteRedirects,
&sessionPrivateKey,
&sessionPublicKey,
&mode,
&listenPort,
&portAutoAssigned,
&source,
&sourcePeer,
&terminated,
)
if err != nil {
return nil, err
@@ -2143,7 +2158,24 @@ func (s *SqlStore) getServices(ctx context.Context, accountID string) ([]*rpserv
if sessionPublicKey.Valid {
s.SessionPublicKey = sessionPublicKey.String
}
if mode.Valid {
s.Mode = mode.String
}
if source.Valid {
s.Source = source.String
}
if sourcePeer.Valid {
s.SourcePeer = sourcePeer.String
}
if terminated.Valid {
s.Terminated = terminated.Bool
}
if portAutoAssigned.Valid {
s.PortAutoAssigned = portAutoAssigned.Bool
}
if listenPort.Valid {
s.ListenPort = uint16(listenPort.Int64)
}
s.Targets = []*rpservice.Target{}
return &s, nil
})
@@ -4410,7 +4442,7 @@ func (s *SqlStore) DeletePAT(ctx context.Context, userID, patID string) error {
// GetProxyAccessTokenByHashedToken retrieves a proxy access token by its hashed value.
func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStrength LockingStrength, hashedToken types.HashedProxyToken) (*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4429,7 +4461,7 @@ func (s *SqlStore) GetProxyAccessTokenByHashedToken(ctx context.Context, lockStr
// GetAllProxyAccessTokens retrieves all proxy access tokens.
func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength LockingStrength) ([]*types.ProxyAccessToken, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -4445,7 +4477,7 @@ func (s *SqlStore) GetAllProxyAccessTokens(ctx context.Context, lockStrength Loc
// SaveProxyAccessToken saves a proxy access token to the database.
func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyAccessToken) error {
if result := s.db.WithContext(ctx).Create(token); result.Error != nil {
if result := s.db.Create(token); result.Error != nil {
return status.Errorf(status.Internal, "save proxy access token: %v", result.Error)
}
return nil
@@ -4453,7 +4485,7 @@ func (s *SqlStore) SaveProxyAccessToken(ctx context.Context, token *types.ProxyA
// RevokeProxyAccessToken revokes a proxy access token by its ID.
func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
result := s.db.Model(&types.ProxyAccessToken{}).Where(idQueryCondition, tokenID).Update("revoked", true)
if result.Error != nil {
return status.Errorf(status.Internal, "revoke proxy access token: %v", result.Error)
}
@@ -4467,7 +4499,7 @@ func (s *SqlStore) RevokeProxyAccessToken(ctx context.Context, tokenID string) e
// MarkProxyAccessTokenUsed updates the last used timestamp for a proxy access token.
func (s *SqlStore) MarkProxyAccessTokenUsed(ctx context.Context, tokenID string) error {
result := s.db.WithContext(ctx).Model(&types.ProxyAccessToken{}).
result := s.db.Model(&types.ProxyAccessToken{}).
Where(idQueryCondition, tokenID).
Update("last_used", time.Now().UTC())
if result.Error != nil {
@@ -5136,7 +5168,7 @@ func (s *SqlStore) EphemeralServiceExists(ctx context.Context, lockStrength Lock
// GetServicesByClusterAndPort returns services matching the given proxy cluster, mode, and listen port.
func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength LockingStrength, proxyCluster string, mode string, listenPort uint16) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5152,7 +5184,7 @@ func (s *SqlStore) GetServicesByClusterAndPort(ctx context.Context, lockStrength
// GetServicesByCluster returns all services for the given proxy cluster.
func (s *SqlStore) GetServicesByCluster(ctx context.Context, lockStrength LockingStrength, proxyCluster string) ([]*rpservice.Service, error) {
tx := s.db.WithContext(ctx)
tx := s.db
if lockStrength != LockingStrengthNone {
tx = tx.Clauses(clause.Locking{Strength: string(lockStrength)})
}
@@ -5262,7 +5294,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
var logs []*accesslogs.AccessLogEntry
var totalCount int64
baseQuery := s.db.WithContext(ctx).
baseQuery := s.db.
Model(&accesslogs.AccessLogEntry{}).
Where(accountIDCondition, accountID)
@@ -5273,7 +5305,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
return nil, 0, status.Errorf(status.Internal, "failed to count access logs")
}
query := s.db.WithContext(ctx).
query := s.db.
Where(accountIDCondition, accountID)
query = s.applyAccessLogFilters(query, filter)
@@ -5310,7 +5342,7 @@ func (s *SqlStore) GetAccountAccessLogs(ctx context.Context, lockStrength Lockin
// DeleteOldAccessLogs deletes all access logs older than the specified time
func (s *SqlStore) DeleteOldAccessLogs(ctx context.Context, olderThan time.Time) (int64, error) {
result := s.db.WithContext(ctx).
result := s.db.
Where("timestamp < ?", olderThan).
Delete(&accesslogs.AccessLogEntry{})
@@ -5399,7 +5431,7 @@ func (s *SqlStore) GetServiceTargetByTargetID(ctx context.Context, lockStrength
// SaveProxy saves or updates a proxy in the database
func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
result := s.db.WithContext(ctx).Save(p)
result := s.db.Save(p)
if result.Error != nil {
log.WithContext(ctx).Errorf("failed to save proxy: %v", result.Error)
return status.Errorf(status.Internal, "failed to save proxy")
@@ -5411,7 +5443,7 @@ func (s *SqlStore) SaveProxy(ctx context.Context, p *proxy.Proxy) error {
func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
now := time.Now()
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("id = ? AND status = ?", proxyID, "connected").
Update("last_seen", now)
@@ -5430,7 +5462,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
ConnectedAt: &now,
Status: "connected",
}
if err := s.db.WithContext(ctx).Save(p).Error; err != nil {
if err := s.db.Save(p).Error; err != nil {
log.WithContext(ctx).Errorf("failed to create proxy on heartbeat: %v", err)
return status.Errorf(status.Internal, "failed to create proxy on heartbeat")
}
@@ -5443,7 +5475,7 @@ func (s *SqlStore) UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAdd
func (s *SqlStore) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) {
var addresses []string
result := s.db.WithContext(ctx).
result := s.db.
Model(&proxy.Proxy{}).
Where("status = ? AND last_seen > ?", "connected", time.Now().Add(-proxyActiveThreshold)).
Distinct("cluster_address").
@@ -5482,6 +5514,7 @@ const proxyActiveThreshold = 2 * time.Minute
var validCapabilityColumns = map[string]struct{}{
"supports_custom_ports": {},
"require_subdomain": {},
"supports_crowdsec": {},
}
// GetClusterSupportsCustomPorts returns whether any active proxy in the cluster
@@ -5496,6 +5529,59 @@ func (s *SqlStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr s
return s.getClusterCapability(ctx, clusterAddr, "require_subdomain")
}
// GetClusterSupportsCrowdSec returns whether all active proxies in the cluster
// have CrowdSec configured. Returns nil when no proxy reported the capability.
// Unlike other capabilities that use ANY-true (for rolling upgrades), CrowdSec
// requires unanimous support: a single unconfigured proxy would let requests
// bypass reputation checks.
func (s *SqlStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
return s.getClusterUnanimousCapability(ctx, clusterAddr, "supports_crowdsec")
}
// getClusterUnanimousCapability returns an aggregated boolean capability
// requiring all active proxies in the cluster to report true.
func (s *SqlStore) getClusterUnanimousCapability(ctx context.Context, clusterAddr, column string) *bool {
if _, ok := validCapabilityColumns[column]; !ok {
log.WithContext(ctx).Errorf("invalid capability column: %s", column)
return nil
}
var result struct {
Total int64
Reported int64
AllTrue bool
}
// All active proxies must have reported the capability (no NULLs) and all
// must report true. A single unreported or false proxy means the cluster
// does not unanimously support the capability.
err := s.db.WithContext(ctx).
Model(&proxy.Proxy{}).
Select("COUNT(*) AS total, "+
"COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) AS reported, "+
"COUNT(*) > 0 AND COUNT(*) = COUNT(CASE WHEN "+column+" = true THEN 1 END) AS all_true").
Where("cluster_address = ? AND status = ? AND last_seen > ?",
clusterAddr, "connected", time.Now().Add(-proxyActiveThreshold)).
Scan(&result).Error
if err != nil {
log.WithContext(ctx).Errorf("query cluster capability %s for %s: %v", column, clusterAddr, err)
return nil
}
if result.Total == 0 || result.Reported == 0 {
return nil
}
// If any proxy has not reported (NULL), we can't confirm unanimous support.
if result.Reported < result.Total {
v := false
return &v
}
return &result.AllTrue
}
// getClusterCapability returns an aggregated boolean capability for the given
// cluster. It checks active (connected, recently seen) proxies and returns:
// - *true if any proxy in the cluster has the capability set to true,
@@ -5512,7 +5598,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
AnyTrue bool
}
err := s.db.WithContext(ctx).
err := s.db.
Model(&proxy.Proxy{}).
Select("COUNT(CASE WHEN "+column+" IS NOT NULL THEN 1 END) > 0 AS has_capability, "+
"COALESCE(MAX(CASE WHEN "+column+" = true THEN 1 ELSE 0 END), 0) = 1 AS any_true").
@@ -5536,7 +5622,7 @@ func (s *SqlStore) getClusterCapability(ctx context.Context, clusterAddr, column
func (s *SqlStore) CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error {
cutoffTime := time.Now().Add(-inactivityDuration)
result := s.db.WithContext(ctx).
result := s.db.
Where("last_seen < ?", cutoffTime).
Delete(&proxy.Proxy{})

View File

@@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
@@ -350,6 +352,35 @@ func TestSqlite_DeleteAccount(t *testing.T) {
},
}
account.Services = []*rpservice.Service{
{
ID: "service_id",
AccountID: account.Id,
Name: "test service",
Domain: "svc.example.com",
Enabled: true,
Targets: []*rpservice.Target{
{
AccountID: account.Id,
ServiceID: "service_id",
Host: "localhost",
Port: 8080,
Protocol: "http",
Enabled: true,
},
},
},
}
account.Domains = []*proxydomain.Domain{
{
ID: "domain_id",
Domain: "custom.example.com",
AccountID: account.Id,
Validated: true,
},
}
err = store.SaveAccount(context.Background(), account)
require.NoError(t, err)
@@ -411,6 +442,20 @@ func TestSqlite_DeleteAccount(t *testing.T) {
require.NoError(t, err, "expecting no error after removing DeleteAccount when searching for network resources")
require.Len(t, resources, 0, "expecting no network resources to be found after DeleteAccount")
}
domains, err := store.ListCustomDomains(context.Background(), account.Id)
require.NoError(t, err, "expecting no error after DeleteAccount when searching for custom domains")
require.Len(t, domains, 0, "expecting no custom domains to be found after DeleteAccount")
var services []*rpservice.Service
err = store.(*SqlStore).db.Model(&rpservice.Service{}).Find(&services, "account_id = ?", account.Id).Error
require.NoError(t, err, "expecting no error after DeleteAccount when searching for services")
require.Len(t, services, 0, "expecting no services to be found after DeleteAccount")
var targets []*rpservice.Target
err = store.(*SqlStore).db.Model(&rpservice.Target{}).Find(&targets, "account_id = ?", account.Id).Error
require.NoError(t, err, "expecting no error after DeleteAccount when searching for service targets")
require.Len(t, targets, 0, "expecting no service targets to be found after DeleteAccount")
}
func Test_GetAccount(t *testing.T) {

View File

@@ -20,6 +20,7 @@ import (
"github.com/stretchr/testify/assert"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types"
routerTypes "github.com/netbirdio/netbird/management/server/networks/routers/types"
@@ -265,6 +266,7 @@ func setupBenchmarkDB(b testing.TB) (*SqlStore, func(), string) {
&nbdns.NameServerGroup{}, &posture.Checks{}, &networkTypes.Network{},
&routerTypes.NetworkRouter{}, &resourceTypes.NetworkResource{},
&types.AccountOnboarding{}, &service.Service{}, &service.Target{},
&domain.Domain{},
}
for i := len(models) - 1; i >= 0; i-- {

View File

@@ -121,7 +121,7 @@ type Store interface {
GetAccountGroups(ctx context.Context, lockStrength LockingStrength, accountID string) ([]*types.Group, error)
GetResourceGroups(ctx context.Context, lockStrength LockingStrength, accountID, resourceID string) ([]*types.Group, error)
GetGroupByID(ctx context.Context, lockStrength LockingStrength, accountID, groupID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types.Group, error)
GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types.Group, error)
GetGroupsByIDs(ctx context.Context, lockStrength LockingStrength, accountID string, groupIDs []string) (map[string]*types.Group, error)
CreateGroups(ctx context.Context, accountID string, groups []*types.Group) error
UpdateGroups(ctx context.Context, accountID string, groups []*types.Group) error
@@ -289,6 +289,7 @@ type Store interface {
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool
GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool
CleanupStaleProxies(ctx context.Context, inactivityDuration time.Duration) error
GetCustomDomainsCounts(ctx context.Context) (total int64, validated int64, err error)
@@ -448,6 +449,12 @@ func getMigrationsPreAuto(ctx context.Context) []migrationFunc {
func(db *gorm.DB) error {
return migration.RemoveDuplicatePeerKeys(ctx, db)
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedResources[rpservice.Service, types.Account](ctx, db, "account_id")
},
func(db *gorm.DB) error {
return migration.CleanupOrphanedResources[domain.Domain, types.Account](ctx, db, "account_id")
},
}
}

View File

@@ -165,34 +165,19 @@ func (mr *MockStoreMockRecorder) CleanupStaleProxies(ctx, inactivityDuration int
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CleanupStaleProxies", reflect.TypeOf((*MockStore)(nil).CleanupStaleProxies), ctx, inactivityDuration)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
// GetClusterSupportsCrowdSec mocks base method.
func (m *MockStore) GetClusterSupportsCrowdSec(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret := m.ctrl.Call(m, "GetClusterSupportsCrowdSec", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
// GetClusterSupportsCrowdSec indicates an expected call of GetClusterSupportsCrowdSec.
func (mr *MockStoreMockRecorder) GetClusterSupportsCrowdSec(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCrowdSec", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCrowdSec), ctx, clusterAddr)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// Close mocks base method.
func (m *MockStore) Close(ctx context.Context) error {
m.ctrl.T.Helper()
@@ -1389,6 +1374,34 @@ func (mr *MockStoreMockRecorder) GetAnyAccountID(ctx interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAnyAccountID", reflect.TypeOf((*MockStore)(nil).GetAnyAccountID), ctx)
}
// GetClusterRequireSubdomain mocks base method.
func (m *MockStore) GetClusterRequireSubdomain(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterRequireSubdomain", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterRequireSubdomain indicates an expected call of GetClusterRequireSubdomain.
func (mr *MockStoreMockRecorder) GetClusterRequireSubdomain(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterRequireSubdomain", reflect.TypeOf((*MockStore)(nil).GetClusterRequireSubdomain), ctx, clusterAddr)
}
// GetClusterSupportsCustomPorts mocks base method.
func (m *MockStore) GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetClusterSupportsCustomPorts", ctx, clusterAddr)
ret0, _ := ret[0].(*bool)
return ret0
}
// GetClusterSupportsCustomPorts indicates an expected call of GetClusterSupportsCustomPorts.
func (mr *MockStoreMockRecorder) GetClusterSupportsCustomPorts(ctx, clusterAddr interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetClusterSupportsCustomPorts", reflect.TypeOf((*MockStore)(nil).GetClusterSupportsCustomPorts), ctx, clusterAddr)
}
// GetCustomDomain mocks base method.
func (m *MockStore) GetCustomDomain(ctx context.Context, accountID, domainID string) (*domain.Domain, error) {
m.ctrl.T.Helper()
@@ -1466,18 +1479,18 @@ func (mr *MockStoreMockRecorder) GetGroupByID(ctx, lockStrength, accountID, grou
}
// GetGroupByName mocks base method.
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, groupName, accountID string) (*types2.Group, error) {
func (m *MockStore) GetGroupByName(ctx context.Context, lockStrength LockingStrength, accountID, groupName string) (*types2.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, groupName, accountID)
ret := m.ctrl.Call(m, "GetGroupByName", ctx, lockStrength, accountID, groupName)
ret0, _ := ret[0].(*types2.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupByName indicates an expected call of GetGroupByName.
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, groupName, accountID interface{}) *gomock.Call {
func (mr *MockStoreMockRecorder) GetGroupByName(ctx, lockStrength, accountID, groupName interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, groupName, accountID)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupByName", reflect.TypeOf((*MockStore)(nil).GetGroupByName), ctx, lockStrength, accountID, groupName)
}
// GetGroupsByIDs mocks base method.
@@ -1974,6 +1987,21 @@ func (mr *MockStoreMockRecorder) GetRouteByID(ctx, lockStrength, accountID, rout
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRouteByID", reflect.TypeOf((*MockStore)(nil).GetRouteByID), ctx, lockStrength, accountID, routeID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// GetServiceByDomain mocks base method.
func (m *MockStore) GetServiceByDomain(ctx context.Context, domain string) (*service.Service, error) {
m.ctrl.T.Helper()
@@ -2361,21 +2389,6 @@ func (mr *MockStoreMockRecorder) IncrementSetupKeyUsage(ctx, setupKeyID interfac
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "IncrementSetupKeyUsage", reflect.TypeOf((*MockStore)(nil).IncrementSetupKeyUsage), ctx, setupKeyID)
}
// GetRoutingPeerNetworks mocks base method.
func (m *MockStore) GetRoutingPeerNetworks(ctx context.Context, accountID, peerID string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetRoutingPeerNetworks", ctx, accountID, peerID)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetRoutingPeerNetworks indicates an expected call of GetRoutingPeerNetworks.
func (mr *MockStoreMockRecorder) GetRoutingPeerNetworks(ctx, accountID, peerID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetRoutingPeerNetworks", reflect.TypeOf((*MockStore)(nil).GetRoutingPeerNetworks), ctx, accountID, peerID)
}
// IsPrimaryAccount mocks base method.
func (m *MockStore) IsPrimaryAccount(ctx context.Context, accountID string) (bool, string, error) {
m.ctrl.T.Helper()

View File

@@ -183,7 +183,18 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler {
w := WrapResponseWriter(rw)
handlerDone := make(chan struct{})
context.AfterFunc(ctx, func() {
select {
case <-handlerDone:
default:
log.Debugf("HTTP request context canceled mid-flight: %v %v (reqID=%s, after %v, cause: %v)",
r.Method, r.URL.Path, reqID, time.Since(reqStart), context.Cause(ctx))
}
})
h.ServeHTTP(w, r.WithContext(ctx))
close(handlerDone)
userAuth, err := nbContext.GetUserAuthFromContext(r.Context())
if err == nil {

View File

@@ -18,6 +18,7 @@ import (
"github.com/netbirdio/netbird/client/ssh/auth"
nbdns "github.com/netbirdio/netbird/dns"
proxydomain "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/domain"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/modules/zones/records"
@@ -101,6 +102,7 @@ type Account struct {
DNSSettings DNSSettings `gorm:"embedded;embeddedPrefix:dns_settings_"`
PostureChecks []*posture.Checks `gorm:"foreignKey:AccountID;references:id"`
Services []*service.Service `gorm:"foreignKey:AccountID;references:id"`
Domains []*proxydomain.Domain `gorm:"foreignKey:AccountID;references:id"`
// Settings is a dictionary of Account settings
Settings *Settings `gorm:"embedded;embeddedPrefix:settings_"`
Networks []*networkTypes.Network `gorm:"foreignKey:AccountID;references:id"`
@@ -911,6 +913,11 @@ func (a *Account) Copy() *Account {
services = append(services, svc.Copy())
}
domains := []*proxydomain.Domain{}
for _, domain := range a.Domains {
domains = append(domains, domain.Copy())
}
return &Account{
Id: a.Id,
CreatedBy: a.CreatedBy,
@@ -936,6 +943,7 @@ func (a *Account) Copy() *Account {
Onboarding: a.Onboarding,
NetworkMapCache: a.NetworkMapCache,
nmapInitOnce: a.nmapInitOnce,
Domains: domains,
}
}

View File

@@ -0,0 +1,217 @@
package types_test
import (
"context"
"fmt"
"os"
"testing"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server/types"
)
type benchmarkScale struct {
name string
peers int
groups int
}
var defaultScales = []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
{"10000peers_200groups", 10000, 200},
{"20000peers_200groups", 20000, 200},
{"30000peers_300groups", 30000, 300},
}
func skipCIBenchmark(b *testing.B) {
if os.Getenv("CI") == "true" {
b.Skip("Skipping benchmark in CI")
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Single Peer Network Map Generation
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_Components benchmarks the components-based approach for a single peer.
func BenchmarkNetworkMapGeneration_Components(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// All Peers (UpdateAccountPeers hot path)
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_AllPeers benchmarks generating network maps for ALL peers.
func BenchmarkNetworkMapGeneration_AllPeers(b *testing.B) {
skipCIBenchmark(b)
scales := []benchmarkScale{
{"100peers_5groups", 100, 5},
{"500peers_20groups", 500, 20},
{"1000peers_50groups", 1000, 50},
{"5000peers_100groups", 5000, 100},
}
for _, scale := range scales {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
peerIDs := make([]string, 0, len(account.Peers))
for peerID := range account.Peers {
peerIDs = append(peerIDs, peerID)
}
b.Run("components/"+scale.name, func(b *testing.B) {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
for _, peerID := range peerIDs {
_ = account.GetPeerNetworkMapFromComponents(ctx, peerID, nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Sub-operations
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_ComponentsCreation benchmarks components extraction.
func BenchmarkNetworkMapGeneration_ComponentsCreation(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
}
})
}
}
// BenchmarkNetworkMapGeneration_ComponentsCalculation benchmarks calculation from pre-built components.
func BenchmarkNetworkMapGeneration_ComponentsCalculation(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run(scale.name, func(b *testing.B) {
account, validatedPeers := scalableTestAccount(scale.peers, scale.groups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
components := account.GetPeerNetworkMapComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, groupIDToUserIDs)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = types.CalculateNetworkMapFromComponents(ctx, components)
}
})
}
}
// BenchmarkNetworkMapGeneration_PrecomputeMaps benchmarks precomputed map costs.
func BenchmarkNetworkMapGeneration_PrecomputeMaps(b *testing.B) {
skipCIBenchmark(b)
for _, scale := range defaultScales {
b.Run("ResourcePoliciesMap/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetResourcePoliciesMap()
}
})
b.Run("ResourceRoutersMap/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetResourceRoutersMap()
}
})
b.Run("ActiveGroupUsers/"+scale.name, func(b *testing.B) {
account, _ := scalableTestAccount(scale.peers, scale.groups)
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetActiveGroupUsers()
}
})
}
}
// ──────────────────────────────────────────────────────────────────────────────
// Scaling Analysis
// ──────────────────────────────────────────────────────────────────────────────
// BenchmarkNetworkMapGeneration_GroupScaling tests group count impact on performance.
func BenchmarkNetworkMapGeneration_GroupScaling(b *testing.B) {
skipCIBenchmark(b)
groupCounts := []int{1, 5, 20, 50, 100, 200, 500}
for _, numGroups := range groupCounts {
b.Run(fmt.Sprintf("components_%dgroups", numGroups), func(b *testing.B) {
account, validatedPeers := scalableTestAccount(1000, numGroups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}
// BenchmarkNetworkMapGeneration_PeerScaling tests peer count impact on performance.
func BenchmarkNetworkMapGeneration_PeerScaling(b *testing.B) {
skipCIBenchmark(b)
peerCounts := []int{50, 100, 500, 1000, 2000, 5000, 10000, 20000, 30000}
for _, numPeers := range peerCounts {
numGroups := numPeers / 20
if numGroups < 1 {
numGroups = 1
}
b.Run(fmt.Sprintf("components_%dpeers", numPeers), func(b *testing.B) {
account, validatedPeers := scalableTestAccount(numPeers, numGroups)
ctx := context.Background()
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
b.ReportAllocs()
b.ResetTimer()
for range b.N {
_ = account.GetPeerNetworkMapFromComponents(ctx, "peer-0", nbdns.CustomZone{}, nil, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
})
}
}

File diff suppressed because it is too large Load Diff