From 3a0cf230a179e767434018814a0dee5be2c526dd Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 1 Feb 2026 14:26:22 +0100 Subject: [PATCH 01/15] Disable local users for a smooth single-idp mode (#5226) Add LocalAuthDisabled option to embedded IdP configuration This adds the ability to disable local (email/password) authentication when using the embedded Dex identity provider. When disabled, users can only authenticate via external identity providers (Google, OIDC, etc.). This simplifies user login when there is only one external IdP configured. The login page will redirect directly to the IdP login page. Key changes: Added LocalAuthDisabled field to EmbeddedIdPConfig Added methods to check and toggle local auth: IsLocalAuthEnabled, HasNonLocalConnectors, DisableLocalAuth, EnableLocalAuth Validation prevents disabling local auth if no external connectors are configured Existing local users are preserved when disabled and can login again when re-enabled Operations are idempotent (disabling already disabled is a no-op) --- idp/dex/connector.go | 54 ++++ management/internals/server/modules.go | 9 +- management/server/account.go | 15 +- management/server/http/handler.go | 4 +- .../handlers/accounts/accounts_handler.go | 25 +- .../accounts/accounts_handler_test.go | 7 +- .../handlers/instance/instance_handler.go | 2 +- .../handlers/users/invites_handler_test.go | 17 ++ .../testing/testing_tools/channel/channel.go | 2 +- management/server/idp/embedded.go | 50 ++++ management/server/idp/embedded_test.go | 231 ++++++++++++++++++ management/server/instance/manager.go | 11 +- management/server/settings/manager.go | 15 +- management/server/types/settings.go | 10 + management/server/user.go | 12 + shared/management/http/api/openapi.yml | 5 + shared/management/http/api/types.gen.go | 3 + 17 files changed, 450 insertions(+), 22 deletions(-) diff --git a/idp/dex/connector.go b/idp/dex/connector.go index cad682141..ba2bb1f00 100644 --- a/idp/dex/connector.go +++ b/idp/dex/connector.go @@ -327,6 +327,60 @@ func ensureLocalConnector(ctx context.Context, stor storage.Storage) error { return nil } +// HasNonLocalConnectors checks if there are any connectors other than the local connector. +func (p *Provider) HasNonLocalConnectors(ctx context.Context) (bool, error) { + connectors, err := p.storage.ListConnectors(ctx) + if err != nil { + return false, fmt.Errorf("failed to list connectors: %w", err) + } + + p.logger.Info("checking for non-local connectors", "total_connectors", len(connectors)) + for _, conn := range connectors { + p.logger.Info("found connector in storage", "id", conn.ID, "type", conn.Type, "name", conn.Name) + if conn.ID != "local" || conn.Type != "local" { + p.logger.Info("found non-local connector", "id", conn.ID) + return true, nil + } + } + p.logger.Info("no non-local connectors found") + return false, nil +} + +// DisableLocalAuth removes the local (password) connector. +// Returns an error if no other connectors are configured. +func (p *Provider) DisableLocalAuth(ctx context.Context) error { + hasOthers, err := p.HasNonLocalConnectors(ctx) + if err != nil { + return err + } + if !hasOthers { + return fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + + // Check if local connector exists + _, err = p.storage.GetConnector(ctx, "local") + if errors.Is(err, storage.ErrNotFound) { + // Already disabled + return nil + } + if err != nil { + return fmt.Errorf("failed to check local connector: %w", err) + } + + // Delete the local connector + if err := p.storage.DeleteConnector(ctx, "local"); err != nil { + return fmt.Errorf("failed to delete local connector: %w", err) + } + + p.logger.Info("local authentication disabled") + return nil +} + +// EnableLocalAuth creates the local (password) connector if it doesn't exist. +func (p *Provider) EnableLocalAuth(ctx context.Context) error { + return ensureLocalConnector(ctx, p.storage) +} + // ensureStaticConnectors creates or updates static connectors in storage func ensureStaticConnectors(ctx context.Context, stor storage.Storage, connectors []Connector) error { for _, conn := range connectors { diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index b51e2ebb2..31badf9d0 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -69,7 +69,14 @@ func (s *BaseServer) UsersManager() users.Manager { func (s *BaseServer) SettingsManager() settings.Manager { return Create(s, func() settings.Manager { extraSettingsManager := integrations.NewManager(s.EventStore()) - return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager()) + + idpConfig := settings.IdpConfig{} + if s.Config.EmbeddedIdP != nil && s.Config.EmbeddedIdP.Enabled { + idpConfig.EmbeddedIdpEnabled = true + idpConfig.LocalAuthDisabled = s.Config.EmbeddedIdP.LocalAuthDisabled + } + + return settings.NewManager(s.Store(), s.UsersManager(), extraSettingsManager, s.PermissionsManager(), idpConfig) }) } diff --git a/management/server/account.go b/management/server/account.go index ba5f0cffa..8f9dad031 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -26,7 +26,6 @@ import ( "golang.org/x/exp/maps" nbdns "github.com/netbirdio/netbird/dns" - nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/formatter/hook" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbconfig "github.com/netbirdio/netbird/management/internals/server/config" @@ -49,6 +48,7 @@ import ( "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/util" "github.com/netbirdio/netbird/route" + nbdomain "github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/status" ) @@ -795,6 +795,19 @@ func IsEmbeddedIdp(i idp.Manager) bool { return ok } +// IsLocalAuthDisabled checks if local (email/password) authentication is disabled. +// Returns true only when using embedded IDP with local auth disabled in config. +func IsLocalAuthDisabled(ctx context.Context, i idp.Manager) bool { + if isNil(i) { + return false + } + embeddedIdp, ok := i.(*idp.EmbeddedIdPManager) + if !ok { + return false + } + return embeddedIdp.IsLocalAuthDisabled() +} + // addAccountIDToIDPAppMeta update user's app metadata in idp manager func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error { if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 32a97ff44..79431a0a3 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -129,14 +129,14 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("register integrations endpoints: %w", err) } - // Check if embedded IdP is enabled + // Check if embedded IdP is enabled for instance manager embeddedIdP, embeddedIdpEnabled := idpManager.(*idpmanager.EmbeddedIdPManager) instanceManager, err := nbinstance.NewManager(ctx, accountManager.GetStore(), embeddedIdP) if err != nil { return nil, fmt.Errorf("failed to create instance manager: %w", err) } - accounts.AddEndpoints(accountManager, settingsManager, embeddedIdpEnabled, router) + accounts.AddEndpoints(accountManager, settingsManager, router) peers.AddEndpoints(accountManager, router, networkMapController) users.AddEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router) diff --git a/management/server/http/handlers/accounts/accounts_handler.go b/management/server/http/handlers/accounts/accounts_handler.go index de778d59a..122c061ce 100644 --- a/management/server/http/handlers/accounts/accounts_handler.go +++ b/management/server/http/handlers/accounts/accounts_handler.go @@ -36,24 +36,22 @@ const ( // handler is a handler that handles the server.Account HTTP endpoints type handler struct { - accountManager account.Manager - settingsManager settings.Manager - embeddedIdpEnabled bool + accountManager account.Manager + settingsManager settings.Manager } -func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool, router *mux.Router) { - accountsHandler := newHandler(accountManager, settingsManager, embeddedIdpEnabled) +func AddEndpoints(accountManager account.Manager, settingsManager settings.Manager, router *mux.Router) { + accountsHandler := newHandler(accountManager, settingsManager) router.HandleFunc("/accounts/{accountId}", accountsHandler.updateAccount).Methods("PUT", "OPTIONS") router.HandleFunc("/accounts/{accountId}", accountsHandler.deleteAccount).Methods("DELETE", "OPTIONS") router.HandleFunc("/accounts", accountsHandler.getAllAccounts).Methods("GET", "OPTIONS") } // newHandler creates a new handler HTTP handler -func newHandler(accountManager account.Manager, settingsManager settings.Manager, embeddedIdpEnabled bool) *handler { +func newHandler(accountManager account.Manager, settingsManager settings.Manager) *handler { return &handler{ - accountManager: accountManager, - settingsManager: settingsManager, - embeddedIdpEnabled: embeddedIdpEnabled, + accountManager: accountManager, + settingsManager: settingsManager, } } @@ -165,7 +163,7 @@ func (h *handler) getAllAccounts(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, settings, meta, onboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, settings, meta, onboarding) util.WriteJSONObject(r.Context(), w, []*api.Account{resp}) } @@ -292,7 +290,7 @@ func (h *handler) updateAccount(w http.ResponseWriter, r *http.Request) { return } - resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding, h.embeddedIdpEnabled) + resp := toAccountResponse(accountID, updatedSettings, meta, updatedOnboarding) util.WriteJSONObject(r.Context(), w, &resp) } @@ -321,7 +319,7 @@ func (h *handler) deleteAccount(w http.ResponseWriter, r *http.Request) { util.WriteJSONObject(r.Context(), w, util.EmptyObject{}) } -func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding, embeddedIdpEnabled bool) *api.Account { +func toAccountResponse(accountID string, settings *types.Settings, meta *types.AccountMeta, onboarding *types.AccountOnboarding) *api.Account { jwtAllowGroups := settings.JWTAllowGroups if jwtAllowGroups == nil { jwtAllowGroups = []string{} @@ -341,7 +339,8 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A LazyConnectionEnabled: &settings.LazyConnectionEnabled, DnsDomain: &settings.DNSDomain, AutoUpdateVersion: &settings.AutoUpdateVersion, - EmbeddedIdpEnabled: &embeddedIdpEnabled, + EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled, + LocalAuthDisabled: &settings.LocalAuthDisabled, } if settings.NetworkRange.IsValid() { diff --git a/management/server/http/handlers/accounts/accounts_handler_test.go b/management/server/http/handlers/accounts/accounts_handler_test.go index e455372c8..6cbd5908d 100644 --- a/management/server/http/handlers/accounts/accounts_handler_test.go +++ b/management/server/http/handlers/accounts/accounts_handler_test.go @@ -33,7 +33,6 @@ func initAccountsTestData(t *testing.T, account *types.Account) *handler { AnyTimes() return &handler{ - embeddedIdpEnabled: false, accountManager: &mock_server.MockAccountManager{ GetAccountSettingsFunc: func(ctx context.Context, accountID string, userID string) (*types.Settings, error) { return account.Settings, nil @@ -124,6 +123,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: true, expectedID: accountID, @@ -148,6 +148,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -172,6 +173,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr("latest"), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -196,6 +198,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -220,6 +223,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, @@ -244,6 +248,7 @@ func TestAccounts_AccountsHandler(t *testing.T) { DnsDomain: sr(""), AutoUpdateVersion: sr(""), EmbeddedIdpEnabled: br(false), + LocalAuthDisabled: br(false), }, expectedArray: false, expectedID: accountID, diff --git a/management/server/http/handlers/instance/instance_handler.go b/management/server/http/handlers/instance/instance_handler.go index 5d8baaf8d..cd9fae6b8 100644 --- a/management/server/http/handlers/instance/instance_handler.go +++ b/management/server/http/handlers/instance/instance_handler.go @@ -46,7 +46,7 @@ func (h *handler) getInstanceStatus(w http.ResponseWriter, r *http.Request) { util.WriteErrorResponse("failed to check instance status", http.StatusInternalServerError, w) return } - + log.WithContext(r.Context()).Infof("instance setup status: %v", setupRequired) util.WriteJSONObject(r.Context(), w, api.InstanceStatus{ SetupRequired: setupRequired, }) diff --git a/management/server/http/handlers/users/invites_handler_test.go b/management/server/http/handlers/users/invites_handler_test.go index 80826b9d4..529ea24d6 100644 --- a/management/server/http/handlers/users/invites_handler_test.go +++ b/management/server/http/handlers/users/invites_handler_test.go @@ -205,6 +205,14 @@ func TestCreateInvite(t *testing.T) { return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + requestBody: `{"email":"test@example.com","name":"Test User","role":"user","auto_groups":[]}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, accountID, initiatorUserID string, invite *types.UserInfo, expiresIn int) (*types.UserInvite, error) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "invalid JSON", requestBody: `{invalid json}`, @@ -376,6 +384,15 @@ func TestAcceptInvite(t *testing.T) { return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") }, }, + { + name: "local auth disabled", + token: testInviteToken, + requestBody: `{"password":"SecurePass123!"}`, + expectedStatus: http.StatusPreconditionFailed, + mockFunc: func(ctx context.Context, token, password string) error { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + }, + }, { name: "missing token", token: "", diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index 9339c3541..1fd4c9bad 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -73,7 +73,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee proxyController := integrations.NewController(store) userManager := users.NewManager(store) permissionsManager := permissions.NewManager(store) - settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager) + settingsManager := settings.NewManager(store, userManager, integrations.NewManager(&activity.InMemoryEventStore{}), permissionsManager, settings.IdpConfig{}) peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) diff --git a/management/server/idp/embedded.go b/management/server/idp/embedded.go index db7a91fa3..a27050a26 100644 --- a/management/server/idp/embedded.go +++ b/management/server/idp/embedded.go @@ -43,6 +43,11 @@ type EmbeddedIdPConfig struct { Owner *OwnerConfig // SignKeyRefreshEnabled enables automatic key rotation for signing keys SignKeyRefreshEnabled bool + // LocalAuthDisabled disables the local (email/password) authentication connector. + // When true, users cannot authenticate via email/password, only via external identity providers. + // Existing local users are preserved and will be able to login again if re-enabled. + // Cannot be enabled if no external identity provider connectors are configured. + LocalAuthDisabled bool } // EmbeddedStorageConfig holds storage configuration for the embedded IdP. @@ -105,6 +110,8 @@ func (c *EmbeddedIdPConfig) ToYAMLConfig() (*dex.YAMLConfig, error) { Issuer: "NetBird", Theme: "light", }, + // Always enable password DB initially - we disable the local connector after startup if needed. + // This ensures Dex has at least one connector during initialization. EnablePasswordDB: true, StaticClients: []storage.Client{ { @@ -192,11 +199,32 @@ func NewEmbeddedIdPManager(ctx context.Context, config *EmbeddedIdPConfig, appMe return nil, err } + log.WithContext(ctx).Debugf("initializing embedded Dex IDP with config: %+v", config) + provider, err := dex.NewProviderFromYAML(ctx, yamlConfig) if err != nil { return nil, fmt.Errorf("failed to create embedded IdP provider: %w", err) } + // If local auth is disabled, validate that other connectors exist + if config.LocalAuthDisabled { + hasOthers, err := provider.HasNonLocalConnectors(ctx) + if err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to check connectors: %w", err) + } + if !hasOthers { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("cannot disable local authentication: no other identity providers configured") + } + // Ensure local connector is removed (it might exist from a previous run) + if err := provider.DisableLocalAuth(ctx); err != nil { + _ = provider.Stop(ctx) + return nil, fmt.Errorf("failed to disable local auth: %w", err) + } + log.WithContext(ctx).Info("local authentication disabled - only external identity providers can be used") + } + log.WithContext(ctx).Infof("embedded Dex IDP initialized with issuer: %s", yamlConfig.Issuer) return &EmbeddedIdPManager{ @@ -281,6 +309,8 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* return nil, fmt.Errorf("failed to list users: %w", err) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(users)) + indexedUsers := make(map[string][]*UserData) for _, user := range users { indexedUsers[UnsetAccountID] = append(indexedUsers[UnsetAccountID], &UserData{ @@ -290,11 +320,17 @@ func (m *EmbeddedIdPManager) GetAllAccounts(ctx context.Context) (map[string][]* }) } + log.WithContext(ctx).Debugf("retrieved %d users from embedded IdP", len(indexedUsers[UnsetAccountID])) + return indexedUsers, nil } // CreateUser creates a new user in the embedded IdP. func (m *EmbeddedIdPManager) CreateUser(ctx context.Context, email, name, accountID, invitedByEmail string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -364,6 +400,10 @@ func (m *EmbeddedIdPManager) GetUserByEmail(ctx context.Context, email string) ( // Unlike CreateUser which auto-generates a password, this method uses the provided password. // This is useful for instance setup where the user provides their own password. func (m *EmbeddedIdPManager) CreateUserWithPassword(ctx context.Context, email, password, name string) (*UserData, error) { + if m.config.LocalAuthDisabled { + return nil, fmt.Errorf("local user creation is disabled") + } + if m.appMetrics != nil { m.appMetrics.IDPMetrics().CountCreateUser() } @@ -553,3 +593,13 @@ func (m *EmbeddedIdPManager) GetClientIDs() []string { func (m *EmbeddedIdPManager) GetUserIDClaim() string { return defaultUserIDClaim } + +// IsLocalAuthDisabled returns whether local authentication is disabled based on configuration. +func (m *EmbeddedIdPManager) IsLocalAuthDisabled() bool { + return m.config.LocalAuthDisabled +} + +// HasNonLocalConnectors checks if there are any identity provider connectors other than local. +func (m *EmbeddedIdPManager) HasNonLocalConnectors(ctx context.Context) (bool, error) { + return m.provider.HasNonLocalConnectors(ctx) +} diff --git a/management/server/idp/embedded_test.go b/management/server/idp/embedded_test.go index d8d3009dd..4dda483fb 100644 --- a/management/server/idp/embedded_test.go +++ b/management/server/idp/embedded_test.go @@ -370,3 +370,234 @@ func TestEmbeddedIdPManager_GetLocalKeysLocation(t *testing.T) { }) } } + +func TestEmbeddedIdPManager_LocalAuthDisabled(t *testing.T) { + ctx := context.Background() + + t.Run("cannot start with local auth disabled without other connectors", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + _, err = NewEmbeddedIdPManager(ctx, config, nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "no other identity providers configured") + }) + + t.Run("local auth enabled by default", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + config := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: filepath.Join(tmpDir, "dex.db"), + }, + }, + } + + manager, err := NewEmbeddedIdPManager(ctx, config, nil) + require.NoError(t, err) + defer func() { _ = manager.Stop(ctx) }() + + // Verify local auth is enabled by default + assert.False(t, manager.IsLocalAuthDisabled()) + }) + + t.Run("start with local auth disabled when connector exists", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager with local auth enabled and add a connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + // Create a user + userData, err := manager1.CreateUser(ctx, "preserved@example.com", "Preserved User", "account1", "admin@example.com") + require.NoError(t, err) + userID := userData.ID + + // Add an external connector (Google doesn't require OIDC discovery) + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + // Stop the first manager + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Now create a new manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Verify local auth is disabled via config + assert.True(t, manager2.IsLocalAuthDisabled()) + + // Verify the user still exists in storage (just can't login via local) + lookedUp, err := manager2.GetUserDataByID(ctx, userID, AppMetadata{}) + require.NoError(t, err) + assert.Equal(t, "preserved@example.com", lookedUp.Email) + }) + + t.Run("CreateUser fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user - should fail + _, err = manager2.CreateUser(ctx, "newuser@example.com", "New User", "account1", "admin@example.com") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) + + t.Run("CreateUserWithPassword fails when local auth is disabled", func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "embedded-idp-test-*") + require.NoError(t, err) + defer os.RemoveAll(tmpDir) + + dbFile := filepath.Join(tmpDir, "dex.db") + + // First, create a manager and add an external connector + config1 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager1, err := NewEmbeddedIdPManager(ctx, config1, nil) + require.NoError(t, err) + + _, err = manager1.CreateConnector(ctx, &dex.ConnectorConfig{ + ID: "google-test", + Name: "Google Test", + Type: "google", + ClientID: "test-client-id", + ClientSecret: "test-client-secret", + }) + require.NoError(t, err) + + err = manager1.Stop(ctx) + require.NoError(t, err) + + // Create manager with local auth disabled + config2 := &EmbeddedIdPConfig{ + Enabled: true, + Issuer: "http://localhost:5556/dex", + LocalAuthDisabled: true, + Storage: EmbeddedStorageConfig{ + Type: "sqlite3", + Config: EmbeddedStorageTypeConfig{ + File: dbFile, + }, + }, + } + + manager2, err := NewEmbeddedIdPManager(ctx, config2, nil) + require.NoError(t, err) + defer func() { _ = manager2.Stop(ctx) }() + + // Try to create a user with password - should fail + _, err = manager2.CreateUserWithPassword(ctx, "newuser@example.com", "SecurePass123!", "New User") + require.Error(t, err) + assert.Contains(t, err.Error(), "local user creation is disabled") + }) +} diff --git a/management/server/instance/manager.go b/management/server/instance/manager.go index 6a0509ebd..19e3abdc0 100644 --- a/management/server/instance/manager.go +++ b/management/server/instance/manager.go @@ -104,13 +104,22 @@ func NewManager(ctx context.Context, store store.Store, idpManager idp.Manager) } func (m *DefaultManager) loadSetupRequired(ctx context.Context) error { + // Check if there are any accounts in the NetBird store + numAccounts, err := m.store.GetAccountsCounter(ctx) + if err != nil { + return err + } + hasAccounts := numAccounts > 0 + + // Check if there are any users in the embedded IdP (Dex) users, err := m.embeddedIdpManager.GetAllAccounts(ctx) if err != nil { return err } + hasLocalUsers := len(users) > 0 m.setupMu.Lock() - m.setupRequired = len(users) == 0 + m.setupRequired = !(hasAccounts || hasLocalUsers) m.setupMu.Unlock() return nil diff --git a/management/server/settings/manager.go b/management/server/settings/manager.go index 2b2896572..74af0a3ef 100644 --- a/management/server/settings/manager.go +++ b/management/server/settings/manager.go @@ -24,19 +24,28 @@ type Manager interface { UpdateExtraSettings(ctx context.Context, accountID, userID string, extraSettings *types.ExtraSettings) (bool, error) } +// IdpConfig holds IdP-related configuration that is set at runtime +// and not stored in the database. +type IdpConfig struct { + EmbeddedIdpEnabled bool + LocalAuthDisabled bool +} + type managerImpl struct { store store.Store extraSettingsManager extra_settings.Manager userManager users.Manager permissionsManager permissions.Manager + idpConfig IdpConfig } -func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager) Manager { +func NewManager(store store.Store, userManager users.Manager, extraSettingsManager extra_settings.Manager, permissionsManager permissions.Manager, idpConfig IdpConfig) Manager { return &managerImpl{ store: store, extraSettingsManager: extraSettingsManager, userManager: userManager, permissionsManager: permissionsManager, + idpConfig: idpConfig, } } @@ -74,6 +83,10 @@ func (m *managerImpl) GetSettings(ctx context.Context, accountID, userID string) settings.Extra.FlowDnsCollectionEnabled = extraSettings.FlowDnsCollectionEnabled } + // Fill in IdP-related runtime settings + settings.EmbeddedIdpEnabled = m.idpConfig.EmbeddedIdpEnabled + settings.LocalAuthDisabled = m.idpConfig.LocalAuthDisabled + return settings, nil } diff --git a/management/server/types/settings.go b/management/server/types/settings.go index 867e12bef..a94e01b78 100644 --- a/management/server/types/settings.go +++ b/management/server/types/settings.go @@ -55,6 +55,14 @@ type Settings struct { // AutoUpdateVersion client auto-update version AutoUpdateVersion string `gorm:"default:'disabled'"` + + // EmbeddedIdpEnabled indicates if the embedded identity provider is enabled. + // This is a runtime-only field, not stored in the database. + EmbeddedIdpEnabled bool `gorm:"-"` + + // LocalAuthDisabled indicates if local (email/password) authentication is disabled. + // This is a runtime-only field, not stored in the database. + LocalAuthDisabled bool `gorm:"-"` } // Copy copies the Settings struct @@ -76,6 +84,8 @@ func (s *Settings) Copy() *Settings { DNSDomain: s.DNSDomain, NetworkRange: s.NetworkRange, AutoUpdateVersion: s.AutoUpdateVersion, + EmbeddedIdpEnabled: s.EmbeddedIdpEnabled, + LocalAuthDisabled: s.LocalAuthDisabled, } if s.Extra != nil { settings.Extra = s.Extra.Copy() diff --git a/management/server/user.go b/management/server/user.go index 51da7a633..48005f325 100644 --- a/management/server/user.go +++ b/management/server/user.go @@ -191,6 +191,10 @@ func (am *DefaultAccountManager) createNewIdpUser(ctx context.Context, accountID // Unlike createNewIdpUser, this method fetches user data directly from the database // since the embedded IdP usage ensures the username and email are stored locally in the User table. func (am *DefaultAccountManager) createEmbeddedIdpUser(ctx context.Context, accountID string, inviterID string, invite *types.UserInfo) (*idp.UserData, error) { + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + inviter, err := am.Store.GetUserByUserID(ctx, store.LockingStrengthNone, inviterID) if err != nil { return nil, fmt.Errorf("failed to get inviter user: %w", err) @@ -1462,6 +1466,10 @@ func (am *DefaultAccountManager) CreateUserInvite(ctx context.Context, accountID return nil, status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return nil, status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if err := validateUserInvite(invite); err != nil { return nil, err } @@ -1621,6 +1629,10 @@ func (am *DefaultAccountManager) AcceptUserInvite(ctx context.Context, token, pa return status.Errorf(status.PreconditionFailed, "invite links are only available with embedded identity provider") } + if IsLocalAuthDisabled(ctx, am.idpManager) { + return status.Errorf(status.PreconditionFailed, "local user creation is disabled - use an external identity provider") + } + if password == "" { return status.Errorf(status.InvalidArgument, "password is required") } diff --git a/shared/management/http/api/openapi.yml b/shared/management/http/api/openapi.yml index 26d2387d1..b9a8eae3a 100644 --- a/shared/management/http/api/openapi.yml +++ b/shared/management/http/api/openapi.yml @@ -294,6 +294,11 @@ components: type: boolean readOnly: true example: false + local_auth_disabled: + description: Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + type: boolean + readOnly: true + example: false required: - peer_login_expiration_enabled - peer_login_expiration diff --git a/shared/management/http/api/types.gen.go b/shared/management/http/api/types.gen.go index e8c044b32..fd7c61917 100644 --- a/shared/management/http/api/types.gen.go +++ b/shared/management/http/api/types.gen.go @@ -415,6 +415,9 @@ type AccountSettings struct { // LazyConnectionEnabled Enables or disables experimental lazy connection LazyConnectionEnabled *bool `json:"lazy_connection_enabled,omitempty"` + // LocalAuthDisabled Indicates whether local (email/password) authentication is disabled. When true, users can only authenticate via external identity providers. This is a read-only field. + LocalAuthDisabled *bool `json:"local_auth_disabled,omitempty"` + // NetworkRange Allows to define a custom network range for the account in CIDR format NetworkRange *string `json:"network_range,omitempty"` From 7b830d8f72b6fa997f03d99292de814faa33a562 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sun, 1 Feb 2026 14:37:00 +0100 Subject: [PATCH 02/15] disable sync lim (#5233) --- management/internals/shared/grpc/server.go | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 219baaf6d..6757cca13 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -77,8 +77,9 @@ type Server struct { oAuthConfigProvider idp.OAuthConfigProvider - syncSem atomic.Int32 - syncLim int32 + syncSem atomic.Int32 + syncLimEnabled bool + syncLim int32 } // NewServer creates a new Management server @@ -108,6 +109,7 @@ func NewServer( blockPeersWithSameConfig := strings.ToLower(os.Getenv(envBlockPeers)) == "true" syncLim := int32(defaultSyncLim) + syncLimEnabled := true if syncLimStr := os.Getenv(envConcurrentSyncs); syncLimStr != "" { syncLimParsed, err := strconv.Atoi(syncLimStr) if err != nil { @@ -115,6 +117,9 @@ func NewServer( } else { //nolint:gosec syncLim = int32(syncLimParsed) + if syncLim < 0 { + syncLimEnabled = false + } } } @@ -134,7 +139,8 @@ func NewServer( loginFilter: newLoginFilter(), - syncLim: syncLim, + syncLim: syncLim, + syncLimEnabled: syncLimEnabled, }, nil } @@ -212,7 +218,7 @@ func (s *Server) Job(srv proto.ManagementService_JobServer) error { // Sync validates the existence of a connecting peer, sends an initial state (all available for the connecting peers) and // notifies the connected peer of any updates (e.g. new peers under the same account) func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_SyncServer) error { - if s.syncSem.Load() >= s.syncLim { + if s.syncLimEnabled && s.syncSem.Load() >= s.syncLim { return status.Errorf(codes.ResourceExhausted, "too many concurrent sync requests, please try again later") } s.syncSem.Add(1) From 893129334376ef0cf65198dfbf5f428290ef6244 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Sun, 1 Feb 2026 15:44:27 +0100 Subject: [PATCH 03/15] [management] run cancelPeerRoutinesWithoutLock in sync (#5234) --- management/internals/shared/grpc/server.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 6757cca13..3704b3188 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -311,7 +311,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) return err } @@ -319,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) return err } @@ -490,6 +490,10 @@ func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) +} + +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) { err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) From b20d4849720754928ce08a88d8970877a5b0daa5 Mon Sep 17 00:00:00 2001 From: Misha Bragin Date: Sun, 1 Feb 2026 16:06:36 +0100 Subject: [PATCH 04/15] [docs] Add selfhosting video (#5235) --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 8f4c04641..bca81c20b 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,8 @@ https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2 -### NetBird on Lawrence Systems (Video) -[![Watch the video](https://img.youtube.com/vi/Kwrff6h0rEw/0.jpg)](https://www.youtube.com/watch?v=Kwrff6h0rEw) +### Self-Host NetBird (Video) +[![Watch the video](https://img.youtube.com/vi/bZAgpT6nzaQ/0.jpg)](https://youtu.be/bZAgpT6nzaQ) ### Key features From 6fdc00ff4185849cf9d3e65d38971684a0ffe65e Mon Sep 17 00:00:00 2001 From: Pascal Fischer <32096965+pascal-fischer@users.noreply.github.com> Date: Tue, 3 Feb 2026 17:30:02 +0100 Subject: [PATCH 05/15] [management] adding account id validation to accessible peers handler (#5246) --- management/server/http/handler.go | 5 +++-- .../http/handlers/peers/peers_handler.go | 19 ++++++++++++++----- .../http/handlers/peers/peers_handler_test.go | 11 +++++++++-- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/management/server/http/handler.go b/management/server/http/handler.go index 79431a0a3..17355d1d9 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -9,10 +9,11 @@ import ( "time" "github.com/gorilla/mux" - idpmanager "github.com/netbirdio/netbird/management/server/idp" "github.com/rs/cors" log "github.com/sirupsen/logrus" + idpmanager "github.com/netbirdio/netbird/management/server/idp" + "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/modules/zones" @@ -137,7 +138,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks } accounts.AddEndpoints(accountManager, settingsManager, router) - peers.AddEndpoints(accountManager, router, networkMapController) + peers.AddEndpoints(accountManager, router, networkMapController, permissionsManager) users.AddEndpoints(accountManager, router) users.AddInvitesEndpoints(accountManager, router) users.AddPublicInvitesEndpoints(accountManager, router) diff --git a/management/server/http/handlers/peers/peers_handler.go b/management/server/http/handlers/peers/peers_handler.go index 53d8ab055..783cfe11b 100644 --- a/management/server/http/handlers/peers/peers_handler.go +++ b/management/server/http/handlers/peers/peers_handler.go @@ -17,6 +17,7 @@ import ( nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/groups" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/management/http/api" "github.com/netbirdio/netbird/shared/management/http/util" @@ -26,11 +27,12 @@ import ( // Handler is a handler that returns peers of the account type Handler struct { accountManager account.Manager + permissionsManager permissions.Manager networkMapController network_map.Controller } -func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller) { - peersHandler := NewHandler(accountManager, networkMapController) +func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMapController network_map.Controller, permissionsManager permissions.Manager) { + peersHandler := NewHandler(accountManager, networkMapController, permissionsManager) router.HandleFunc("/peers", peersHandler.GetAllPeers).Methods("GET", "OPTIONS") router.HandleFunc("/peers/{peerId}", peersHandler.HandlePeer). Methods("GET", "PUT", "DELETE", "OPTIONS") @@ -42,10 +44,11 @@ func AddEndpoints(accountManager account.Manager, router *mux.Router, networkMap } // NewHandler creates a new peers Handler -func NewHandler(accountManager account.Manager, networkMapController network_map.Controller) *Handler { +func NewHandler(accountManager account.Manager, networkMapController network_map.Controller, permissionsManager permissions.Manager) *Handler { return &Handler{ accountManager: accountManager, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } @@ -359,13 +362,19 @@ func (h *Handler) GetAccessiblePeers(w http.ResponseWriter, r *http.Request) { return } - account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) + user, err := h.accountManager.GetUserByID(r.Context(), userID) if err != nil { util.WriteError(r.Context(), err, w) return } - user, err := h.accountManager.GetUserByID(r.Context(), userID) + err = h.permissionsManager.ValidateAccountAccess(r.Context(), accountID, user, false) + if err != nil { + util.WriteError(r.Context(), status.NewPermissionDeniedError(), w) + return + } + + account, err := h.accountManager.GetAccountByID(r.Context(), accountID, activity.SystemInitiator) if err != nil { util.WriteError(r.Context(), err, w) return diff --git a/management/server/http/handlers/peers/peers_handler_test.go b/management/server/http/handlers/peers/peers_handler_test.go index 869a39b5e..786c144fc 100644 --- a/management/server/http/handlers/peers/peers_handler_test.go +++ b/management/server/http/handlers/peers/peers_handler_test.go @@ -13,13 +13,15 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/gorilla/mux" - "go.uber.org/mock/gomock" + ugomock "go.uber.org/mock/gomock" "golang.org/x/exp/maps" "github.com/netbirdio/netbird/management/internals/controllers/network_map" nbcontext "github.com/netbirdio/netbird/management/server/context" nbpeer "github.com/netbirdio/netbird/management/server/peer" + "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/shared/auth" "github.com/netbirdio/netbird/shared/management/http/api" @@ -102,7 +104,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, } - ctrl := gomock.NewController(t) + ctrl := ugomock.NewController(t) networkMapController := network_map.NewMockController(ctrl) networkMapController.EXPECT(). @@ -110,6 +112,10 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { Return("domain"). AnyTimes() + ctrl2 := gomock.NewController(t) + permissionsManager := permissions.NewMockManager(ctrl2) + permissionsManager.EXPECT().ValidateAccountAccess(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + return &Handler{ accountManager: &mock_server.MockAccountManager{ UpdatePeerFunc: func(_ context.Context, accountID, userID string, update *nbpeer.Peer) (*nbpeer.Peer, error) { @@ -199,6 +205,7 @@ func initTestMetaData(t *testing.T, peers ...*nbpeer.Peer) *Handler { }, }, networkMapController: networkMapController, + permissionsManager: permissionsManager, } } From d488f583115402c62d9974e787a5b23f1b73fc32 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Wed, 4 Feb 2026 11:44:46 +0100 Subject: [PATCH 06/15] [management] fix set disconnected status for connected peer (#5247) --- management/internals/shared/grpc/server.go | 32 ++++++----- management/server/account.go | 16 +++++- management/server/account/manager.go | 2 +- management/server/account_test.go | 55 +++++++++++++++++++ management/server/mock_server/account_mock.go | 5 +- 5 files changed, 89 insertions(+), 21 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 3704b3188..befcd2adf 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -307,11 +307,13 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S return mapError(ctx, err) } + streamStartTime := time.Now().UTC() + err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -319,7 +321,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) return err } @@ -336,7 +338,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { @@ -404,7 +406,7 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. -func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) for { select { @@ -416,11 +418,11 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg if !open { log.WithContext(ctx).Debugf("updates channel for peer %s was closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv); err != nil { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) return err } @@ -429,7 +431,7 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg case <-srv.Context().Done(): // happens when connection drops, e.g. client disconnects log.WithContext(ctx).Debugf("stream of peer %s has been closed", peerKey.String()) - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return srv.Context().Err() } } @@ -437,16 +439,16 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg // sendUpdate encrypts the update message using the peer key and the server's wireguard key, // then sends the encrypted message to the connected peer via the sync server. -func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer) error { +func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, update *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { key, err := s.secretsManager.GetWGKey() if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } encryptedResp, err := encryption.EncryptMessage(peerKey, key, update.Update) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed processing update message") } err = srv.Send(&proto.EncryptedMessage{ @@ -454,7 +456,7 @@ func (s *Server) sendUpdate(ctx context.Context, accountID string, peerKey wgtyp Body: encryptedResp, }) if err != nil { - s.cancelPeerRoutines(ctx, accountID, peer) + s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return status.Errorf(codes.Internal, "failed sending update message") } log.WithContext(ctx).Debugf("sent an update to peer %s", peerKey.String()) @@ -486,15 +488,15 @@ func (s *Server) sendJob(ctx context.Context, peerKey wgtypes.Key, job *job.Even return nil } -func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer) { +func (s *Server) cancelPeerRoutines(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { unlock := s.acquirePeerLockByUID(ctx, peer.Key) defer unlock() - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) } -func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer) { - err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key) +func (s *Server) cancelPeerRoutinesWithoutLock(ctx context.Context, accountID string, peer *nbpeer.Peer, streamStartTime time.Time) { + err := s.accountManager.OnPeerDisconnected(ctx, accountID, peer.Key, streamStartTime) if err != nil { log.WithContext(ctx).Errorf("failed to disconnect peer %s properly: %v", peer.Key, err) } diff --git a/management/server/account.go b/management/server/account.go index 8f9dad031..4f53415f5 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1684,8 +1684,20 @@ func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID return peer, netMap, postureChecks, dnsfwdPort, nil } -func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error { - err := am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) +func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + peer, err := am.Store.GetPeerByPeerPubKey(ctx, store.LockingStrengthNone, peerPubKey) + if err != nil { + log.WithContext(ctx).Warnf("failed to get peer %s for disconnect check: %v", peerPubKey, err) + return nil + } + + if peer.Status.LastSeen.After(streamStartTime) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s > streamStart=%s), skipping disconnect", + peerPubKey, peer.Status.LastSeen.Format(time.RFC3339), streamStartTime.Format(time.RFC3339)) + return nil + } + + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index 5e9bb42a2..eed7739da 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -115,7 +115,7 @@ type Manager interface { GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) - OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string) error + OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) GetAccountIDForPeerKey(ctx context.Context, peerKey string) (string, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index 86cc69e8b..f3d98916c 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1961,6 +1961,61 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. } } +func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { + manager, _, err := createManager(t) + require.NoError(t, err, "unable to create account manager") + + accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) + require.NoError(t, err, "unable to create an account") + + key, err := wgtypes.GenerateKey() + require.NoError(t, err, "unable to generate WireGuard key") + peerPubKey := key.PublicKey().String() + + _, _, _, err = manager.AddPeer(context.Background(), "", "", userID, &nbpeer.Peer{ + Key: peerPubKey, + Meta: nbpeer.PeerSystemMeta{Hostname: "test-peer"}, + }, false) + require.NoError(t, err, "unable to add peer") + + t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err, "unable to get peer") + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := time.Now().UTC() + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.False(t, peer.Status.Connected, "peer should be disconnected") + }) + + t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + require.NoError(t, err, "unable to mark peer connected") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + + streamStartTime := peer.Status.LastSeen.Add(-1 * time.Hour) + + err = manager.OnPeerDisconnected(context.Background(), accountID, peerPubKey, streamStartTime) + require.NoError(t, err) + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, + "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") + }) +} + func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { manager, _, err := createManager(t) require.NoError(t, err, "unable to create account manager") diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index 026989898..a4754d180 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -221,9 +221,8 @@ func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID str return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } -func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string) error { - // TODO implement me - panic("implement me") +func (am *MockAccountManager) OnPeerDisconnected(_ context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error { + return nil } func (am *MockAccountManager) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) { From f7732557fa42622c85c96aa894a648ac7f5cb58f Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Feb 2026 01:07:27 +0800 Subject: [PATCH 07/15] [client] Add missing bsd flags in debug bundle (#5254) --- .../routemanager/systemops/routeflags_bsd.go | 40 ++++++++++++------- .../systemops/routeflags_freebsd.go | 38 +++++++++++------- 2 files changed, 49 insertions(+), 29 deletions(-) diff --git a/client/internal/routemanager/systemops/routeflags_bsd.go b/client/internal/routemanager/systemops/routeflags_bsd.go index ad32e5029..33280bfb3 100644 --- a/client/internal/routemanager/systemops/routeflags_bsd.go +++ b/client/internal/routemanager/systemops/routeflags_bsd.go @@ -4,16 +4,17 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE|syscall.RTF_WASCLONED) != 0 { + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE|unix.RTF_WASCLONED) != 0 { return true } @@ -24,42 +25,51 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } - if flags&syscall.RTF_CLONING != 0 { + if flags&unix.RTF_CLONING != 0 { flagStrs = append(flagStrs, "C") } - if flags&syscall.RTF_WASCLONED != 0 { + if flags&unix.RTF_WASCLONED != 0 { flagStrs = append(flagStrs, "W") } + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" diff --git a/client/internal/routemanager/systemops/routeflags_freebsd.go b/client/internal/routemanager/systemops/routeflags_freebsd.go index 2338fe5d8..a8c82b3ed 100644 --- a/client/internal/routemanager/systemops/routeflags_freebsd.go +++ b/client/internal/routemanager/systemops/routeflags_freebsd.go @@ -4,17 +4,18 @@ package systemops import ( "strings" - "syscall" + + "golang.org/x/sys/unix" ) // filterRoutesByFlags returns true if the route message should be ignored based on its flags. func filterRoutesByFlags(routeMessageFlags int) bool { - if routeMessageFlags&syscall.RTF_UP == 0 { + if routeMessageFlags&unix.RTF_UP == 0 { return true } - // NOTE: syscall.RTF_WASCLONED deprecated in FreeBSD 8.0 - if routeMessageFlags&(syscall.RTF_REJECT|syscall.RTF_BLACKHOLE) != 0 { + // NOTE: RTF_WASCLONED deprecated in FreeBSD 8.0 + if routeMessageFlags&(unix.RTF_REJECT|unix.RTF_BLACKHOLE) != 0 { return true } @@ -25,37 +26,46 @@ func filterRoutesByFlags(routeMessageFlags int) bool { func formatBSDFlags(flags int) string { var flagStrs []string - if flags&syscall.RTF_UP != 0 { + if flags&unix.RTF_UP != 0 { flagStrs = append(flagStrs, "U") } - if flags&syscall.RTF_GATEWAY != 0 { + if flags&unix.RTF_GATEWAY != 0 { flagStrs = append(flagStrs, "G") } - if flags&syscall.RTF_HOST != 0 { + if flags&unix.RTF_HOST != 0 { flagStrs = append(flagStrs, "H") } - if flags&syscall.RTF_REJECT != 0 { + if flags&unix.RTF_REJECT != 0 { flagStrs = append(flagStrs, "R") } - if flags&syscall.RTF_DYNAMIC != 0 { + if flags&unix.RTF_DYNAMIC != 0 { flagStrs = append(flagStrs, "D") } - if flags&syscall.RTF_MODIFIED != 0 { + if flags&unix.RTF_MODIFIED != 0 { flagStrs = append(flagStrs, "M") } - if flags&syscall.RTF_STATIC != 0 { + if flags&unix.RTF_STATIC != 0 { flagStrs = append(flagStrs, "S") } - if flags&syscall.RTF_LLINFO != 0 { + if flags&unix.RTF_LLINFO != 0 { flagStrs = append(flagStrs, "L") } - if flags&syscall.RTF_LOCAL != 0 { + if flags&unix.RTF_LOCAL != 0 { flagStrs = append(flagStrs, "l") } - if flags&syscall.RTF_BLACKHOLE != 0 { + if flags&unix.RTF_BLACKHOLE != 0 { flagStrs = append(flagStrs, "B") } // Note: RTF_CLONING and RTF_WASCLONED deprecated in FreeBSD 8.0 + if flags&unix.RTF_PROTO1 != 0 { + flagStrs = append(flagStrs, "1") + } + if flags&unix.RTF_PROTO2 != 0 { + flagStrs = append(flagStrs, "2") + } + if flags&unix.RTF_PROTO3 != 0 { + flagStrs = append(flagStrs, "3") + } if len(flagStrs) == 0 { return "-" From 194a986926cfcce284e77ee6eca84732d8976ace Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Wed, 4 Feb 2026 22:22:37 +0100 Subject: [PATCH 08/15] Cache the result of wgInterface.ToInterface() using sync.Once (#5256) Avoid repeated conversions during route setup. The toInterface helper ensures the conversion happens only once regardless of how many routes are added or removed. --- client/internal/routemanager/manager.go | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/client/internal/routemanager/manager.go b/client/internal/routemanager/manager.go index 2baa0e668..077b9521b 100644 --- a/client/internal/routemanager/manager.go +++ b/client/internal/routemanager/manager.go @@ -173,12 +173,21 @@ func (m *DefaultManager) setupAndroidRoutes(config ManagerConfig) { } func (m *DefaultManager) setupRefCounters(useNoop bool) { + var once sync.Once + var wgIface *net.Interface + toInterface := func() *net.Interface { + once.Do(func() { + wgIface = m.wgInterface.ToInterface() + }) + return wgIface + } + m.routeRefCounter = refcounter.New( func(prefix netip.Prefix, _ struct{}) (struct{}, error) { - return struct{}{}, m.sysOps.AddVPNRoute(prefix, m.wgInterface.ToInterface()) + return struct{}{}, m.sysOps.AddVPNRoute(prefix, toInterface()) }, func(prefix netip.Prefix, _ struct{}) error { - return m.sysOps.RemoveVPNRoute(prefix, m.wgInterface.ToInterface()) + return m.sysOps.RemoveVPNRoute(prefix, toInterface()) }, ) From d2f9653cea8c3b7d119c657366e967eb5869dd07 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Thu, 5 Feb 2026 12:06:28 +0100 Subject: [PATCH 09/15] Fix nil pointer panic in ICE agent during sleep/wake cycles (#5261) Add defensive nil checks in ThreadSafeAgent.Close() to prevent panic when agent field is nil. This can occur during Windows suspend/resume when network interfaces are disrupted or the pion/ice library returns nil without error. Also capture agent pointer in local variable before goroutine execution to prevent race conditions. Fixes service crashes on laptop wake-up. --- client/internal/peer/ice/agent.go | 51 +++++++++++++++++++----------- client/internal/peer/worker_ice.go | 6 ++-- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/client/internal/peer/ice/agent.go b/client/internal/peer/ice/agent.go index 79f68d279..c74b46d10 100644 --- a/client/internal/peer/ice/agent.go +++ b/client/internal/peer/ice/agent.go @@ -2,6 +2,7 @@ package ice import ( "context" + "fmt" "sync" "time" @@ -32,24 +33,6 @@ type ThreadSafeAgent struct { once sync.Once } -func (a *ThreadSafeAgent) Close() error { - var err error - a.once.Do(func() { - done := make(chan error, 1) - go func() { - done <- a.Agent.Close() - }() - - select { - case err = <-done: - case <-time.After(iceAgentCloseTimeout): - log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) - err = nil - } - }) - return err -} - func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, config Config, candidateTypes []ice.CandidateType, ufrag string, pwd string) (*ThreadSafeAgent, error) { iceKeepAlive := iceKeepAlive() iceDisconnectedTimeout := iceDisconnectedTimeout() @@ -93,9 +76,41 @@ func NewAgent(ctx context.Context, iFaceDiscover stdnet.ExternalIFaceDiscover, c return nil, err } + if agent == nil { + return nil, fmt.Errorf("ice.NewAgent returned nil agent without error") + } + return &ThreadSafeAgent{Agent: agent}, nil } +func (a *ThreadSafeAgent) Close() error { + var err error + a.once.Do(func() { + // Defensive check to prevent nil pointer dereference + // This can happen during sleep/wake transitions or memory corruption scenarios + // github.com/netbirdio/netbird/client/internal/peer/ice.(*ThreadSafeAgent).Close(0x40006883f0?) + // [signal 0xc0000005 code=0x0 addr=0x0 pc=0x7ff7e73af83c] + agent := a.Agent + if agent == nil { + log.Warnf("ICE agent is nil during close, skipping") + return + } + + done := make(chan error, 1) + go func() { + done <- agent.Close() + }() + + select { + case err = <-done: + case <-time.After(iceAgentCloseTimeout): + log.Warnf("ICE agent close timed out after %v, proceeding with cleanup", iceAgentCloseTimeout) + err = nil + } + }) + return err +} + func GenerateICECredentials() (string, string, error) { ufrag, err := randutil.GenerateCryptoRandomString(lenUFrag, runesAlpha) if err != nil { diff --git a/client/internal/peer/worker_ice.go b/client/internal/peer/worker_ice.go index b6b9d2cf4..464f57bff 100644 --- a/client/internal/peer/worker_ice.go +++ b/client/internal/peer/worker_ice.go @@ -107,8 +107,10 @@ func (w *WorkerICE) OnNewOffer(remoteOfferAnswer *OfferAnswer) { } w.log.Debugf("agent already exists, recreate the connection") w.agentDialerCancel() - if err := w.agent.Close(); err != nil { - w.log.Warnf("failed to close ICE agent: %s", err) + if w.agent != nil { + if err := w.agent.Close(); err != nil { + w.log.Warnf("failed to close ICE agent: %s", err) + } } sessionID, err := NewICESessionID() From 1b96648d4d190ec2d57b3a195004445dd5e10b16 Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Thu, 5 Feb 2026 21:34:35 +0800 Subject: [PATCH 10/15] [client] Always log dns forwader responses (#5262) --- client/internal/dnsfwd/forwarder.go | 125 ++++++++++------------- client/internal/dnsfwd/forwarder_test.go | 90 +++++++--------- 2 files changed, 93 insertions(+), 122 deletions(-) diff --git a/client/internal/dnsfwd/forwarder.go b/client/internal/dnsfwd/forwarder.go index 1230a4e46..5c7cb31fc 100644 --- a/client/internal/dnsfwd/forwarder.go +++ b/client/internal/dnsfwd/forwarder.go @@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error { return nberrors.FormatErrorOrNil(result) } -func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg { +func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) { if len(query.Question) == 0 { - return nil + return } question := query.Question[0] - logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s", - question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) + qname := strings.ToLower(question.Name) - domain := strings.ToLower(question.Name) + logger.Tracef("question: domain=%s type=%s class=%s", + qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass]) resp := query.SetReply(query) network := resutil.NetworkForQtype(question.Qtype) if network == "" { resp.Rcode = dns.RcodeNotImplemented - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } - mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, ".")) - // query doesn't match any configured domain + mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, ".")) if mostSpecificResId == "" { resp.Rcode = dns.RcodeRefused - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - } - return nil + f.writeResponse(logger, w, resp, qname, startTime) + return } ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout) defer cancel() - result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype) + result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype) if result.Err != nil { - f.handleDNSError(ctx, logger, w, question, resp, domain, result) - return nil + f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime) + return } f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries) - resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...) - f.cache.set(domain, question.Qtype, result.IPs) + resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...) + f.cache.set(qname, question.Qtype, result.IPs) - return resp + f.writeResponse(logger, w, resp, qname, startTime) +} + +func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) { + if err := w.WriteMsg(resp); err != nil { + logger.Errorf("failed to write DNS response: %v", err) + return + } + + logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", + qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) +} + +// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation. +type udpResponseWriter struct { + dns.ResponseWriter + query *dns.Msg +} + +func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error { + opt := u.query.IsEdns0() + maxSize := dns.MinMsgSize + if opt != nil { + maxSize = int(opt.UDPSize()) + } + + if resp.Len() > maxSize { + resp.Truncate(maxSize) + } + + return u.ResponseWriter.WriteMsg(resp) } func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { @@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - opt := query.IsEdns0() - maxSize := dns.MinMsgSize - if opt != nil { - // client advertised a larger EDNS0 buffer - maxSize = int(opt.UDPSize()) - } - - // if our response is too big, truncate and set the TC bit - if resp.Len() > maxSize { - resp.Truncate(maxSize) - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime) } func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { @@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) { "dns_id": fmt.Sprintf("%04x", query.Id), }) - resp := f.handleDNSQuery(logger, w, query) - if resp == nil { - return - } - - if err := w.WriteMsg(resp); err != nil { - logger.Errorf("failed to write DNS response: %v", err) - return - } - - logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s", - query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime)) + f.handleDNSQuery(logger, w, query, startTime) } func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) { @@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError( resp *dns.Msg, domain string, result resutil.LookupResult, + startTime time.Time, ) { qType := question.Qtype qTypeName := dns.TypeToString[qType] @@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError( // NotFound: cache negative result and respond if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess { f.cache.set(domain, question.Qtype, nil) - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError( logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName) resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...) resp.Rcode = dns.RcodeSuccess - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write cached DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } @@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError( verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType) if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess { resp.Rcode = verifyResult.Rcode - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) return } } @@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError( // No cache or verification failed. Log with or without the server field for more context. var dnsErr *net.DNSError if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" { - logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) + logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err) } else { logger.Warnf(errResolveFailed, domain, result.Err) } - // Write final failure response. - if writeErr := w.WriteMsg(resp); writeErr != nil { - logger.Errorf("failed to write failure DNS response: %v", writeErr) - } + f.writeResponse(logger, w, resp, domain, startTime) } // getMatchingEntries retrieves the resource IDs for a given domain. diff --git a/client/internal/dnsfwd/forwarder_test.go b/client/internal/dnsfwd/forwarder_test.go index 6416c2f21..7325ef8a7 100644 --- a/client/internal/dnsfwd/forwarder_test.go +++ b/client/internal/dnsfwd/forwarder_test.go @@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response") @@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) { mockFirewall.AssertExpectations(t) mockResolver.AssertExpectations(t) } else { - if resp != nil { - assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, - "Unauthorized domain should not return successful answers") - } + require.NotNil(t, resp, "Expected response") + assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess, + "Unauthorized domain should not return successful answers") mockFirewall.AssertNotCalled(t, "UpdateSet") mockResolver.AssertNotCalled(t, "LookupNetIP") } @@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) { dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now()) // Verify response + resp := mockWriter.GetLastResponse() if tt.shouldResolve { require.NotNil(t, resp, "Expected response for authorized domain") require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.NotEmpty(t, resp.Answer) - } else if resp != nil { + } else { + require.NotNil(t, resp, "Expected response") assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0, "Unauthorized domain should be refused or have no answers") } @@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) { query.SetQuestion("example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Verify response contains all IPs + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) require.Equal(t, dns.RcodeSuccess, resp.Rcode) require.Len(t, resp.Answer, 3, "Should have 3 answer records") @@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) { }, } - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) // Check the response written to the writer require.NotNil(t, writtenResp, "Expected response to be written") @@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) { // Second query: serve from cache after upstream failure q2 := &dns.Msg{} q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp, "expected response to be written") - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2, "expected response to be written") + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q1 := &dns.Msg{} q1.SetQuestion(mixedQuery+".", dns.TypeA) w1 := &test.MockResponseWriter{} - resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now()) + resp1 := w1.GetLastResponse() require.NotNil(t, resp1) require.Equal(t, dns.RcodeSuccess, resp1.Rcode) require.Len(t, resp1.Answer, 1) @@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) { q2 := &dns.Msg{} q2.SetQuestion("EXAMPLE.COM", dns.TypeA) - var writtenResp *dns.Msg - w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }} - _ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2) + w2 := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now()) - require.NotNil(t, writtenResp) - require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode) - require.Len(t, writtenResp.Answer, 1) + resp2 := w2.GetLastResponse() + require.NotNil(t, resp2) + require.Equal(t, dns.RcodeSuccess, resp2.Rcode) + require.Len(t, resp2.Answer, 1) mockResolver.AssertExpectations(t) } @@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) { query.SetQuestion("smtp.mail.example.com.", dns.TypeA) mockWriter := &test.MockResponseWriter{} - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) + resp := mockWriter.GetLastResponse() require.NotNil(t, resp) assert.Equal(t, dns.RcodeSuccess, resp.Rcode) @@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) { query := &dns.Msg{} query.SetQuestion(dns.Fqdn("example.com"), tt.queryType) - var writtenResp *dns.Msg - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writtenResp = m - return nil - }, - } + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) - - // If a response was returned, it means it should be written (happens in wrapper functions) - if resp != nil && writtenResp == nil { - writtenResp = resp - } - - require.NotNil(t, writtenResp, "Expected response to be written") - assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description) + resp := mockWriter.GetLastResponse() + require.NotNil(t, resp, "Expected response to be written") + assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description) if tt.expectNoAnswer { - assert.Empty(t, writtenResp.Answer, "Response should have no answer records") + assert.Empty(t, resp.Answer, "Response should have no answer records") } mockResolver.AssertExpectations(t) @@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) { query := &dns.Msg{} // Don't set any question - writeCalled := false - mockWriter := &test.MockResponseWriter{ - WriteMsgFunc: func(m *dns.Msg) error { - writeCalled = true - return nil - }, - } - resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query) + mockWriter := &test.MockResponseWriter{} + forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now()) - assert.Nil(t, resp, "Should return nil for empty query") - assert.False(t, writeCalled, "Should not write response for empty query") + assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query") } From 0119f3e9f4f87e49af92c3ee5a4a79c7eebe9a1d Mon Sep 17 00:00:00 2001 From: Viktor Liu <17948409+lixmal@users.noreply.github.com> Date: Fri, 6 Feb 2026 17:03:01 +0800 Subject: [PATCH 11/15] [client] Fix netstack detection and add wireguard port option (#5251) - Add WireguardPort option to embed.Options for custom port configuration - Fix KernelInterface detection to account for netstack mode - Skip SSH config updates when running in netstack mode - Skip interface removal wait when running in netstack mode - Use BindListener for netstack to avoid port conflicts on same host --- client/embed/embed.go | 3 +++ client/iface/iface.go | 5 +++++ client/internal/connect.go | 3 ++- client/internal/engine.go | 2 +- client/internal/engine_ssh.go | 9 +++++++++ client/internal/lazyconn/activity/manager.go | 8 +++++--- 6 files changed, 25 insertions(+), 5 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index e73f37e35..2ad025ff0 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -71,6 +71,8 @@ type Options struct { DisableClientRoutes bool // BlockInbound blocks all inbound connections from peers BlockInbound bool + // WireguardPort is the port for the WireGuard interface. Use 0 for a random port. + WireguardPort *int } // validateCredentials checks that exactly one credential type is provided @@ -140,6 +142,7 @@ func New(opts Options) (*Client, error) { DisableServerRoutes: &t, DisableClientRoutes: &opts.DisableClientRoutes, BlockInbound: &opts.BlockInbound, + WireguardPort: opts.WireguardPort, } if opts.ConfigPath != "" { config, err = profilemanager.UpdateOrCreateConfig(input) diff --git a/client/iface/iface.go b/client/iface/iface.go index e5623c979..9b331d68c 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -18,6 +18,7 @@ import ( "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/iface/configurer" "github.com/netbirdio/netbird/client/iface/device" + nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/udpmux" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgproxy" @@ -228,6 +229,10 @@ func (w *WGIface) Close() error { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) } + if nbnetstack.IsEnabled() { + return errors.FormatErrorOrNil(result) + } + if err := w.waitUntilRemoved(); err != nil { log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err) if err := w.Destroy(); err != nil { diff --git a/client/internal/connect.go b/client/internal/connect.go index 7fc3c9a96..17fc20c42 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -20,6 +20,7 @@ import ( "github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -244,7 +245,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan localPeerState := peer.LocalPeerState{ IP: loginResp.GetPeerConfig().GetAddress(), PubKey: myPrivateKey.PublicKey().String(), - KernelInterface: device.WireGuardModuleIsLoaded(), + KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(), FQDN: loginResp.GetPeerConfig().GetFqdn(), } c.statusRecorder.UpdateLocalPeerState(localPeerState) diff --git a/client/internal/engine.go b/client/internal/engine.go index 63ba1c9f2..597ac7c2d 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1017,7 +1017,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error { state := e.statusRecorder.GetLocalPeerState() state.IP = e.wgInterface.Address().String() state.PubKey = e.config.WgPrivateKey.PublicKey().String() - state.KernelInterface = device.WireGuardModuleIsLoaded() + state.KernelInterface = !e.wgInterface.IsUserspaceBind() state.FQDN = conf.GetFqdn() e.statusRecorder.UpdateLocalPeerState(state) diff --git a/client/internal/engine_ssh.go b/client/internal/engine_ssh.go index a8c05fe0a..1419bc262 100644 --- a/client/internal/engine_ssh.go +++ b/client/internal/engine_ssh.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" + "github.com/netbirdio/netbird/client/iface/netstack" nftypes "github.com/netbirdio/netbird/client/internal/netflow/types" sshauth "github.com/netbirdio/netbird/client/ssh/auth" sshconfig "github.com/netbirdio/netbird/client/ssh/config" @@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { // updateSSHClientConfig updates the SSH client configuration with peer information func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error { + if netstack.IsEnabled() { + return nil + } + peerInfo := e.extractPeerSSHInfo(remotePeers) if len(peerInfo) == 0 { log.Debug("no SSH-enabled peers found, skipping SSH config update") @@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) { // cleanupSSHConfig removes NetBird SSH client configuration on shutdown func (e *Engine) cleanupSSHConfig() { + if netstack.IsEnabled() { + return + } + configMgr := sshconfig.New() if err := configMgr.RemoveSSHClientConfig(); err != nil { diff --git a/client/internal/lazyconn/activity/manager.go b/client/internal/lazyconn/activity/manager.go index db283ec9a..1c11378c8 100644 --- a/client/internal/lazyconn/activity/manager.go +++ b/client/internal/lazyconn/activity/manager.go @@ -11,6 +11,7 @@ import ( log "github.com/sirupsen/logrus" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "github.com/netbirdio/netbird/client/iface/netstack" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/lazyconn" peerid "github.com/netbirdio/netbird/client/internal/peer/id" @@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error) return NewUDPListener(m.wgIface, peerCfg) } - // BindListener is only used on Windows and JS platforms: + // BindListener is used on Windows, JS, and netstack platforms: // - JS: Cannot listen to UDP sockets // - Windows: IP_UNICAST_IF socket option forces packets out the interface the default // gateway points to, preventing them from reaching the loopback interface. - // BindListener bypasses this by passing data directly through the bind. - if runtime.GOOS != "windows" && runtime.GOOS != "js" { + // - Netstack: Allows multiple instances on the same host without port conflicts. + // BindListener bypasses these issues by passing data directly through the bind. + if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() { return NewUDPListener(m.wgIface, peerCfg) } From c3f176f34835151da3a3bf493960b3e8a7857421 Mon Sep 17 00:00:00 2001 From: eyJhb Date: Fri, 6 Feb 2026 11:23:36 +0100 Subject: [PATCH 12/15] [client] Fix wrong URL being logged for DefaultAdminURL (#5252) - DefaultManagementURL was being logged instead of DefaultAdminURL --- client/internal/profilemanager/config.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/client/internal/profilemanager/config.go b/client/internal/profilemanager/config.go index f2fda84e0..8f3ff8b11 100644 --- a/client/internal/profilemanager/config.go +++ b/client/internal/profilemanager/config.go @@ -252,7 +252,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) { } if config.AdminURL == nil { - log.Infof("using default Admin URL %s", DefaultManagementURL) + log.Infof("using default Admin URL %s", DefaultAdminURL) config.AdminURL, err = parseURL("Admin URL", DefaultAdminURL) if err != nil { return false, err From af8f730bdac9109b02f3432efd6b6d24a251cea9 Mon Sep 17 00:00:00 2001 From: Vlad <4941176+crn4@users.noreply.github.com> Date: Fri, 6 Feb 2026 18:00:43 +0100 Subject: [PATCH 13/15] [management] check stream start time for connecting peer (#5267) --- management/internals/shared/grpc/server.go | 10 +++--- management/server/account.go | 6 ++-- management/server/account/manager.go | 4 +-- management/server/account_test.go | 33 +++++++++++++++---- management/server/mock_server/account_mock.go | 12 +++---- management/server/peer.go | 20 ++++++++--- 6 files changed, 58 insertions(+), 27 deletions(-) diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index befcd2adf..98c68ebda 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -300,20 +300,18 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S metahash := metaHash(peerMeta, realIP.String()) s.loginFilter.addLogin(peerKey.String(), metahash) - peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP) + peer, netMap, postureChecks, dnsFwdPort, err := s.accountManager.SyncAndMarkPeer(ctx, accountID, peerKey.String(), peerMeta, realIP, reqStart) if err != nil { log.WithContext(ctx).Debugf("error while syncing peer %s: %v", peerKey.String(), err) s.syncSem.Add(-1) return mapError(ctx, err) } - streamStartTime := time.Now().UTC() - err = s.sendInitialSync(ctx, peerKey, peer, netMap, postureChecks, srv, dnsFwdPort) if err != nil { log.WithContext(ctx).Debugf("error while sending initial sync for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -321,7 +319,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S if err != nil { log.WithContext(ctx).Debugf("error while notify peer connected for %s: %v", peerKey.String(), err) s.syncSem.Add(-1) - s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, streamStartTime) + s.cancelPeerRoutinesWithoutLock(ctx, accountID, peer, reqStart) return err } @@ -338,7 +336,7 @@ func (s *Server) Sync(req *proto.EncryptedMessage, srv proto.ManagementService_S s.syncSem.Add(-1) - return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, streamStartTime) + return s.handleUpdates(ctx, accountID, peerKey, peer, updates, srv, reqStart) } func (s *Server) handleHandshake(ctx context.Context, srv proto.ManagementService_JobServer) (wgtypes.Key, error) { diff --git a/management/server/account.go b/management/server/account.go index 4f53415f5..a9f59773a 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -1670,13 +1670,13 @@ func domainIsUpToDate(domain string, domainCategory string, userAuth auth.UserAu return domainCategory == types.PrivateCategory || userAuth.DomainCategory != types.PrivateCategory || domain != userAuth.Domain } -func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *DefaultAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { peer, netMap, postureChecks, dnsfwdPort, err := am.SyncPeer(ctx, types.PeerSync{WireGuardPubKey: peerPubKey, Meta: meta}, accountID) if err != nil { return nil, nil, nil, 0, fmt.Errorf("error syncing peer: %w", err) } - err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, true, realIP, accountID, syncTime) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as connected %s %v", peerPubKey, err) } @@ -1697,7 +1697,7 @@ func (am *DefaultAccountManager) OnPeerDisconnected(ctx context.Context, account return nil } - err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID) + err = am.MarkPeerConnected(ctx, peerPubKey, false, nil, accountID, time.Now().UTC()) if err != nil { log.WithContext(ctx).Warnf("failed marking peer as disconnected %s %v", peerPubKey, err) } diff --git a/management/server/account/manager.go b/management/server/account/manager.go index eed7739da..1d25b0af7 100644 --- a/management/server/account/manager.go +++ b/management/server/account/manager.go @@ -58,7 +58,7 @@ type Manager interface { GetUserFromUserAuth(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsers(ctx context.Context, accountID string) ([]*types.User, error) GetPeers(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error + MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error DeletePeer(ctx context.Context, accountID, peerID, userID string) error UpdatePeer(ctx context.Context, accountID, userID string, peer *nbpeer.Peer) (*nbpeer.Peer, error) UpdatePeerIP(ctx context.Context, accountID, userID, peerID string, newIP netip.Addr) error @@ -114,7 +114,7 @@ type Manager interface { UpdateIntegratedValidator(ctx context.Context, accountID, userID, validator string, groups []string) error GroupValidation(ctx context.Context, accountId string, groups []string) (bool, error) GetValidatedPeers(ctx context.Context, accountID string) (map[string]struct{}, map[string]string, error) - SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) OnPeerDisconnected(ctx context.Context, accountID string, peerPubKey string, streamStartTime time.Time) error SyncPeerMeta(ctx context.Context, peerPubKey string, meta nbpeer.PeerSystemMeta) error FindExistingPostureCheck(accountID string, checks *posture.ChecksDefinition) (*posture.Checks, error) diff --git a/management/server/account_test.go b/management/server/account_test.go index f3d98916c..443e6344e 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -1881,7 +1881,7 @@ func TestDefaultAccountManager_UpdatePeer_PeerLoginExpiration(t *testing.T) { accountID, err := manager.GetAccountIDByUserID(context.Background(), auth.UserAuth{UserId: userID}) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") _, err = manager.UpdateAccountSettings(context.Background(), accountID, userID, &types.Settings{ @@ -1952,7 +1952,7 @@ func TestDefaultAccountManager_MarkPeerConnected_PeerLoginExpiration(t *testing. require.NoError(t, err, "unable to get the account") // when we mark peer as connected, the peer login expiration routine should trigger - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") failed := waitTimeout(wg, time.Second) @@ -1979,7 +1979,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.NoError(t, err, "unable to add peer") t.Run("disconnect peer when streamStartTime is after LastSeen", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -1997,7 +1997,7 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { }) t.Run("skip disconnect when LastSeen is after streamStartTime (zombie stream protection)", func(t *testing.T) { - err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) @@ -2014,6 +2014,27 @@ func TestDefaultAccountManager_OnPeerDisconnected_LastSeenCheck(t *testing.T) { require.True(t, peer.Status.Connected, "peer should remain connected because LastSeen > streamStartTime (zombie stream protection)") }) + + t.Run("skip stale connect when peer already has newer LastSeen (blocked goroutine protection)", func(t *testing.T) { + node2SyncTime := time.Now().UTC() + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node2SyncTime) + require.NoError(t, err, "node 2 should connect peer") + + peer, err := manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), "LastSeen should be node2SyncTime") + + node1StaleSyncTime := node2SyncTime.Add(-1 * time.Minute) + err = manager.MarkPeerConnected(context.Background(), peerPubKey, true, nil, accountID, node1StaleSyncTime) + require.NoError(t, err, "stale connect should not return error") + + peer, err = manager.Store.GetPeerByPeerPubKey(context.Background(), store.LockingStrengthNone, peerPubKey) + require.NoError(t, err) + require.True(t, peer.Status.Connected, "peer should still be connected") + require.Equal(t, node2SyncTime.Unix(), peer.Status.LastSeen.Unix(), + "LastSeen should NOT be overwritten by stale syncTime from blocked goroutine") + }) } func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *testing.T) { @@ -2038,7 +2059,7 @@ func TestDefaultAccountManager_UpdateAccountSettings_PeerLoginExpiration(t *test account, err := manager.Store.GetAccount(context.Background(), accountID) require.NoError(t, err, "unable to get the account") - err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID) + err = manager.MarkPeerConnected(context.Background(), key.PublicKey().String(), true, nil, accountID, time.Now().UTC()) require.NoError(t, err, "unable to mark peer connected") wg := &sync.WaitGroup{} @@ -3231,7 +3252,7 @@ func BenchmarkSyncAndMarkPeer(b *testing.B) { b.ResetTimer() start := time.Now() for i := 0; i < b.N; i++ { - _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}) + _, _, _, _, err := manager.SyncAndMarkPeer(context.Background(), account.Id, account.Peers["peer-1"].Key, nbpeer.PeerSystemMeta{Hostname: strconv.Itoa(i)}, net.IP{1, 1, 1, 1}, time.Now().UTC()) assert.NoError(b, err) } diff --git a/management/server/mock_server/account_mock.go b/management/server/mock_server/account_mock.go index a4754d180..8471d0a94 100644 --- a/management/server/mock_server/account_mock.go +++ b/management/server/mock_server/account_mock.go @@ -37,8 +37,8 @@ type MockAccountManager struct { GetUserFromUserAuthFunc func(ctx context.Context, userAuth auth.UserAuth) (*types.User, error) ListUsersFunc func(ctx context.Context, accountID string) ([]*types.User, error) GetPeersFunc func(ctx context.Context, accountID, userID, nameFilter, ipFilter string) ([]*nbpeer.Peer, error) - MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP) error - SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) + MarkPeerConnectedFunc func(ctx context.Context, peerKey string, connected bool, realIP net.IP, syncTime time.Time) error + SyncAndMarkPeerFunc func(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) DeletePeerFunc func(ctx context.Context, accountID, peerKey, userID string) error GetNetworkMapFunc func(ctx context.Context, peerKey string) (*types.NetworkMap, error) GetPeerNetworkFunc func(ctx context.Context, peerKey string) (*types.Network, error) @@ -214,9 +214,9 @@ func (am *MockAccountManager) DeleteSetupKey(ctx context.Context, accountID, use return status.Errorf(codes.Unimplemented, "method DeleteSetupKey is not implemented") } -func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { +func (am *MockAccountManager) SyncAndMarkPeer(ctx context.Context, accountID string, peerPubKey string, meta nbpeer.PeerSystemMeta, realIP net.IP, syncTime time.Time) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) { if am.SyncAndMarkPeerFunc != nil { - return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP) + return am.SyncAndMarkPeerFunc(ctx, accountID, peerPubKey, meta, realIP, syncTime) } return nil, nil, nil, 0, status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } @@ -322,9 +322,9 @@ func (am *MockAccountManager) GetAccountIDByUserID(ctx context.Context, userAuth } // MarkPeerConnected mock implementation of MarkPeerConnected from server.AccountManager interface -func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string) error { +func (am *MockAccountManager) MarkPeerConnected(ctx context.Context, peerKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { if am.MarkPeerConnectedFunc != nil { - return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP) + return am.MarkPeerConnectedFunc(ctx, peerKey, connected, realIP, syncTime) } return status.Errorf(codes.Unimplemented, "method MarkPeerConnected is not implemented") } diff --git a/management/server/peer.go b/management/server/peer.go index ab72d3051..a4bdc784d 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -103,11 +103,13 @@ func (am *DefaultAccountManager) getUserAccessiblePeers(ctx context.Context, acc } // MarkPeerConnected marks peer as connected (true) or disconnected (false) -func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string) error { +// syncTime is used as the LastSeen timestamp and for stale request detection +func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubKey string, connected bool, realIP net.IP, accountID string, syncTime time.Time) error { var peer *nbpeer.Peer var settings *types.Settings var expired bool var err error + var skipped bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { peer, err = transaction.GetPeerByPeerPubKey(ctx, store.LockingStrengthUpdate, peerPubKey) @@ -115,9 +117,19 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return err } - expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID) + if connected && !syncTime.After(peer.Status.LastSeen) { + log.WithContext(ctx).Tracef("peer %s has newer activity (lastSeen=%s >= syncTime=%s), skipping connect", + peer.ID, peer.Status.LastSeen.Format(time.RFC3339), syncTime.Format(time.RFC3339)) + skipped = true + return nil + } + + expired, err = updatePeerStatusAndLocation(ctx, am.geo, transaction, peer, connected, realIP, accountID, syncTime) return err }) + if skipped { + return nil + } if err != nil { return err } @@ -147,10 +159,10 @@ func (am *DefaultAccountManager) MarkPeerConnected(ctx context.Context, peerPubK return nil } -func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string) (bool, error) { +func updatePeerStatusAndLocation(ctx context.Context, geo geolocation.Geolocation, transaction store.Store, peer *nbpeer.Peer, connected bool, realIP net.IP, accountID string, syncTime time.Time) (bool, error) { oldStatus := peer.Status.Copy() newStatus := oldStatus - newStatus.LastSeen = time.Now().UTC() + newStatus.LastSeen = syncTime newStatus.Connected = connected // whenever peer got connected that means that it logged in successfully if newStatus.Connected { From 3be16d19a0d5167d2ceeaac321b987b17f3120f4 Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 6 Feb 2026 19:47:38 +0100 Subject: [PATCH 14/15] [management] Feature/grpc debounce msgtype (#5239) * Add gRPC update debouncing mechanism Implements backpressure handling for peer network map updates to efficiently handle rapid changes. First update is sent immediately, subsequent rapid updates are coalesced, ensuring only the latest update is sent after a 1-second quiet period. * Enhance unit test to verify peer count synchronization with debouncing and timeout handling * Debounce based on type * Refactor test to validate timer restart after pending update dispatch * Simplify timer reset for Go 1.23+ automatic channel draining Remove manual channel drain in resetTimer() since Go 1.23+ automatically drains the timer channel when Stop() returns false, making the select-case pattern unnecessary. --- .../network_map/controller/controller.go | 11 +- .../update_channel/updatechannel_test.go | 22 +- .../controllers/network_map/update_message.go | 15 +- management/internals/shared/grpc/server.go | 33 +- management/internals/shared/grpc/token_mgr.go | 10 +- .../internals/shared/grpc/update_debouncer.go | 103 +++ .../shared/grpc/update_debouncer_test.go | 587 ++++++++++++++++++ management/server/management_test.go | 59 +- 8 files changed, 818 insertions(+), 22 deletions(-) create mode 100644 management/internals/shared/grpc/update_debouncer.go create mode 100644 management/internals/shared/grpc/update_debouncer_test.go diff --git a/management/internals/controllers/network_map/controller/controller.go b/management/internals/controllers/network_map/controller/controller.go index 5ae64e9f1..3e28e1380 100644 --- a/management/internals/controllers/network_map/controller/controller.go +++ b/management/internals/controllers/network_map/controller/controller.go @@ -247,7 +247,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, p, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSetting, maps.Keys(peerGroups), dnsFwdPort) c.metrics.CountToSyncResponseDuration(time.Since(start)) - c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, p.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) }(peer) } @@ -370,7 +373,10 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe dnsFwdPort := computeForwarderPort(maps.Values(account.Peers), network_map.DnsForwarderPortMinVersion) update := grpc.ToSyncResponse(ctx, nil, c.config.HttpConfig, c.config.DeviceAuthorizationFlow, peer, nil, nil, remotePeerNetworkMap, dnsDomain, postureChecks, dnsCache, account.Settings, extraSettings, maps.Keys(peerGroups), dnsFwdPort) - c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{Update: update}) + c.peersUpdateManager.SendUpdate(ctx, peer.ID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeNetworkMap, + }) return nil } @@ -778,6 +784,7 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI }, }, }, + MessageType: network_map.MessageTypeNetworkMap, }) c.peersUpdateManager.CloseChannel(ctx, peerID) diff --git a/management/internals/controllers/network_map/update_channel/updatechannel_test.go b/management/internals/controllers/network_map/update_channel/updatechannel_test.go index afc1e2c32..c73baf81f 100644 --- a/management/internals/controllers/network_map/update_channel/updatechannel_test.go +++ b/management/internals/controllers/network_map/update_channel/updatechannel_test.go @@ -25,11 +25,14 @@ func TestCreateChannel(t *testing.T) { func TestSendUpdate(t *testing.T) { peer := "test-sendupdate" peersUpdater := NewPeersUpdateManager(nil) - update1 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 0, + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 0, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } _ = peersUpdater.CreateChannel(context.Background(), peer) if _, ok := peersUpdater.peerChannels[peer]; !ok { t.Error("Error creating the channel") @@ -45,11 +48,14 @@ func TestSendUpdate(t *testing.T) { peersUpdater.SendUpdate(context.Background(), peer, update1) } - update2 := &network_map.UpdateMessage{Update: &proto.SyncResponse{ - NetworkMap: &proto.NetworkMap{ - Serial: 10, + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: 10, + }, }, - }} + MessageType: network_map.MessageTypeNetworkMap, + } peersUpdater.SendUpdate(context.Background(), peer, update2) timeout := time.After(5 * time.Second) diff --git a/management/internals/controllers/network_map/update_message.go b/management/internals/controllers/network_map/update_message.go index 33643bcbd..0ffddf8b2 100644 --- a/management/internals/controllers/network_map/update_message.go +++ b/management/internals/controllers/network_map/update_message.go @@ -4,6 +4,19 @@ import ( "github.com/netbirdio/netbird/shared/management/proto" ) +// MessageType indicates the type of update message for debouncing strategy +type MessageType int + +const ( + // MessageTypeNetworkMap represents network map updates (peers, routes, DNS, firewall) + // These updates can be safely debounced - only the latest state matters + MessageTypeNetworkMap MessageType = iota + // MessageTypeControlConfig represents control/config updates (tokens, peer expiration) + // These updates should not be dropped as they contain time-sensitive information + MessageTypeControlConfig +) + type UpdateMessage struct { - Update *proto.SyncResponse + Update *proto.SyncResponse + MessageType MessageType } diff --git a/management/internals/shared/grpc/server.go b/management/internals/shared/grpc/server.go index 98c68ebda..ff9d7ea05 100644 --- a/management/internals/shared/grpc/server.go +++ b/management/internals/shared/grpc/server.go @@ -404,11 +404,20 @@ func (s *Server) sendJobsLoop(ctx context.Context, accountID string, peerKey wgt } // handleUpdates sends updates to the connected peer until the updates channel is closed. +// It implements a backpressure mechanism that sends the first update immediately, +// then debounces subsequent rapid updates, ensuring only the latest update is sent +// after a quiet period. func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wgtypes.Key, peer *nbpeer.Peer, updates chan *network_map.UpdateMessage, srv proto.ManagementService_SyncServer, streamStartTime time.Time) error { log.WithContext(ctx).Tracef("starting to handle updates for peer %s", peerKey.String()) + + // Create a debouncer for this peer connection + debouncer := NewUpdateDebouncer(1000 * time.Millisecond) + defer debouncer.Stop() + for { select { // condition when there are some updates + // todo set the updates channel size to 1 case update, open := <-updates: if s.appMetrics != nil { s.appMetrics.GRPCMetrics().UpdateChannelQueueLength(len(updates) + 1) @@ -419,10 +428,28 @@ func (s *Server) handleUpdates(ctx context.Context, accountID string, peerKey wg s.cancelPeerRoutines(ctx, accountID, peer, streamStartTime) return nil } + log.WithContext(ctx).Debugf("received an update for peer %s", peerKey.String()) - if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { - log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) - return err + if debouncer.ProcessUpdate(update) { + // Send immediately (first update or after quiet period) + if err := s.sendUpdate(ctx, accountID, peerKey, peer, update, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } + } + + // Timer expired - quiet period reached, send pending updates if any + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + continue + } + log.WithContext(ctx).Debugf("sending %d debounced update(s) for peer %s", len(pendingUpdates), peerKey.String()) + for _, pendingUpdate := range pendingUpdates { + if err := s.sendUpdate(ctx, accountID, peerKey, peer, pendingUpdate, srv, streamStartTime); err != nil { + log.WithContext(ctx).Debugf("error while sending an update to peer %s: %v", peerKey.String(), err) + return err + } } // condition when client <-> server connection has been terminated diff --git a/management/internals/shared/grpc/token_mgr.go b/management/internals/shared/grpc/token_mgr.go index ccb32202f..65e58ad41 100644 --- a/management/internals/shared/grpc/token_mgr.go +++ b/management/internals/shared/grpc/token_mgr.go @@ -242,7 +242,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewTURNAndRelayTokens(ctx context.Cont m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new TURN credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, accountID, peerID string) { @@ -266,7 +269,10 @@ func (m *TimeBasedAuthSecretsManager) pushNewRelayTokens(ctx context.Context, ac m.extendNetbirdConfig(ctx, peerID, accountID, update) log.WithContext(ctx).Debugf("sending new relay credentials to peer %s", peerID) - m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{Update: update}) + m.updateManager.SendUpdate(ctx, peerID, &network_map.UpdateMessage{ + Update: update, + MessageType: network_map.MessageTypeControlConfig, + }) } func (m *TimeBasedAuthSecretsManager) extendNetbirdConfig(ctx context.Context, peerID, accountID string, update *proto.SyncResponse) { diff --git a/management/internals/shared/grpc/update_debouncer.go b/management/internals/shared/grpc/update_debouncer.go new file mode 100644 index 000000000..8af9c2656 --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer.go @@ -0,0 +1,103 @@ +package grpc + +import ( + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" +) + +// UpdateDebouncer implements a backpressure mechanism that: +// - Sends the first update immediately +// - Coalesces rapid subsequent network map updates (only latest matters) +// - Queues control/config updates (all must be delivered) +// - Preserves the order of messages (important for control configs between network maps) +// - Ensures pending updates are sent after a quiet period +type UpdateDebouncer struct { + debounceInterval time.Duration + timer *time.Timer + pendingUpdates []*network_map.UpdateMessage // Queue that preserves order + timerC <-chan time.Time +} + +// NewUpdateDebouncer creates a new debouncer with the specified interval +func NewUpdateDebouncer(interval time.Duration) *UpdateDebouncer { + return &UpdateDebouncer{ + debounceInterval: interval, + } +} + +// ProcessUpdate handles an incoming update and returns whether it should be sent immediately +func (d *UpdateDebouncer) ProcessUpdate(update *network_map.UpdateMessage) bool { + if d.timer == nil { + // No active debounce timer, signal to send immediately + // and start the debounce period + d.startTimer() + return true + } + + // Already in debounce period, accumulate this update preserving order + // Check if we should coalesce with the last pending update + if len(d.pendingUpdates) > 0 && + update.MessageType == network_map.MessageTypeNetworkMap && + d.pendingUpdates[len(d.pendingUpdates)-1].MessageType == network_map.MessageTypeNetworkMap { + // Replace the last network map with this one (coalesce consecutive network maps) + d.pendingUpdates[len(d.pendingUpdates)-1] = update + } else { + // Append to the queue (preserves order for control configs and non-consecutive network maps) + d.pendingUpdates = append(d.pendingUpdates, update) + } + d.resetTimer() + return false +} + +// TimerChannel returns the timer channel for select statements +func (d *UpdateDebouncer) TimerChannel() <-chan time.Time { + if d.timer == nil { + return nil + } + return d.timerC +} + +// GetPendingUpdates returns and clears all pending updates after timer expiration. +// Updates are returned in the order they were received, with consecutive network maps +// already coalesced to only the latest one. +// If there were pending updates, it restarts the timer to continue debouncing. +// If there were no pending updates, it clears the timer (true quiet period). +func (d *UpdateDebouncer) GetPendingUpdates() []*network_map.UpdateMessage { + updates := d.pendingUpdates + d.pendingUpdates = nil + + if len(updates) > 0 { + // There were pending updates, so updates are still coming rapidly + // Restart the timer to continue debouncing mode + if d.timer != nil { + d.timer.Reset(d.debounceInterval) + } + } else { + // No pending updates means true quiet period - return to immediate mode + d.timer = nil + d.timerC = nil + } + + return updates +} + +// Stop stops the debouncer and cleans up resources +func (d *UpdateDebouncer) Stop() { + if d.timer != nil { + d.timer.Stop() + d.timer = nil + d.timerC = nil + } + d.pendingUpdates = nil +} + +func (d *UpdateDebouncer) startTimer() { + d.timer = time.NewTimer(d.debounceInterval) + d.timerC = d.timer.C +} + +func (d *UpdateDebouncer) resetTimer() { + d.timer.Stop() + d.timer.Reset(d.debounceInterval) +} diff --git a/management/internals/shared/grpc/update_debouncer_test.go b/management/internals/shared/grpc/update_debouncer_test.go new file mode 100644 index 000000000..075994a2d --- /dev/null +++ b/management/internals/shared/grpc/update_debouncer_test.go @@ -0,0 +1,587 @@ +package grpc + +import ( + "testing" + "time" + + "github.com/netbirdio/netbird/management/internals/controllers/network_map" + "github.com/netbirdio/netbird/shared/management/proto" +) + +func TestUpdateDebouncer_FirstUpdateSentImmediately(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + shouldSend := debouncer.ProcessUpdate(update) + + if !shouldSend { + t.Error("First update should be sent immediately") + } + + if debouncer.TimerChannel() == nil { + t.Error("Timer should be started after first update") + } +} + +func TestUpdateDebouncer_RapidUpdatesCoalesced(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update should be sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Rapid subsequent updates should be coalesced + if debouncer.ProcessUpdate(update2) { + t.Error("Second rapid update should not be sent immediately") + } + + if debouncer.ProcessUpdate(update3) { + t.Error("Third rapid update should not be sent immediately") + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_LastUpdateAlwaysSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Send second update within debounce period + debouncer.ProcessUpdate(update2) + + // Wait for timer + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update2 { + t.Error("Should get the last update") + } + if pendingUpdates[0] == update1 { + t.Error("Should not get the first update") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerResetOnNewUpdate(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + debouncer.ProcessUpdate(update1) + + // Wait a bit, but not the full debounce period + time.Sleep(30 * time.Millisecond) + + // Send second update - should reset timer + debouncer.ProcessUpdate(update2) + + // Wait a bit more + time.Sleep(30 * time.Millisecond) + + // Send third update - should reset timer again + debouncer.ProcessUpdate(update3) + + // Now wait for the timer (should fire after last update's reset) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != update3 { + t.Error("Should get the last update (update3)") + } + // Timer should be restarted since there was a pending update + if debouncer.TimerChannel() == nil { + t.Error("Timer should be restarted after sending pending update") + } + case <-time.After(150 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_TimerRestartsAfterPendingUpdateSent(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update3 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + debouncer.ProcessUpdate(update1) + + // Second update coalesced + debouncer.ProcessUpdate(update2) + + // Wait for timer to expire + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) == 0 { + t.Fatal("Should have pending update") + } + + // After sending pending update, timer is restarted, so next update is NOT immediate + if debouncer.ProcessUpdate(update3) { + t.Error("Update after debounced send should not be sent immediately (timer restarted)") + } + + // Wait for the restarted timer and verify update3 is pending + select { + case <-debouncer.TimerChannel(): + finalUpdates := debouncer.GetPendingUpdates() + if len(finalUpdates) != 1 || finalUpdates[0] != update3 { + t.Error("Should get update3 as pending") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired for restarted timer") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_StopCleansUp(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send update to start timer + debouncer.ProcessUpdate(update) + + // Stop should clean up + debouncer.Stop() + + // Multiple stops should be safe + debouncer.Stop() +} + +func TestUpdateDebouncer_HighFrequencyUpdates(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate high-frequency updates + var lastUpdate *network_map.UpdateMessage + sentImmediately := 0 + for i := 0; i < 100; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + lastUpdate = update + if debouncer.ProcessUpdate(update) { + sentImmediately++ + } + time.Sleep(1 * time.Millisecond) // Very rapid updates + } + + // Only first update should be sent immediately + if sentImmediately != 1 { + t.Errorf("Expected only 1 update sent immediately, got %d", sentImmediately) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0] != lastUpdate { + t.Error("Should get the very last update") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_NoUpdatesAfterFirst(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // Send first update + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + + // Wait for timer to expire with no additional updates (true quiet period) + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates") + } + // After true quiet period, timer should be cleared + if debouncer.TimerChannel() != nil { + t.Error("Timer should be cleared after quiet period") + } + case <-time.After(100 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_IntermediateUpdatesDropped(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + updates := make([]*network_map.UpdateMessage, 5) + for i := range updates { + updates[i] = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + } + + // First update sent immediately + debouncer.ProcessUpdate(updates[0]) + + // Send updates 1, 2, 3, 4 rapidly - only last one should remain pending + debouncer.ProcessUpdate(updates[1]) + debouncer.ProcessUpdate(updates[2]) + debouncer.ProcessUpdate(updates[3]) + debouncer.ProcessUpdate(updates[4]) + + // Wait for debounce + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 1 { + t.Errorf("Should get exactly 1 pending update, got %d", len(pendingUpdates)) + } + if pendingUpdates[0].Update.NetworkMap.Serial != 4 { + t.Errorf("Expected only the last update (serial 4), got serial %d", pendingUpdates[0].Update.NetworkMap.Serial) + } +} + +func TestUpdateDebouncer_TrueQuietPeriodResetsToImmediateMode(t *testing.T) { + debouncer := NewUpdateDebouncer(30 * time.Millisecond) + defer debouncer.Stop() + + update1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + update2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{}, + MessageType: network_map.MessageTypeNetworkMap, + } + + // First update sent immediately + if !debouncer.ProcessUpdate(update1) { + t.Error("First update should be sent immediately") + } + + // Wait for timer without sending any more updates (true quiet period) + <-debouncer.TimerChannel() + pendingUpdates := debouncer.GetPendingUpdates() + + if len(pendingUpdates) != 0 { + t.Error("Should have no pending updates during quiet period") + } + + // After true quiet period, next update should be sent immediately + if !debouncer.ProcessUpdate(update2) { + t.Error("Update after true quiet period should be sent immediately") + } +} + +func TestUpdateDebouncer_ContinuousHighFrequencyStaysInDebounceMode(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate continuous high-frequency updates + for i := 0; i < 10; i++ { + update := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{ + NetworkMap: &proto.NetworkMap{ + Serial: uint64(i), + }, + }, + MessageType: network_map.MessageTypeNetworkMap, + } + + if i == 0 { + // First one sent immediately + if !debouncer.ProcessUpdate(update) { + t.Error("First update should be sent immediately") + } + } else { + // All others should be coalesced (not sent immediately) + if debouncer.ProcessUpdate(update) { + t.Errorf("Update %d should not be sent immediately", i) + } + } + + // Wait a bit but send next update before debounce expires + time.Sleep(20 * time.Millisecond) + } + + // Now wait for final debounce + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + if len(pendingUpdates) == 0 { + t.Fatal("Should have the last update pending") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 9 { + t.Errorf("Expected serial 9, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_ControlConfigMessagesQueued(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + tokenUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate) + + // Send multiple control config updates - they should all be queued + debouncer.ProcessUpdate(tokenUpdate1) + debouncer.ProcessUpdate(tokenUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get both control config updates + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 control config updates, got %d", len(pendingUpdates)) + } + // Control configs should come first + if pendingUpdates[0] != tokenUpdate1 { + t.Error("First pending update should be tokenUpdate1") + } + if pendingUpdates[1] != tokenUpdate2 { + t.Error("Second pending update should be tokenUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_MixedMessageTypes(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + netmapUpdate1 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 1}}, + MessageType: network_map.MessageTypeNetworkMap, + } + netmapUpdate2 := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 2}}, + MessageType: network_map.MessageTypeNetworkMap, + } + tokenUpdate := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + + // First update sent immediately + debouncer.ProcessUpdate(netmapUpdate1) + + // Send token update and network map update + debouncer.ProcessUpdate(tokenUpdate) + debouncer.ProcessUpdate(netmapUpdate2) + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get 2 updates in order: token, then network map + if len(pendingUpdates) != 2 { + t.Errorf("Expected 2 pending updates, got %d", len(pendingUpdates)) + } + // Token update should come first (preserves order) + if pendingUpdates[0] != tokenUpdate { + t.Error("First pending update should be tokenUpdate") + } + // Network map update should come second + if pendingUpdates[1] != netmapUpdate2 { + t.Error("Second pending update should be netmapUpdate2") + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} + +func TestUpdateDebouncer_OrderPreservation(t *testing.T) { + debouncer := NewUpdateDebouncer(50 * time.Millisecond) + defer debouncer.Stop() + + // Simulate: 50 network maps -> 1 control config -> 50 network maps + // Expected result: 3 messages (netmap, controlConfig, netmap) + + // Send first network map immediately + firstNetmap := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: 0}}, + MessageType: network_map.MessageTypeNetworkMap, + } + if !debouncer.ProcessUpdate(firstNetmap) { + t.Error("First update should be sent immediately") + } + + // Send 49 more network maps (will be coalesced to last one) + var lastNetmapBatch1 *network_map.UpdateMessage + for i := 1; i < 50; i++ { + lastNetmapBatch1 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch1) + } + + // Send 1 control config + controlConfig := &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetbirdConfig: &proto.NetbirdConfig{}}, + MessageType: network_map.MessageTypeControlConfig, + } + debouncer.ProcessUpdate(controlConfig) + + // Send 50 more network maps (will be coalesced to last one) + var lastNetmapBatch2 *network_map.UpdateMessage + for i := 50; i < 100; i++ { + lastNetmapBatch2 = &network_map.UpdateMessage{ + Update: &proto.SyncResponse{NetworkMap: &proto.NetworkMap{Serial: uint64(i)}}, + MessageType: network_map.MessageTypeNetworkMap, + } + debouncer.ProcessUpdate(lastNetmapBatch2) + } + + // Wait for debounce period + select { + case <-debouncer.TimerChannel(): + pendingUpdates := debouncer.GetPendingUpdates() + // Should get exactly 3 updates: netmap, controlConfig, netmap + if len(pendingUpdates) != 3 { + t.Errorf("Expected 3 pending updates, got %d", len(pendingUpdates)) + } + // First should be the last netmap from batch 1 + if pendingUpdates[0] != lastNetmapBatch1 { + t.Error("First pending update should be last netmap from batch 1") + } + if pendingUpdates[0].Update.NetworkMap.Serial != 49 { + t.Errorf("Expected serial 49, got %d", pendingUpdates[0].Update.NetworkMap.Serial) + } + // Second should be the control config + if pendingUpdates[1] != controlConfig { + t.Error("Second pending update should be control config") + } + // Third should be the last netmap from batch 2 + if pendingUpdates[2] != lastNetmapBatch2 { + t.Error("Third pending update should be last netmap from batch 2") + } + if pendingUpdates[2].Update.NetworkMap.Serial != 99 { + t.Errorf("Expected serial 99, got %d", pendingUpdates[2].Update.NetworkMap.Serial) + } + case <-time.After(200 * time.Millisecond): + t.Error("Timer should have fired") + } +} diff --git a/management/server/management_test.go b/management/server/management_test.go index 0864baadf..de02855bf 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -610,6 +610,7 @@ func TestSync10PeersGetUpdates(t *testing.T) { initialPeers := 10 additionalPeers := 10 + expectedPeerCount := initialPeers + additionalPeers - 1 // -1 because peer doesn't see itself var peers []wgtypes.Key for i := 0; i < initialPeers; i++ { @@ -618,8 +619,19 @@ func TestSync10PeersGetUpdates(t *testing.T) { peers = append(peers, key) } + // Track the maximum peer count each peer has seen + type peerState struct { + mu sync.Mutex + maxPeerCount int + done bool + } + peerStates := make(map[string]*peerState) + for _, pk := range peers { + peerStates[pk.PublicKey().String()] = &peerState{} + } + var wg sync.WaitGroup - wg.Add(initialPeers + initialPeers*additionalPeers) + wg.Add(initialPeers) // One completion per initial peer var syncClients []mgmtProto.ManagementService_SyncClient for _, pk := range peers { @@ -643,6 +655,9 @@ func TestSync10PeersGetUpdates(t *testing.T) { syncClients = append(syncClients, s) go func(pk wgtypes.Key, syncStream mgmtProto.ManagementService_SyncClient) { + pubKey := pk.PublicKey().String() + state := peerStates[pubKey] + for { encMsg := &mgmtProto.EncryptedMessage{} err := syncStream.RecvMsg(encMsg) @@ -651,19 +666,28 @@ func TestSync10PeersGetUpdates(t *testing.T) { } decryptedBytes, decErr := encryption.Decrypt(encMsg.Body, ts.serverPubKey, pk) if decErr != nil { - t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pk.PublicKey().String(), decErr) + t.Errorf("failed to decrypt SyncResponse for peer %s: %v", pubKey, decErr) return } resp := &mgmtProto.SyncResponse{} umErr := pb.Unmarshal(decryptedBytes, resp) if umErr != nil { - t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pk.PublicKey().String(), umErr) + t.Errorf("failed to unmarshal SyncResponse for peer %s: %v", pubKey, umErr) return } - // We only count if there's a new peer update - if len(resp.GetRemotePeers()) > 0 { + + // Track the maximum peer count seen (due to debouncing, updates are coalesced) + peerCount := len(resp.GetRemotePeers()) + state.mu.Lock() + if peerCount > state.maxPeerCount { + state.maxPeerCount = peerCount + } + // Signal completion when this peer has seen all expected peers + if !state.done && state.maxPeerCount >= expectedPeerCount { + state.done = true wg.Done() } + state.mu.Unlock() } }(pk, s) } @@ -677,7 +701,30 @@ func TestSync10PeersGetUpdates(t *testing.T) { time.Sleep(time.Duration(n) * time.Millisecond) } - wg.Wait() + // Wait for debouncer to flush final updates (debounce interval is 1000ms) + time.Sleep(1500 * time.Millisecond) + + // Wait with timeout + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Success - all peers received expected peer count + case <-time.After(5 * time.Second): + // Timeout - report which peers didn't receive all updates + t.Error("Timeout waiting for all peers to receive updates") + for pubKey, state := range peerStates { + state.mu.Lock() + if state.maxPeerCount < expectedPeerCount { + t.Errorf("Peer %s only saw %d peers, expected %d", pubKey, state.maxPeerCount, expectedPeerCount) + } + state.mu.Unlock() + } + } for _, sc := range syncClients { err := sc.CloseSend() From 7bc85107eb6f76e250b874419092c3934ee8deff Mon Sep 17 00:00:00 2001 From: Zoltan Papp Date: Fri, 6 Feb 2026 19:50:48 +0100 Subject: [PATCH 15/15] Adds timing measurement to handleSync to help diagnose sync performance issues (#5228) --- client/internal/engine.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/client/internal/engine.go b/client/internal/engine.go index 597ac7c2d..4dbd5f45e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -828,6 +828,10 @@ func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdate } func (e *Engine) handleSync(update *mgmProto.SyncResponse) error { + started := time.Now() + defer func() { + log.Infof("sync finished in %s", time.Since(started)) + }() e.syncMsgMux.Lock() defer e.syncMsgMux.Unlock()