diff --git a/Makefile b/Makefile index 43379e115..5d52b94fa 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint $(GOLANGCI_LINT): @echo "Installing golangci-lint..." @mkdir -p ./bin - @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest + @GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest # Lint only changed files (fast, for pre-push) lint: $(GOLANGCI_LINT) diff --git a/client/android/client.go b/client/android/client.go index d35bf4279..37e17a363 100644 --- a/client/android/client.go +++ b/client/android/client.go @@ -8,6 +8,7 @@ import ( "os" "slices" "sync" + "time" "golang.org/x/exp/maps" @@ -15,6 +16,7 @@ import ( "github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/internal" + "github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/listener" "github.com/netbirdio/netbird/client/internal/peer" @@ -26,6 +28,7 @@ import ( "github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/shared/management/domain" + types "github.com/netbirdio/netbird/upload-server/types" ) // ConnectionListener export internal Listener for mobile @@ -68,7 +71,30 @@ type Client struct { uiVersion string networkChangeListener listener.NetworkChangeListener + stateMu sync.RWMutex connectClient *internal.ConnectClient + config *profilemanager.Config + cacheDir string +} + +func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) { + c.stateMu.Lock() + defer c.stateMu.Unlock() + c.config = cfg + c.cacheDir = cacheDir + c.connectClient = cc +} + +func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.config, c.cacheDir, c.connectClient +} + +func (c *Client) getConnectClient() *internal.ConnectClient { + c.stateMu.RLock() + defer c.stateMu.RUnlock() + return c.connectClient } // NewClient instantiate a new Client @@ -93,6 +119,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile) @@ -124,8 +151,9 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot). @@ -135,6 +163,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR cfgFile := platformFiles.ConfigurationFilePath() stateFile := platformFiles.StateFilePath() + cacheDir := platformFiles.CacheDir() log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile) @@ -157,8 +186,9 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR // todo do not throw error in case of cancelled context ctx = internal.CtxInitState(ctx) - c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder) - return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile) + connectClient := internal.NewConnectClient(ctx, cfg, c.recorder) + c.setState(cfg, cacheDir, connectClient) + return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir) } // Stop the internal client and free the resources @@ -173,11 +203,12 @@ func (c *Client) Stop() { } func (c *Client) RenewTun(fd int) error { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { return fmt.Errorf("engine not running") } - e := c.connectClient.Engine() + e := cc.Engine() if e == nil { return fmt.Errorf("engine not initialized") } @@ -185,6 +216,73 @@ func (c *Client) RenewTun(fd int) error { return e.RenewTun(fd) } +// DebugBundle generates a debug bundle, uploads it, and returns the upload key. +// It works both with and without a running engine. +func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) { + cfg, cacheDir, cc := c.stateSnapshot() + + // If the engine hasn't been started, load config from disk + if cfg == nil { + var err error + cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{ + ConfigPath: platformFiles.ConfigurationFilePath(), + }) + if err != nil { + return "", fmt.Errorf("load config: %w", err) + } + cacheDir = platformFiles.CacheDir() + } + + deps := debug.GeneratorDependencies{ + InternalConfig: cfg, + StatusRecorder: c.recorder, + TempDir: cacheDir, + } + + if cc != nil { + resp, err := cc.GetLatestSyncResponse() + if err != nil { + log.Warnf("get latest sync response: %v", err) + } + deps.SyncResponse = resp + + if e := cc.Engine(); e != nil { + if cm := e.GetClientMetrics(); cm != nil { + deps.ClientMetrics = cm + } + } + } + + bundleGenerator := debug.NewBundleGenerator( + deps, + debug.BundleConfig{ + Anonymize: anonymize, + IncludeSystemInfo: true, + }, + ) + + path, err := bundleGenerator.Generate() + if err != nil { + return "", fmt.Errorf("generate debug bundle: %w", err) + } + defer func() { + if err := os.Remove(path); err != nil { + log.Errorf("failed to remove debug bundle file: %v", err) + } + }() + + uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path) + if err != nil { + return "", fmt.Errorf("upload debug bundle: %w", err) + } + + log.Infof("debug bundle uploaded with key %s", key) + return key, nil +} + // SetTraceLogLevel configure the logger to trace level func (c *Client) SetTraceLogLevel() { log.SetLevel(log.TraceLevel) @@ -214,12 +312,13 @@ func (c *Client) PeersList() *PeerInfoArray { } func (c *Client) Networks() *NetworkArray { - if c.connectClient == nil { + cc := c.getConnectClient() + if cc == nil { log.Error("not connected") return nil } - engine := c.connectClient.Engine() + engine := cc.Engine() if engine == nil { log.Error("could not get engine") return nil @@ -300,7 +399,7 @@ func (c *Client) toggleRoute(command routeCommand) error { } func (c *Client) getRouteManager() (routemanager.Manager, error) { - client := c.connectClient + client := c.getConnectClient() if client == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/android/platform_files.go b/client/android/platform_files.go index f0c369750..3be40c0bd 100644 --- a/client/android/platform_files.go +++ b/client/android/platform_files.go @@ -7,4 +7,5 @@ package android type PlatformFiles interface { ConfigurationFilePath() string StateFilePath() string + CacheDir() string } diff --git a/client/cmd/root.go b/client/cmd/root.go index dc21abaca..29d4328a1 100644 --- a/client/cmd/root.go +++ b/client/cmd/root.go @@ -76,6 +76,7 @@ var ( profilesDisabled bool updateSettingsDisabled bool captureEnabled bool + networksDisabled bool rootCmd = &cobra.Command{ Use: "netbird", diff --git a/client/cmd/service.go b/client/cmd/service.go index 5addb644a..56d8a8726 100644 --- a/client/cmd/service.go +++ b/client/cmd/service.go @@ -45,10 +45,13 @@ func init() { serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture") + serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") serviceEnvDesc := `Sets extra environment variables for the service. ` + `You can specify a comma-separated list of KEY=VALUE pairs. ` + + `New keys are merged with previously saved env vars; existing keys are overwritten. ` + + `Use --service-env "" to clear all saved env vars. ` + `E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value` installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc) diff --git a/client/cmd/service_controller.go b/client/cmd/service_controller.go index 0a2b14f35..88121c067 100644 --- a/client/cmd/service_controller.go +++ b/client/cmd/service_controller.go @@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error { } } - serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled) + serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) if err := serverInstance.Start(); err != nil { log.Fatalf("failed to start daemon: %v", err) } diff --git a/client/cmd/service_installer.go b/client/cmd/service_installer.go index e3011b5e9..2d45fa063 100644 --- a/client/cmd/service_installer.go +++ b/client/cmd/service_installer.go @@ -63,6 +63,10 @@ func buildServiceArguments() []string { args = append(args, "--enable-capture") } + if networksDisabled { + args = append(args, "--disable-networks") + } + return args } diff --git a/client/cmd/service_params.go b/client/cmd/service_params.go index 8bb22045e..192e0ac60 100644 --- a/client/cmd/service_params.go +++ b/client/cmd/service_params.go @@ -29,6 +29,7 @@ type serviceParams struct { DisableProfiles bool `json:"disable_profiles,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` EnableCapture bool `json:"enable_capture,omitempty"` + DisableNetworks bool `json:"disable_networks,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` } @@ -80,11 +81,12 @@ func currentServiceParams() *serviceParams { DisableProfiles: profilesDisabled, DisableUpdateSettings: updateSettingsDisabled, EnableCapture: captureEnabled, + DisableNetworks: networksDisabled, } if len(serviceEnvVars) > 0 { parsed, err := parseServiceEnvVars(serviceEnvVars) - if err == nil && len(parsed) > 0 { + if err == nil { params.ServiceEnvVars = parsed } } @@ -148,31 +150,46 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) { captureEnabled = params.EnableCapture } + if !serviceCmd.PersistentFlags().Changed("disable-networks") { + networksDisabled = params.DisableNetworks + } + applyServiceEnvParams(cmd, params) } // applyServiceEnvParams merges saved service environment variables. -// If --service-env was explicitly set, explicit values win on key conflict -// but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set with values, explicit values win on key +// conflict but saved keys not in the explicit set are carried over. +// If --service-env was explicitly set to empty, all saved env vars are cleared. // If --service-env was not set, saved env vars are used entirely. func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) { - if len(params.ServiceEnvVars) == 0 { - return - } - if !cmd.Flags().Changed("service-env") { - // No explicit env vars: rebuild serviceEnvVars from saved params. - serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + if len(params.ServiceEnvVars) > 0 { + // No explicit env vars: rebuild serviceEnvVars from saved params. + serviceEnvVars = envMapToSlice(params.ServiceEnvVars) + } return } - // Explicit env vars were provided: merge saved values underneath. + // Flag was explicitly set: parse what the user provided. explicit, err := parseServiceEnvVars(serviceEnvVars) if err != nil { cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err) return } + // If the user passed an empty value (e.g. --service-env ""), clear all + // saved env vars rather than merging. + if len(explicit) == 0 { + serviceEnvVars = nil + return + } + + if len(params.ServiceEnvVars) == 0 { + return + } + + // Merge saved values underneath explicit ones. merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit)) maps.Copy(merged, params.ServiceEnvVars) maps.Copy(merged, explicit) // explicit wins on conflict diff --git a/client/cmd/service_params_test.go b/client/cmd/service_params_test.go index 0c2f699cd..f338c12f4 100644 --- a/client/cmd/service_params_test.go +++ b/client/cmd/service_params_test.go @@ -327,6 +327,41 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) { assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result) } +func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + cmd := &cobra.Command{} + cmd.Flags().StringSlice("service-env", nil, "") + require.NoError(t, cmd.Flags().Set("service-env", "")) + + saved := &serviceParams{ + ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"}, + } + + applyServiceEnvParams(cmd, saved) + + assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars") +} + +func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) { + origServiceEnvVars := serviceEnvVars + t.Cleanup(func() { serviceEnvVars = origServiceEnvVars }) + + // Simulate --service-env "" which produces [""] in the slice. + serviceEnvVars = []string{""} + + params := currentServiceParams() + + // After parsing, the empty string is skipped, resulting in an empty map. + // The map should still be set (not nil) so it overwrites saved values. + assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil") + assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string") +} + // TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are // referenced in both currentServiceParams() and applyServiceParams(). If a new field is // added to serviceParams but not wired into these functions, this test fails. @@ -501,6 +536,7 @@ func fieldToGlobalVar(field string) string { "DisableProfiles": "profilesDisabled", "DisableUpdateSettings": "updateSettingsDisabled", "EnableCapture": "captureEnabled", + "DisableNetworks": "networksDisabled", "ServiceEnvVars": "serviceEnvVars", } if v, ok := m[field]; ok { diff --git a/client/cmd/testutil_test.go b/client/cmd/testutil_test.go index 5c6926f04..2503d527e 100644 --- a/client/cmd/testutil_test.go +++ b/client/cmd/testutil_test.go @@ -13,6 +13,8 @@ import ( "github.com/netbirdio/management-integrations/integrations" + nbcache "github.com/netbirdio/netbird/management/server/cache" + "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller" "github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel" "github.com/netbirdio/netbird/management/internals/modules/peers" @@ -100,9 +102,16 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp jobManager := job.NewJobManager(nil, store, peersmanager) - iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -113,12 +122,11 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp Return(&types.Settings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } @@ -152,7 +160,7 @@ func startClientDaemon( s := grpc.NewServer() server := client.New(ctx, - "", "", false, false, false) + "", "", false, false, false, false) if err := server.Start(); err != nil { t.Fatal(err) } diff --git a/client/firewall/firewalld/firewalld.go b/client/firewall/firewalld/firewalld.go new file mode 100644 index 000000000..188ea61dd --- /dev/null +++ b/client/firewall/firewalld/firewalld.go @@ -0,0 +1,11 @@ +// Package firewalld integrates with the firewalld daemon so NetBird can place +// its wg interface into firewalld's "trusted" zone. This is required because +// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent +// versions, which returns EPERM to any other process that tries to insert +// rules into them. The workaround mirrors what Tailscale does: let firewalld +// itself add the accept rules to its own chains by trusting the interface. +package firewalld + +// TrustedZone is the firewalld zone name used for interfaces whose traffic +// should bypass firewalld filtering. +const TrustedZone = "trusted" diff --git a/client/firewall/firewalld/firewalld_linux.go b/client/firewall/firewalld/firewalld_linux.go new file mode 100644 index 000000000..924a04b0a --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux.go @@ -0,0 +1,260 @@ +//go:build linux + +package firewalld + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" + "sync" + "time" + + "github.com/godbus/dbus/v5" + log "github.com/sirupsen/logrus" +) + +const ( + dbusDest = "org.fedoraproject.FirewallD1" + dbusPath = "/org/fedoraproject/FirewallD1" + dbusRootIface = "org.fedoraproject.FirewallD1" + dbusZoneIface = "org.fedoraproject.FirewallD1.zone" + + errZoneAlreadySet = "ZONE_ALREADY_SET" + errAlreadyEnabled = "ALREADY_ENABLED" + errUnknownIface = "UNKNOWN_INTERFACE" + errNotEnabled = "NOT_ENABLED" + + // callTimeout bounds each individual DBus or firewall-cmd invocation. + // A fresh context is created for each call so a slow DBus probe can't + // exhaust the deadline before the firewall-cmd fallback gets to run. + callTimeout = 3 * time.Second +) + +var ( + errDBusUnavailable = errors.New("firewalld dbus unavailable") + + // trustLogOnce ensures the "added to trusted zone" message is logged at + // Info level only for the first successful add per process; repeat adds + // from other init paths are quieter. + trustLogOnce sync.Once + + parentCtxMu sync.RWMutex + parentCtx context.Context = context.Background() +) + +// SetParentContext installs a parent context whose cancellation aborts any +// in-flight TrustInterface call. It does not affect UntrustInterface, which +// always uses a fresh Background-rooted timeout so cleanup can still run +// during engine shutdown when the engine context is already cancelled. +func SetParentContext(ctx context.Context) { + parentCtxMu.Lock() + parentCtx = ctx + parentCtxMu.Unlock() +} + +func getParentContext() context.Context { + parentCtxMu.RLock() + defer parentCtxMu.RUnlock() + return parentCtx +} + +// TrustInterface places iface into firewalld's trusted zone if firewalld is +// running. It is idempotent and best-effort: errors are returned so callers +// can log, but a non-running firewalld is not an error. Only the first +// successful call per process logs at Info. Respects the parent context set +// via SetParentContext so startup-time cancellation unblocks it. +func TrustInterface(iface string) error { + parent := getParentContext() + if !isRunning(parent) { + return nil + } + if err := addTrusted(parent, iface); err != nil { + return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err) + } + trustLogOnce.Do(func() { + log.Infof("added %s to firewalld trusted zone", iface) + }) + log.Debugf("firewalld: ensured %s is in trusted zone", iface) + return nil +} + +// UntrustInterface removes iface from firewalld's trusted zone if firewalld +// is running. Idempotent. Uses a Background-rooted timeout so it still runs +// during shutdown after the engine context has been cancelled. +func UntrustInterface(iface string) error { + if !isRunning(context.Background()) { + return nil + } + if err := removeTrusted(context.Background(), iface); err != nil { + return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err) + } + return nil +} + +func newCallContext(parent context.Context) (context.Context, context.CancelFunc) { + return context.WithTimeout(parent, callTimeout) +} + +func isRunning(parent context.Context) bool { + ctx, cancel := newCallContext(parent) + ok, err := isRunningDBus(ctx) + cancel() + if err == nil { + return ok + } + if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) { + ctx, cancel = newCallContext(parent) + defer cancel() + return isRunningCLI(ctx) + } + return false +} + +func addTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := addDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return addCLI(ctx, iface) +} + +func removeTrusted(parent context.Context, iface string) error { + ctx, cancel := newCallContext(parent) + err := removeDBus(ctx, iface) + cancel() + if err == nil { + return nil + } + if !errors.Is(err, errDBusUnavailable) { + log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err) + } + ctx, cancel = newCallContext(parent) + defer cancel() + return removeCLI(ctx, iface) +} + +func isRunningDBus(ctx context.Context) (bool, error) { + conn, err := dbus.SystemBus() + if err != nil { + return false, fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + var zone string + if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil { + return false, fmt.Errorf("firewalld getDefaultZone: %w", err) + } + return true, nil +} + +func isRunningCLI(ctx context.Context) bool { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return false + } + return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil +} + +func addDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errAlreadyEnabled) { + return nil + } + + if dbusErrContains(call.Err, errZoneAlreadySet) { + move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface) + if move.Err != nil { + return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err) + } + return nil + } + + return fmt.Errorf("firewalld addInterface: %w", call.Err) +} + +func removeDBus(ctx context.Context, iface string) error { + conn, err := dbus.SystemBus() + if err != nil { + return fmt.Errorf("%w: %v", errDBusUnavailable, err) + } + obj := conn.Object(dbusDest, dbusPath) + + call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface) + if call.Err == nil { + return nil + } + + if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) { + return nil + } + + return fmt.Errorf("firewalld removeInterface: %w", call.Err) +} + +func addCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + // --change-interface (no --permanent) binds the interface for the + // current runtime only; we do not want membership to persist across + // reboots because netbird re-asserts it on every startup. + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface, + ).CombinedOutput() + if err != nil { + return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out))) + } + return nil +} + +func removeCLI(ctx context.Context, iface string) error { + if _, err := exec.LookPath("firewall-cmd"); err != nil { + return fmt.Errorf("firewall-cmd not available: %w", err) + } + + out, err := exec.CommandContext(ctx, + "firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface, + ).CombinedOutput() + if err != nil { + msg := strings.TrimSpace(string(out)) + if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) { + return nil + } + return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg) + } + return nil +} + +func dbusErrContains(err error, code string) bool { + if err == nil { + return false + } + var de dbus.Error + if errors.As(err, &de) { + for _, b := range de.Body { + if s, ok := b.(string); ok && strings.Contains(s, code) { + return true + } + } + } + return strings.Contains(err.Error(), code) +} diff --git a/client/firewall/firewalld/firewalld_linux_test.go b/client/firewall/firewalld/firewalld_linux_test.go new file mode 100644 index 000000000..d812745fc --- /dev/null +++ b/client/firewall/firewalld/firewalld_linux_test.go @@ -0,0 +1,49 @@ +//go:build linux + +package firewalld + +import ( + "errors" + "testing" + + "github.com/godbus/dbus/v5" +) + +func TestDBusErrContains(t *testing.T) { + tests := []struct { + name string + err error + code string + want bool + }{ + {"nil error", nil, errZoneAlreadySet, false}, + {"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true}, + {"plain error miss", errors.New("something else"), errZoneAlreadySet, false}, + { + "dbus.Error body match", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}}, + errZoneAlreadySet, + true, + }, + { + "dbus.Error body miss", + dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}}, + errAlreadyEnabled, + false, + }, + { + "dbus.Error non-string body falls back to Error()", + dbus.Error{Name: "x", Body: []any{123}}, + "x", + true, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := dbusErrContains(tc.err, tc.code) + if got != tc.want { + t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want) + } + }) + } +} diff --git a/client/firewall/firewalld/firewalld_other.go b/client/firewall/firewalld/firewalld_other.go new file mode 100644 index 000000000..cfa28221d --- /dev/null +++ b/client/firewall/firewalld/firewalld_other.go @@ -0,0 +1,25 @@ +//go:build !linux + +package firewalld + +import "context" + +// SetParentContext is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func SetParentContext(context.Context) { + // intentionally empty: firewalld is a Linux-only daemon +} + +// TrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func TrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} + +// UntrustInterface is a no-op on non-Linux platforms because firewalld only +// runs on Linux. +func UntrustInterface(string) error { + // intentionally empty: firewalld is a Linux-only daemon + return nil +} diff --git a/client/firewall/iptables/acl_linux.go b/client/firewall/iptables/acl_linux.go index d83798f09..e629f7881 100644 --- a/client/firewall/iptables/acl_linux.go +++ b/client/firewall/iptables/acl_linux.go @@ -21,6 +21,10 @@ const ( // rules chains contains the effective ACL rules chainNameInputRules = "NETBIRD-ACL-INPUT" + + // mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent + // external DNAT from bypassing ACL rules. + mangleFwdKey = "MANGLE-FORWARD" ) type aclEntries map[string][][]string @@ -274,6 +278,12 @@ func (m *aclManager) cleanChains() error { } } + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err) + } + } + for _, ipsetName := range m.ipsetStore.ipsetNames() { if err := m.flushIPSet(ipsetName); err != nil { if errors.Is(err, ipset.ErrSetNotExist) { @@ -303,6 +313,10 @@ func (m *aclManager) createDefaultChains() error { } for chainName, rules := range m.entries { + // mangle FORWARD guard rules are handled separately below + if chainName == mangleFwdKey { + continue + } for _, rule := range rules { if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil { log.Debugf("failed to create input chain jump rule: %s", err) @@ -322,6 +336,13 @@ func (m *aclManager) createDefaultChains() error { } clear(m.optionalEntries) + // Insert mangle FORWARD guard rules to prevent external DNAT bypass. + for _, rule := range m.entries[mangleFwdKey] { + if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil { + log.Errorf("failed to add mangle FORWARD guard rule: %v", err) + } + } + return nil } @@ -343,6 +364,22 @@ func (m *aclManager) seedInitialEntries() { m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT}) m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN}) + + // Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it + // traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD + // can be inserted above ours. Mangle runs before filter, so these guard rules enforce the + // ACL mark check where it cannot be overridden. + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", + "-j", "ACCEPT", + }) + m.appendToEntries(mangleFwdKey, []string{ + "-i", m.wgIface.Name(), + "-m", "conntrack", "--ctstate", "DNAT", + "-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), + "-j", "DROP", + }) } func (m *aclManager) seedInitialOptionalEntries() { diff --git a/client/firewall/iptables/manager_linux.go b/client/firewall/iptables/manager_linux.go index a1d4467d5..7d8cd7f8c 100644 --- a/client/firewall/iptables/manager_linux.go +++ b/client/firewall/iptables/manager_linux.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -86,6 +87,12 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error { log.Warnf("raw table not available, notrack rules will be disabled: %v", err) } + // Trust after all fatal init steps so a later failure doesn't leave the + // interface in firewalld's trusted zone without a corresponding Close. + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + // persist early to ensure cleanup of chains go func() { if err := stateManager.PersistState(context.Background()); err != nil { @@ -191,6 +198,12 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err)) } + // Appending to merr intentionally blocks DeleteState below so ShutdownState + // stays persisted and the crash-recovery path retries firewalld cleanup. + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + // attempt to delete state only if all other operations succeeded if merr == nil { if err := stateManager.DeleteState(&ShutdownState{}); err != nil { @@ -217,6 +230,11 @@ func (m *Manager) AllowNetbird() error { if err != nil { return fmt.Errorf("allow netbird interface traffic: %w", err) } + + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } diff --git a/client/firewall/nftables/manager_linux.go b/client/firewall/nftables/manager_linux.go index 0b5b61e04..8cd5cc6b3 100644 --- a/client/firewall/nftables/manager_linux.go +++ b/client/firewall/nftables/manager_linux.go @@ -14,6 +14,7 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/sys/unix" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/internal/statemanager" @@ -217,6 +218,10 @@ func (m *Manager) AllowNetbird() error { return fmt.Errorf("flush allow input netbird rules: %w", err) } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + return nil } diff --git a/client/firewall/nftables/router_linux.go b/client/firewall/nftables/router_linux.go index 904daf7cb..8cc0d2792 100644 --- a/client/firewall/nftables/router_linux.go +++ b/client/firewall/nftables/router_linux.go @@ -19,6 +19,7 @@ import ( "golang.org/x/sys/unix" nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewall "github.com/netbirdio/netbird/client/firewall/manager" nbid "github.com/netbirdio/netbird/client/internal/acl/id" "github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate" @@ -40,6 +41,8 @@ const ( chainNameForward = "FORWARD" chainNameMangleForward = "netbird-mangle-forward" + firewalldTableName = "firewalld" + userDataAcceptForwardRuleIif = "frwacceptiif" userDataAcceptForwardRuleOif = "frwacceptoif" userDataAcceptInputRule = "inputaccept" @@ -133,6 +136,10 @@ func (r *router) Reset() error { merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err)) } + if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil { + merr = multierror.Append(merr, err) + } + if err := r.removeNatPreroutingRules(); err != nil { merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err)) } @@ -280,6 +287,10 @@ func (r *router) createContainers() error { log.Errorf("failed to add accept rules for the forward chain: %s", err) } + if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } + if err := r.refreshRulesMap(); err != nil { log.Errorf("failed to refresh rules: %s", err) } @@ -1319,6 +1330,13 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool { return false } + // Skip firewalld-owned chains. Firewalld creates its chains with the + // NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM. + // We delegate acceptance to firewalld by trusting the interface instead. + if chain.Table.Name == firewalldTableName { + return false + } + // Skip all iptables-managed tables in the ip family if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) { return false diff --git a/client/firewall/uspfilter/allow_netbird.go b/client/firewall/uspfilter/allow_netbird.go index 6a6533344..b120cdf12 100644 --- a/client/firewall/uspfilter/allow_netbird.go +++ b/client/firewall/uspfilter/allow_netbird.go @@ -3,6 +3,9 @@ package uspfilter import ( + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/internal/statemanager" ) @@ -16,6 +19,9 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error { if m.nativeFirewall != nil { return m.nativeFirewall.Close(stateManager) } + if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to untrust interface in firewalld: %v", err) + } return nil } @@ -24,5 +30,8 @@ func (m *Manager) AllowNetbird() error { if m.nativeFirewall != nil { return m.nativeFirewall.AllowNetbird() } + if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil { + log.Warnf("failed to trust interface in firewalld: %v", err) + } return nil } diff --git a/client/firewall/uspfilter/common/iface.go b/client/firewall/uspfilter/common/iface.go index 7296953db..9c06eb3f7 100644 --- a/client/firewall/uspfilter/common/iface.go +++ b/client/firewall/uspfilter/common/iface.go @@ -9,6 +9,7 @@ import ( // IFaceMapper defines subset methods of interface required for manager type IFaceMapper interface { + Name() string SetFilter(device.PacketFilter) error Address() wgaddr.Address GetWGDevice() *wgdevice.Device diff --git a/client/firewall/uspfilter/filter_test.go b/client/firewall/uspfilter/filter_test.go index 39e8efa2c..5fb9fef0e 100644 --- a/client/firewall/uspfilter/filter_test.go +++ b/client/firewall/uspfilter/filter_test.go @@ -31,12 +31,20 @@ var logger = log.NewFromLogrus(logrus.StandardLogger()) var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger() type IFaceMock struct { + NameFunc func() string SetFilterFunc func(device.PacketFilter) error AddressFunc func() wgaddr.Address GetWGDeviceFunc func() *wgdevice.Device GetDeviceFunc func() *device.FilteredDevice } +func (i *IFaceMock) Name() string { + if i.NameFunc == nil { + return "wgtest" + } + return i.NameFunc() +} + func (i *IFaceMock) GetWGDevice() *wgdevice.Device { if i.GetWGDeviceFunc == nil { return nil diff --git a/client/iface/bind/ice_bind_test.go b/client/iface/bind/ice_bind_test.go index 1fdd955c9..f49e68508 100644 --- a/client/iface/bind/ice_bind_test.go +++ b/client/iface/bind/ice_bind_test.go @@ -239,8 +239,12 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) { ipv6Count++ } - assert.Equal(t, packetsPerFamily, ipv4Count) - assert.Equal(t, packetsPerFamily, ipv6Count) + // Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The + // routing-correctness checks above are the real assertions; the counts + // are a sanity bound to catch a totally silent path. + minDelivered := packetsPerFamily * 80 / 100 + assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold") + assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold") } func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) { diff --git a/client/iface/iface.go b/client/iface/iface.go index 9b331d68c..655dd1682 100644 --- a/client/iface/iface.go +++ b/client/iface/iface.go @@ -217,7 +217,6 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error // Close closes the tunnel interface func (w *WGIface) Close() error { w.mu.Lock() - defer w.mu.Unlock() var result *multierror.Error @@ -225,7 +224,15 @@ func (w *WGIface) Close() error { result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err)) } - if err := w.tun.Close(); err != nil { + // Release w.mu before calling w.tun.Close(): the underlying + // wireguard-go device.Close() waits for its send/receive goroutines + // to drain. Some of those goroutines re-enter WGIface methods that + // take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so + // holding the mutex here would deadlock the shutdown path. + tun := w.tun + w.mu.Unlock() + + if err := tun.Close(); err != nil { result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err)) } diff --git a/client/iface/iface_close_test.go b/client/iface/iface_close_test.go new file mode 100644 index 000000000..171e15d0a --- /dev/null +++ b/client/iface/iface_close_test.go @@ -0,0 +1,113 @@ +//go:build !android + +package iface + +import ( + "errors" + "sync" + "testing" + "time" + + wgdevice "golang.zx2c4.com/wireguard/device" + "golang.zx2c4.com/wireguard/tun/netstack" + + "github.com/netbirdio/netbird/client/iface/device" + "github.com/netbirdio/netbird/client/iface/udpmux" + "github.com/netbirdio/netbird/client/iface/wgaddr" + "github.com/netbirdio/netbird/client/iface/wgproxy" +) + +// fakeTunDevice implements WGTunDevice and lets the test control when +// Close() returns. It mimics the wireguard-go shutdown path, which blocks +// until its goroutines drain. Some of those goroutines (e.g. the packet +// filter DNS hook in client/internal/dns) call back into WGIface, so if +// WGIface.Close() held w.mu across tun.Close() the shutdown would +// deadlock. +type fakeTunDevice struct { + closeStarted chan struct{} + unblockClose chan struct{} +} + +func (f *fakeTunDevice) Create() (device.WGConfigurer, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) { + return nil, errors.New("not implemented") +} +func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil } +func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} } +func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU } +func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" } +func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil } +func (f *fakeTunDevice) Device() *wgdevice.Device { return nil } +func (f *fakeTunDevice) GetNet() *netstack.Net { return nil } +func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil } + +func (f *fakeTunDevice) Close() error { + close(f.closeStarted) + <-f.unblockClose + return nil +} + +type fakeProxyFactory struct{} + +func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil } +func (fakeProxyFactory) GetProxyPort() uint16 { return 0 } +func (fakeProxyFactory) Free() error { return nil } + +// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock +// that surfaces as a macOS test-timeout in +// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu +// while waiting for the wireguard-go device goroutines to finish, and +// one of those goroutines (the DNS filter hook) calls back into +// WGIface.GetDevice() which needs the same mutex. The fix is to drop +// the lock before tun.Close() returns control. +func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) { + tun := &fakeTunDevice{ + closeStarted: make(chan struct{}), + unblockClose: make(chan struct{}), + } + w := &WGIface{ + tun: tun, + wgProxyFactory: fakeProxyFactory{}, + } + + closeDone := make(chan error, 1) + go func() { + closeDone <- w.Close() + }() + + select { + case <-tun.closeStarted: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + t.Fatal("tun.Close() was never invoked") + } + + // Simulate the WireGuard read goroutine calling back into WGIface + // via the packet filter's DNS hook. If Close() still held w.mu + // during tun.Close(), this would block until the test timeout. + getDeviceDone := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + _ = w.GetDevice() + close(getDeviceDone) + }() + + select { + case <-getDeviceDone: + case <-time.After(2 * time.Second): + close(tun.unblockClose) + wg.Wait() + t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun") + } + + close(tun.unblockClose) + select { + case <-closeDone: + case <-time.After(2 * time.Second): + t.Fatal("WGIface.Close() never returned after the tun was unblocked") + } +} diff --git a/client/iface/udpmux/universal.go b/client/iface/udpmux/universal.go index 43bfedaaa..89a7eefb9 100644 --- a/client/iface/udpmux/universal.go +++ b/client/iface/udpmux/universal.go @@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { } if u.address.Network.Contains(a) { - log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address) + log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address) } @@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error { u.addrCache.Store(addr.String(), isRouted) if isRouted { // Extra log, as the error only shows up with ICE logging enabled - log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix) + log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix) return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix) } } diff --git a/client/internal/connect.go b/client/internal/connect.go index bc2bd84d9..ac498f719 100644 --- a/client/internal/connect.go +++ b/client/internal/connect.go @@ -94,6 +94,7 @@ func (c *ConnectClient) RunOnAndroid( dnsAddresses []netip.AddrPort, dnsReadyListener dns.ReadyListener, stateFilePath string, + cacheDir string, ) error { // in case of non Android os these variables will be nil mobileDependency := MobileDependency{ @@ -103,6 +104,7 @@ func (c *ConnectClient) RunOnAndroid( HostDNSAddresses: dnsAddresses, DnsReadyListener: dnsReadyListener, StateFilePath: stateFilePath, + TempDir: cacheDir, } return c.run(mobileDependency, nil, "") } @@ -338,6 +340,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan log.Error(err) return wrapErr(err) } + engineConfig.TempDir = mobileDependency.TempDir relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU) c.statusRecorder.SetRelayMgr(relayManager) diff --git a/client/internal/debug/debug.go b/client/internal/debug/debug.go index bad481519..90560d028 100644 --- a/client/internal/debug/debug.go +++ b/client/internal/debug/debug.go @@ -16,7 +16,6 @@ import ( "path/filepath" "runtime" "runtime/pprof" - "slices" "sort" "strings" "time" @@ -31,7 +30,6 @@ import ( "github.com/netbirdio/netbird/client/internal/updater/installer" nbstatus "github.com/netbirdio/netbird/client/status" mgmProto "github.com/netbirdio/netbird/shared/management/proto" - "github.com/netbirdio/netbird/util" ) const readmeContent = `Netbird debug bundle @@ -235,6 +233,7 @@ type BundleGenerator struct { statusRecorder *peer.Status syncResponse *mgmProto.SyncResponse logPath string + tempDir string cpuProfile []byte capturePath string refreshStatus func() // Optional callback to refresh status before bundle generation @@ -258,6 +257,7 @@ type GeneratorDependencies struct { StatusRecorder *peer.Status SyncResponse *mgmProto.SyncResponse LogPath string + TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. CPUProfile []byte CapturePath string RefreshStatus func() @@ -278,6 +278,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen statusRecorder: deps.StatusRecorder, syncResponse: deps.SyncResponse, logPath: deps.LogPath, + tempDir: deps.TempDir, cpuProfile: deps.CPUProfile, capturePath: deps.CapturePath, refreshStatus: deps.RefreshStatus, @@ -291,7 +292,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen // Generate creates a debug bundle and returns the location. func (g *BundleGenerator) Generate() (resp string, err error) { - bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip") + bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip") if err != nil { return "", fmt.Errorf("create zip file: %w", err) } @@ -381,15 +382,8 @@ func (g *BundleGenerator) createArchive() error { log.Errorf("failed to add wg show output: %v", err) } - if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { - if err := g.addLogfile(); err != nil { - log.Errorf("failed to add log file to debug bundle: %v", err) - if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs as fallback: %v", err) - } - } - } else if err := g.trySystemdLogFallback(); err != nil { - log.Errorf("failed to add systemd logs: %v", err) + if err := g.addPlatformLog(); err != nil { + log.Errorf("failed to add logs to debug bundle: %v", err) } if err := g.addUpdateLogs(); err != nil { diff --git a/client/internal/debug/debug_android.go b/client/internal/debug/debug_android.go new file mode 100644 index 000000000..a4e2b3e98 --- /dev/null +++ b/client/internal/debug/debug_android.go @@ -0,0 +1,41 @@ +//go:build android + +package debug + +import ( + "fmt" + "io" + "os/exec" + + log "github.com/sirupsen/logrus" +) + +func (g *BundleGenerator) addPlatformLog() error { + cmd := exec.Command("/system/bin/logcat", "-d") + stdout, err := cmd.StdoutPipe() + if err != nil { + return fmt.Errorf("logcat stdout pipe: %w", err) + } + + if err := cmd.Start(); err != nil { + return fmt.Errorf("start logcat: %w", err) + } + + var logReader io.Reader = stdout + if g.anonymize { + var pw *io.PipeWriter + logReader, pw = io.Pipe() + go anonymizeLog(stdout, pw, g.anonymizer) + } + + if err := g.addFileToZip(logReader, "logcat.txt"); err != nil { + return fmt.Errorf("add logcat to zip: %w", err) + } + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("wait logcat: %w", err) + } + + log.Debug("added logcat output to debug bundle") + return nil +} diff --git a/client/internal/debug/debug_nonandroid.go b/client/internal/debug/debug_nonandroid.go new file mode 100644 index 000000000..117238dec --- /dev/null +++ b/client/internal/debug/debug_nonandroid.go @@ -0,0 +1,25 @@ +//go:build !android + +package debug + +import ( + "slices" + + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/util" +) + +func (g *BundleGenerator) addPlatformLog() error { + if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) { + if err := g.addLogfile(); err != nil { + log.Errorf("failed to add log file to debug bundle: %v", err) + if err := g.trySystemdLogFallback(); err != nil { + return err + } + } + } else if err := g.trySystemdLogFallback(); err != nil { + return err + } + return nil +} diff --git a/client/internal/debug/upload_test.go b/client/internal/debug/upload_test.go index e833c196d..f224b8d3f 100644 --- a/client/internal/debug/upload_test.go +++ b/client/internal/debug/upload_test.go @@ -3,10 +3,12 @@ package debug import ( "context" "errors" + "net" "net/http" "os" "path/filepath" "testing" + "time" "github.com/stretchr/testify/require" @@ -19,8 +21,10 @@ func TestUpload(t *testing.T) { t.Skip("Skipping upload test on docker ci") } testDir := t.TempDir() - testURL := "http://localhost:8080" + addr := reserveLoopbackPort(t) + testURL := "http://" + addr t.Setenv("SERVER_URL", testURL) + t.Setenv("SERVER_ADDRESS", addr) t.Setenv("STORE_DIR", testDir) srv := server.NewServer() go func() { @@ -33,6 +37,7 @@ func TestUpload(t *testing.T) { t.Errorf("Failed to stop server: %v", err) } }) + waitForServer(t, addr) file := filepath.Join(t.TempDir(), "tmpfile") fileContent := []byte("test file content") @@ -47,3 +52,30 @@ func TestUpload(t *testing.T) { require.NoError(t, err) require.Equal(t, fileContent, createdFileContent) } + +// reserveLoopbackPort binds an ephemeral port on loopback to learn a free +// address, then releases it so the server under test can rebind. The close/ +// rebind window is racy in theory; on loopback with a kernel-assigned port +// it's essentially never contended in practice. +func reserveLoopbackPort(t *testing.T) string { + t.Helper() + l, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + addr := l.Addr().String() + require.NoError(t, l.Close()) + return addr +} + +func waitForServer(t *testing.T, addr string) { + t.Helper() + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond) + if err == nil { + _ = c.Close() + return + } + time.Sleep(20 * time.Millisecond) + } + t.Fatalf("server did not start listening on %s in time", addr) +} diff --git a/client/internal/dns/file_parser_unix.go b/client/internal/dns/file_parser_unix.go index 8dacb4e51..50ba74c0c 100644 --- a/client/internal/dns/file_parser_unix.go +++ b/client/internal/dns/file_parser_unix.go @@ -13,6 +13,7 @@ import ( const ( defaultResolvConfPath = "/etc/resolv.conf" + nsswitchConfPath = "/etc/nsswitch.conf" ) type resolvConf struct { diff --git a/client/internal/dns/handler_chain.go b/client/internal/dns/handler_chain.go index 6fbdedc59..57e7722d4 100644 --- a/client/internal/dns/handler_chain.go +++ b/client/internal/dns/handler_chain.go @@ -1,7 +1,10 @@ package dns import ( + "context" "fmt" + "math" + "net" "slices" "strconv" "strings" @@ -192,6 +195,12 @@ func (c *HandlerChain) logHandlers() { } func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + c.dispatch(w, r, math.MaxInt) +} + +// dispatch routes a DNS request through the chain, skipping handlers with +// priority > maxPriority. Shared by ServeDNS and ResolveInternal. +func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) { if len(r.Question) == 0 { return } @@ -216,6 +225,9 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { // Try handlers in priority order for _, entry := range handlers { + if entry.Priority > maxPriority { + continue + } if !c.isHandlerMatch(qname, entry) { continue } @@ -273,6 +285,55 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q cw.response.Len(), meta, time.Since(startTime)) } +// ResolveInternal runs an in-process DNS query against the chain, skipping any +// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt +// cache refresher) that must bypass themselves to avoid loops. Honors ctx +// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own +// (bounded by the invoked handler's internal timeout). +func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) { + if len(r.Question) == 0 { + return nil, fmt.Errorf("empty question") + } + + base := &internalResponseWriter{} + done := make(chan struct{}) + go func() { + c.dispatch(base, r, maxPriority) + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + // Prefer a completed response if dispatch finished concurrently with cancellation. + select { + case <-done: + default: + return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err()) + } + } + + if base.response == nil || base.response.Rcode == dns.RcodeRefused { + return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d", + strings.ToLower(r.Question[0].Name), maxPriority) + } + return base.response, nil +} + +// HasRootHandlerAtOrBelow reports whether any "." handler is registered at +// priority ≤ maxPriority. +func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + c.mu.RLock() + defer c.mu.RUnlock() + + for _, h := range c.handlers { + if h.Pattern == "." && h.Priority <= maxPriority { + return true + } + } + return false +} + func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { switch { case entry.Pattern == ".": @@ -291,3 +352,36 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { } } } + +// internalResponseWriter captures a dns.Msg for in-process chain queries. +type internalResponseWriter struct { + response *dns.Msg +} + +func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil } +func (w *internalResponseWriter) LocalAddr() net.Addr { return nil } +func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil } + +// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg +// still surface their answer to ResolveInternal. +func (w *internalResponseWriter) Write(p []byte) (int, error) { + msg := new(dns.Msg) + if err := msg.Unpack(p); err != nil { + return 0, err + } + w.response = msg + return len(p), nil +} + +func (w *internalResponseWriter) Close() error { return nil } +func (w *internalResponseWriter) TsigStatus() error { return nil } + +// TsigTimersOnly is part of dns.ResponseWriter. +func (w *internalResponseWriter) TsigTimersOnly(bool) { + // no-op: in-process queries carry no TSIG state. +} + +// Hijack is part of dns.ResponseWriter. +func (w *internalResponseWriter) Hijack() { + // no-op: in-process queries have no underlying connection to hand off. +} diff --git a/client/internal/dns/handler_chain_test.go b/client/internal/dns/handler_chain_test.go index fa9525069..034a760dc 100644 --- a/client/internal/dns/handler_chain_test.go +++ b/client/internal/dns/handler_chain_test.go @@ -1,11 +1,15 @@ package dns_test import ( + "context" + "net" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" nbdns "github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns/test" @@ -1042,3 +1046,163 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) { }) } } + +// answeringHandler writes a fixed A record to ack the query. Used to verify +// which handler ResolveInternal dispatches to. +type answeringHandler struct { + name string + ip string +} + +func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + _ = w.WriteMsg(resp) +} + +func (h *answeringHandler) String() string { return h.name } + +func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) { + chain := nbdns.NewHandlerChain() + + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + low := &answeringHandler{name: "low", ip: "10.0.0.2"} + + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + chain.AddHandler("example.com.", low, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + assert.NotNil(t, resp) + assert.Equal(t, 1, len(resp.Answer)) + a, ok := resp.Answer[0].(*dns.A) + assert.True(t, ok) + assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream") +} + +func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) { + chain := nbdns.NewHandlerChain() + high := &answeringHandler{name: "high", ip: "10.0.0.1"} + chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + _, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.Error(t, err, "no handler at or below maxPriority should error") +} + +// rawWriteHandler packs a response and calls ResponseWriter.Write directly +// (instead of WriteMsg), exercising the internalResponseWriter.Write path. +type rawWriteHandler struct { + ip string +} + +func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + resp := &dns.Msg{} + resp.SetReply(r) + resp.Answer = []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP(h.ip).To4(), + }} + packed, err := resp.Pack() + if err != nil { + return + } + _, _ = w.Write(packed) +} + +func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) { + chain := nbdns.NewHandlerChain() + chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream) + assert.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Answer, 1) + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok) + assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer") +} + +func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) { + chain := nbdns.NewHandlerChain() + _, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream) + assert.Error(t, err) +} + +// hangingHandler blocks indefinitely until closed, simulating a wedged upstream. +type hangingHandler struct { + block chan struct{} +} + +func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { + <-h.block + resp := &dns.Msg{} + resp.SetReply(r) + _ = w.WriteMsg(resp) +} + +func (h *hangingHandler) String() string { return "hangingHandler" } + +func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &hangingHandler{block: make(chan struct{})} + defer close(h.block) + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + + r := new(dns.Msg) + r.SetQuestion("example.com.", dns.TypeA) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + start := time.Now() + _, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream) + elapsed := time.Since(start) + + assert.Error(t, err) + assert.ErrorIs(t, err, context.DeadlineExceeded) + assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline") +} + +func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) { + chain := nbdns.NewHandlerChain() + h := &answeringHandler{name: "h", ip: "10.0.0.1"} + + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain") + + chain.AddHandler("example.com.", h, nbdns.PriorityUpstream) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count") + + chain.AddHandler(".", h, nbdns.PriorityMgmtCache) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded") + + chain.AddHandler(".", h, nbdns.PriorityDefault) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included") + + chain.RemoveHandler(".", nbdns.PriorityDefault) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) + + // Primary nsgroup case: root handler lands at PriorityUpstream. + chain.AddHandler(".", h, nbdns.PriorityUpstream) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included") + chain.RemoveHandler(".", nbdns.PriorityUpstream) + + // Fallback case: original /etc/resolv.conf entries land at PriorityFallback. + chain.AddHandler(".", h, nbdns.PriorityFallback) + assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included") + chain.RemoveHandler(".", nbdns.PriorityFallback) + assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream)) +} diff --git a/client/internal/dns/host_unix.go b/client/internal/dns/host_unix.go index 422fed4e5..d7301d725 100644 --- a/client/internal/dns/host_unix.go +++ b/client/internal/dns/host_unix.go @@ -46,12 +46,12 @@ type restoreHostManager interface { } func newHostManager(wgInterface string) (hostManager, error) { - osManager, err := getOSDNSManagerType() + osManager, reason, err := getOSDNSManagerType() if err != nil { return nil, fmt.Errorf("get os dns manager type: %w", err) } - log.Infof("System DNS manager discovered: %s", osManager) + log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) mgr, err := newHostManagerFromType(wgInterface, osManager) // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value if err != nil { @@ -74,17 +74,49 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor } } -func getOSDNSManagerType() (osManagerType, error) { +func getOSDNSManagerType() (osManagerType, string, error) { + resolved := isSystemdResolvedRunning() + nss := isLibnssResolveUsed() + stub := checkStub() + + // Prefer systemd-resolved whenever it owns libc resolution, regardless of + // who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups + // that go through nss-resolve, and in foreign mode they can loop back + // through resolved as an upstream. + if resolved && (nss || stub) { + return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil + } + + mgr, reason, rejected, err := scanResolvConfHeader() + if err != nil { + return 0, "", err + } + if reason != "" { + return mgr, reason, nil + } + + fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub) + if len(rejected) > 0 { + fallback += "; rejected: " + strings.Join(rejected, ", ") + } + return fileManager, fallback, nil +} + +// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the +// matching manager. If reason is empty the caller should pick file mode and +// use rejected for diagnostics. +func scanResolvConfHeader() (osManagerType, string, []string, error) { file, err := os.Open(defaultResolvConfPath) if err != nil { - return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) + return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) } defer func() { - if err := file.Close(); err != nil { - log.Errorf("close file %s: %s", defaultResolvConfPath, err) + if cerr := file.Close(); cerr != nil { + log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) } }() + var rejected []string scanner := bufio.NewScanner(file) for scanner.Scan() { text := scanner.Text() @@ -92,41 +124,48 @@ func getOSDNSManagerType() (osManagerType, error) { continue } if text[0] != '#' { - return fileManager, nil + break } - if strings.Contains(text, fileGeneratedResolvConfContentHeader) { - return netbirdManager, nil - } - if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { - return networkManager, nil - } - if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() { - if checkStub() { - return systemdManager, nil - } else { - return fileManager, nil - } - } - if strings.Contains(text, "resolvconf") { - if isSystemdResolveConfMode() { - return systemdManager, nil - } - - return resolvConfManager, nil + if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { + return mgr, reason, nil, nil + } else if rej != "" { + rejected = append(rejected, rej) } } if err := scanner.Err(); err != nil && err != io.EOF { - return 0, fmt.Errorf("scan: %w", err) + return 0, "", nil, fmt.Errorf("scan: %w", err) } - - return fileManager, nil + return 0, "", rejected, nil } -// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager. +// matchResolvConfHeader inspects a single comment line. Returns either a +// definitive (manager, reason) or a non-empty rejected diagnostic. +func matchResolvConfHeader(text string) (osManagerType, string, string) { + if strings.Contains(text, fileGeneratedResolvConfContentHeader) { + return netbirdManager, "netbird-managed resolv.conf header detected", "" + } + if strings.Contains(text, "NetworkManager") { + if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() { + return networkManager, "NetworkManager header + supported version on dbus", "" + } + return 0, "", "NetworkManager header (no dbus or unsupported version)" + } + if strings.Contains(text, "resolvconf") { + if isSystemdResolveConfMode() { + return systemdManager, "resolvconf header in systemd-resolved compatibility mode", "" + } + return resolvConfManager, "resolvconf header detected", "" + } + return 0, "", "" +} + +// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed +// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping +// into file mode while resolved is active. func checkStub() bool { rConf, err := parseDefaultResolvConf() if err != nil { - log.Warnf("failed to parse resolv conf: %s", err) + log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) return true } @@ -139,3 +178,36 @@ func checkStub() bool { return false } + +// isLibnssResolveUsed reports whether nss-resolve is listed before dns on +// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are +// delegated to systemd-resolved regardless of /etc/resolv.conf. +func isLibnssResolveUsed() bool { + bs, err := os.ReadFile(nsswitchConfPath) + if err != nil { + log.Debugf("read %s: %v", nsswitchConfPath, err) + return false + } + return parseNsswitchResolveAhead(bs) +} + +func parseNsswitchResolveAhead(data []byte) bool { + for _, line := range strings.Split(string(data), "\n") { + if i := strings.IndexByte(line, '#'); i >= 0 { + line = line[:i] + } + fields := strings.Fields(line) + if len(fields) < 2 || fields[0] != "hosts:" { + continue + } + for _, module := range fields[1:] { + switch module { + case "dns": + return false + case "resolve": + return true + } + } + } + return false +} diff --git a/client/internal/dns/host_unix_test.go b/client/internal/dns/host_unix_test.go new file mode 100644 index 000000000..e936281d3 --- /dev/null +++ b/client/internal/dns/host_unix_test.go @@ -0,0 +1,76 @@ +//go:build (linux && !android) || freebsd + +package dns + +import "testing" + +func TestParseNsswitchResolveAhead(t *testing.T) { + tests := []struct { + name string + in string + want bool + }{ + { + name: "resolve before dns with action token", + in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n", + want: true, + }, + { + name: "dns before resolve", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n", + want: false, + }, + { + name: "debian default with only dns", + in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n", + want: false, + }, + { + name: "neither resolve nor dns", + in: "hosts: files myhostname\n", + want: false, + }, + { + name: "no hosts line", + in: "passwd: files systemd\ngroup: files systemd\n", + want: false, + }, + { + name: "empty", + in: "", + want: false, + }, + { + name: "comments and blank lines ignored", + in: "# comment\n\n# another\nhosts: resolve dns\n", + want: true, + }, + { + name: "trailing inline comment", + in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n", + want: true, + }, + { + name: "hosts token must be the first field", + in: " hosts: resolve dns\n", + want: true, + }, + { + name: "other db line mentioning resolve is ignored", + in: "networks: resolve\nhosts: dns\n", + want: false, + }, + { + name: "only resolve, no dns", + in: "hosts: files resolve\n", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want { + t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/client/internal/dns/mgmt/mgmt.go b/client/internal/dns/mgmt/mgmt.go index 314af51d9..988e427fb 100644 --- a/client/internal/dns/mgmt/mgmt.go +++ b/client/internal/dns/mgmt/mgmt.go @@ -2,40 +2,83 @@ package mgmt import ( "context" + "errors" "fmt" "net" - "net/netip" "net/url" + "os" + "slices" "strings" "sync" + "sync/atomic" "time" "github.com/miekg/dns" log "github.com/sirupsen/logrus" + "golang.org/x/sync/singleflight" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" + "github.com/netbirdio/netbird/client/internal/dns/resutil" "github.com/netbirdio/netbird/shared/management/domain" ) -const dnsTimeout = 5 * time.Second +const ( + dnsTimeout = 5 * time.Second + defaultTTL = 300 * time.Second + refreshBackoff = 30 * time.Second -// Resolver caches critical NetBird infrastructure domains + // envMgmtCacheTTL overrides defaultTTL for integration/dev testing. + envMgmtCacheTTL = "NB_MGMT_CACHE_TTL" +) + +// ChainResolver lets the cache refresh stale entries through the DNS handler +// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the +// system resolver. +type ChainResolver interface { + ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) + HasRootHandlerAtOrBelow(maxPriority int) bool +} + +// cachedRecord holds DNS records plus timestamps used for TTL refresh. +// records and cachedAt are set at construction and treated as immutable; +// lastFailedRefresh and consecFailures are mutable and must be accessed under +// Resolver.mutex. +type cachedRecord struct { + records []dns.RR + cachedAt time.Time + lastFailedRefresh time.Time + consecFailures int +} + +// Resolver caches critical NetBird infrastructure domains. +// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex. type Resolver struct { - records map[dns.Question][]dns.RR + records map[dns.Question]*cachedRecord mgmtDomain *domain.Domain serverDomains *dnsconfig.ServerDomains mutex sync.RWMutex -} -type ipsResponse struct { - ips []netip.Addr - err error + chain ChainResolver + chainMaxPriority int + refreshGroup singleflight.Group + + // refreshing tracks questions whose refresh is running via the OS + // fallback path. A ServeDNS hit for a question in this map indicates + // the OS resolver routed the recursive query back to us (loop). Only + // the OS path arms this so chain-path refreshes don't produce false + // positives. The atomic bool is CAS-flipped once per refresh to + // throttle the warning log. + refreshing map[dns.Question]*atomic.Bool + + cacheTTL time.Duration } // NewResolver creates a new management domains cache resolver. func NewResolver() *Resolver { return &Resolver{ - records: make(map[dns.Question][]dns.RR), + records: make(map[dns.Question]*cachedRecord), + refreshing: make(map[dns.Question]*atomic.Bool), + cacheTTL: resolveCacheTTL(), } } @@ -44,7 +87,19 @@ func (m *Resolver) String() string { return "MgmtCacheResolver" } -// ServeDNS implements dns.Handler interface. +// SetChainResolver wires the handler chain used to refresh stale cache entries. +// maxPriority caps which handlers may answer refresh queries (typically +// PriorityUpstream, so upstream/default/fallback handlers are consulted and +// mgmt/route/local handlers are skipped). +func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) { + m.mutex.Lock() + m.chain = chain + m.chainMaxPriority = maxPriority + m.mutex.Unlock() +} + +// ServeDNS serves cached A/AAAA records. Stale entries are returned +// immediately and refreshed asynchronously (stale-while-revalidate). func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { if len(r.Question) == 0 { m.continueToNext(w, r) @@ -60,7 +115,14 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { } m.mutex.RLock() - records, found := m.records[question] + cached, found := m.records[question] + inflight := m.refreshing[question] + var shouldRefresh bool + if found { + stale := time.Since(cached.cachedAt) > m.cacheTTL + inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff + shouldRefresh = stale && !inBackoff + } m.mutex.RUnlock() if !found { @@ -68,12 +130,23 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { return } + if inflight != nil && inflight.CompareAndSwap(false, true) { + log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)", + question.Name) + } + + // Skip scheduling a refresh goroutine if one is already inflight for + // this question; singleflight would dedup anyway but skipping avoids + // a parked goroutine per stale hit under bursty load. + if shouldRefresh && inflight == nil { + m.scheduleRefresh(question, cached) + } + resp := &dns.Msg{} resp.SetReply(r) resp.Authoritative = false resp.RecursionAvailable = true - - resp.Answer = append(resp.Answer, records...) + resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt)) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) @@ -98,101 +171,260 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) { } } -// AddDomain manually adds a domain to cache by resolving it. +// AddDomain resolves a domain and stores its A/AAAA records in the cache. +// A family that resolves NODATA (nil err, zero records) evicts any stale +// entry for that qtype. func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) ctx, cancel := context.WithTimeout(ctx, dnsTimeout) defer cancel() - ips, err := lookupIPWithExtraTimeout(ctx, d) - if err != nil { - return err + aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) + + if errA != nil && errAAAA != nil { + return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA)) } - var aRecords, aaaaRecords []dns.RR - for _, ip := range ips { - if ip.Is4() { - rr := &dns.A{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeA, - Class: dns.ClassINET, - Ttl: 300, - }, - A: ip.AsSlice(), - } - aRecords = append(aRecords, rr) - } else if ip.Is6() { - rr := &dns.AAAA{ - Hdr: dns.RR_Header{ - Name: dnsName, - Rrtype: dns.TypeAAAA, - Class: dns.ClassINET, - Ttl: 300, - }, - AAAA: ip.AsSlice(), - } - aaaaRecords = append(aaaaRecords, rr) + if len(aRecords) == 0 && len(aaaaRecords) == 0 { + if err := errors.Join(errA, errAAAA); err != nil { + return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) } + return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString()) } + now := time.Now() m.mutex.Lock() + defer m.mutex.Unlock() - if len(aRecords) > 0 { - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - m.records[aQuestion] = aRecords - } + m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) + m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) - if len(aaaaRecords) > 0 { - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - m.records[aaaaQuestion] = aaaaRecords - } - - m.mutex.Unlock() - - log.Debugf("added domain=%s with %d A records and %d AAAA records", + log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", d.SafeString(), len(aRecords), len(aaaaRecords)) return nil } -func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) { - log.Infof("looking up IP for mgmt domain=%s", d.SafeString()) - defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString()) - resultChan := make(chan *ipsResponse, 1) +// applyFamilyRecords writes records, evicts on NODATA, leaves the cache +// untouched on error. Caller holds m.mutex. +func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { + q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} + switch { + case len(records) > 0: + m.records[q] = &cachedRecord{records: records, cachedAt: now} + case err == nil: + delete(m.records, q) + } +} - go func() { - ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString()) - resultChan <- &ipsResponse{ - err: err, - ips: ips, +// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per +// unique in-flight key; bursty stale hits share its channel. expected is the +// cachedRecord pointer observed by the caller; the refresh only mutates the +// cache if that pointer is still the one stored, so a stale in-flight refresh +// can't clobber a newer entry written by AddDomain or a competing refresh. +func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) { + key := question.Name + "|" + dns.TypeToString[question.Qtype] + _ = m.refreshGroup.DoChan(key, func() (any, error) { + return nil, m.refreshQuestion(question, expected) + }) +} + +// refreshQuestion replaces the cached records on success, or marks the entry +// failed (arming the backoff) on failure. While this runs, ServeDNS can detect +// a resolver loop by spotting a query for this same question arriving on us. +// expected pins the cache entry observed at schedule time; mutations only apply +// if m.records[question] still points at it. +func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error { + ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout) + defer cancel() + + d, err := domain.FromString(strings.TrimSuffix(question.Name, ".")) + if err != nil { + m.markRefreshFailed(question, expected) + return fmt.Errorf("parse domain: %w", err) + } + + records, err := m.lookupRecords(ctx, d, question) + if err != nil { + fails := m.markRefreshFailed(question, expected) + logf := log.Warnf + if fails == 0 || fails > 1 { + logf = log.Debugf } - }() - - var resp *ipsResponse - - select { - case <-time.After(dnsTimeout + time.Millisecond*500): - log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString()) - return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString()) - case <-ctx.Done(): - return nil, ctx.Err() - case resp = <-resultChan: + logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", + d.SafeString(), dns.TypeToString[question.Qtype], err, fails) + return err } - if resp.err != nil { - return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err) + // NOERROR/NODATA: family gone upstream, evict so we stop serving stale. + if len(records) == 0 { + m.mutex.Lock() + if m.records[question] == expected { + delete(m.records, question) + m.mutex.Unlock() + log.Infof("removed mgmt cache domain=%s type=%s: no records returned", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.mutex.Unlock() + log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil } - return resp.ips, nil + + now := time.Now() + m.mutex.Lock() + if m.records[question] != expected { + m.mutex.Unlock() + log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil + } + m.records[question] = &cachedRecord{records: records, cachedAt: now} + m.mutex.Unlock() + + log.Infof("refreshed mgmt cache domain=%s type=%s", + d.SafeString(), dns.TypeToString[question.Qtype]) + return nil +} + +func (m *Resolver) markRefreshing(question dns.Question) { + m.mutex.Lock() + m.refreshing[question] = &atomic.Bool{} + m.mutex.Unlock() +} + +func (m *Resolver) clearRefreshing(question dns.Question) { + m.mutex.Lock() + delete(m.refreshing, question) + m.mutex.Unlock() +} + +// markRefreshFailed arms the backoff and returns the new consecutive-failure +// count so callers can downgrade subsequent failure logs to debug. +func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int { + m.mutex.Lock() + defer m.mutex.Unlock() + c, ok := m.records[question] + if !ok || c != expected { + return 0 + } + c.lastFailedRefresh = time.Now() + c.consecFailures++ + return c.consecFailures +} + +// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let +// callers tell records, NODATA (nil err, no records), and failure apart. +func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA) + return + } + + // TODO: drop once every supported OS registers a fallback resolver. Safe + // today: no root handler at priority ≤ PriorityUpstream means NetBird is + // not the system resolver, so net.DefaultResolver will not loop back. + aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA) + aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA) + return +} + +// lookupRecords resolves a single record type via chain or OS. The OS branch +// arms the loop detector for the duration of its call so that ServeDNS can +// spot the OS resolver routing the recursive query back to us. +func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) { + m.mutex.RLock() + chain := m.chain + maxPriority := m.chainMaxPriority + m.mutex.RUnlock() + + if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) { + return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype) + } + + // TODO: drop once every supported OS registers a fallback resolver. + m.markRefreshing(q) + defer m.clearRefreshing(q) + + return m.osLookup(ctx, d, q.Name, q.Qtype) +} + +// lookupViaChain resolves via the handler chain and rewrites each RR to use +// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache +// target-owned records or upstream TTLs. NODATA returns (nil, nil). +func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) { + msg := &dns.Msg{} + msg.SetQuestion(dnsName, qtype) + msg.RecursionDesired = true + + resp, err := chain.ResolveInternal(ctx, msg, maxPriority) + if err != nil { + return nil, fmt.Errorf("chain resolve: %w", err) + } + if resp == nil { + return nil, fmt.Errorf("chain resolve returned nil response") + } + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode]) + } + + ttl := uint32(m.cacheTTL.Seconds()) + owners := cnameOwners(dnsName, resp.Answer) + var filtered []dns.RR + for _, rr := range resp.Answer { + h := rr.Header() + if h.Class != dns.ClassINET || h.Rrtype != qtype { + continue + } + if !owners[strings.ToLower(dns.Fqdn(h.Name))] { + continue + } + if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil { + filtered = append(filtered, cp) + } + } + return filtered, nil +} + +// osLookup resolves a single family via net.DefaultResolver using resutil, +// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA +// returns (nil, nil). +func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) { + network := resutil.NetworkForQtype(qtype) + if network == "" { + return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype]) + } + + log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype]) + + result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype) + if result.Rcode == dns.RcodeSuccess { + return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil + } + + if result.Err != nil { + return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err) + } + return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode]) +} + +// responseTTL returns the remaining cache lifetime in seconds (rounded up), +// so downstream resolvers don't cache an answer for longer than we will. +func (m *Resolver) responseTTL(cachedAt time.Time) uint32 { + remaining := m.cacheTTL - time.Since(cachedAt) + if remaining <= 0 { + return 0 + } + return uint32((remaining + time.Second - 1) / time.Second) } // PopulateFromConfig extracts and caches domains from the client configuration. @@ -224,19 +456,12 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error { m.mutex.Lock() defer m.mutex.Unlock() - aQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeA, - Qclass: dns.ClassINET, - } - delete(m.records, aQuestion) - - aaaaQuestion := dns.Question{ - Name: dnsName, - Qtype: dns.TypeAAAA, - Qclass: dns.ClassINET, - } - delete(m.records, aaaaQuestion) + qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} + qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} + delete(m.records, qA) + delete(m.records, qAAAA) + delete(m.refreshing, qA) + delete(m.refreshing, qAAAA) log.Debugf("removed domain=%s from cache", d.SafeString()) return nil @@ -394,3 +619,73 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve return domains } + +// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non +// A/AAAA records return nil. +func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR { + switch r := rr.(type) { + case *dns.A: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.A = slices.Clone(r.A) + return &cp + case *dns.AAAA: + cp := *r + cp.Hdr.Name = owner + cp.Hdr.Ttl = ttl + cp.AAAA = slices.Clone(r.AAAA) + return &cp + } + return nil +} + +// cloneRecordsWithTTL clones A/AAAA records preserving their owner and +// stamping ttl so the response shares no memory with the cached slice. +func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR { + out := make([]dns.RR, 0, len(records)) + for _, rr := range records { + if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil { + out = append(out, cp) + } + } + return out +} + +// cnameOwners returns dnsName plus every target reachable by following CNAMEs +// in answer, iterating until fixed point so out-of-order chains resolve. +func cnameOwners(dnsName string, answer []dns.RR) map[string]bool { + owners := map[string]bool{dnsName: true} + for { + added := false + for _, rr := range answer { + cname, ok := rr.(*dns.CNAME) + if !ok { + continue + } + name := strings.ToLower(dns.Fqdn(cname.Hdr.Name)) + if !owners[name] { + continue + } + target := strings.ToLower(dns.Fqdn(cname.Target)) + if !owners[target] { + owners[target] = true + added = true + } + } + if !added { + return owners + } + } +} + +// resolveCacheTTL reads the cache TTL override env var; invalid or empty +// values fall back to defaultTTL. Called once per Resolver from NewResolver. +func resolveCacheTTL() time.Duration { + if v := os.Getenv(envMgmtCacheTTL); v != "" { + if d, err := time.ParseDuration(v); err == nil && d > 0 { + return d + } + } + return defaultTTL +} diff --git a/client/internal/dns/mgmt/mgmt_refresh_test.go b/client/internal/dns/mgmt/mgmt_refresh_test.go new file mode 100644 index 000000000..9faa5a0b8 --- /dev/null +++ b/client/internal/dns/mgmt/mgmt_refresh_test.go @@ -0,0 +1,408 @@ +package mgmt + +import ( + "context" + "errors" + "net" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/miekg/dns" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/netbirdio/netbird/client/internal/dns/test" + "github.com/netbirdio/netbird/shared/management/domain" +) + +type fakeChain struct { + mu sync.Mutex + calls map[string]int + answers map[string][]dns.RR + err error + hasRoot bool + onLookup func() +} + +func newFakeChain() *fakeChain { + return &fakeChain{ + calls: map[string]int{}, + answers: map[string][]dns.RR{}, + hasRoot: true, + } +} + +func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool { + f.mu.Lock() + defer f.mu.Unlock() + return f.hasRoot +} + +func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) { + f.mu.Lock() + q := msg.Question[0] + key := q.Name + "|" + dns.TypeToString[q.Qtype] + f.calls[key]++ + answers := f.answers[key] + err := f.err + onLookup := f.onLookup + f.mu.Unlock() + + if onLookup != nil { + onLookup() + } + if err != nil { + return nil, err + } + resp := &dns.Msg{} + resp.SetReply(msg) + resp.Answer = answers + return resp, nil +} + +func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) { + f.mu.Lock() + defer f.mu.Unlock() + key := name + "|" + dns.TypeToString[qtype] + hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60} + switch qtype { + case dns.TypeA: + f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}} + case dns.TypeAAAA: + f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}} + } +} + +func (f *fakeChain) callCount(name string, qtype uint16) int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls[name+"|"+dns.TypeToString[qtype]] +} + +// waitFor polls the predicate until it returns true or the deadline passes. +func waitFor(t *testing.T, d time.Duration, fn func() bool) { + t.Helper() + deadline := time.Now().Add(d) + for time.Now().Before(deadline) { + if fn() { + return + } + time.Sleep(5 * time.Millisecond) + } + t.Fatalf("condition not met within %s", d) +} + +func queryA(t *testing.T, r *Resolver, name string) *dns.Msg { + t.Helper() + msg := new(dns.Msg) + msg.SetQuestion(name, dns.TypeA) + w := &test.MockResponseWriter{} + r.ServeDNS(w, msg) + return w.GetLastResponse() +} + +func firstA(t *testing.T, resp *dns.Msg) string { + t.Helper() + require.NotNil(t, resp) + require.Greater(t, len(resp.Answer), 0, "expected at least one answer") + a, ok := resp.Answer[0].(*dns.A) + require.True(t, ok, "expected A record") + return a.A.String() +} + +func TestResolver_CacheTTLGatesRefresh(t *testing.T) { + // Same cached entry age, different cacheTTL values: the shorter TTL must + // trigger a background refresh, the longer one must not. Proves that the + // per-Resolver cacheTTL field actually drives the stale decision. + cachedAt := time.Now().Add(-100 * time.Millisecond) + + newRec := func() *cachedRecord { + return &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: cachedAt, + } + } + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + + t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = 10 * time.Millisecond + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount(q.Name, dns.TypeA) >= 1 + }) + }) + + t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) { + r := NewResolver() + r.cacheTTL = time.Hour + chain := newFakeChain() + chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + r.records[q] = newRec() + + resp := queryA(t, r, q.Name) + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh") + }) +} + +func TestResolver_ServeFresh_NoRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), // fresh + } + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp)) + + time.Sleep(20 * time.Millisecond) + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh") +} + +func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), // stale + } + + // First query: serves stale immediately. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1 + }) + + // Next query should now return the refreshed IP. + waitFor(t, time.Second, func() bool { + resp := queryA(t, r, "mgmt.example.com.") + return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2" + }) +} + +func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + + var inflight atomic.Int32 + var maxInflight atomic.Int32 + chain.onLookup = func() { + cur := inflight.Add(1) + defer inflight.Add(-1) + for { + prev := maxInflight.Load() + if cur <= prev || maxInflight.CompareAndSwap(prev, cur) { + break + } + } + time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide + } + + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(1) + go func() { + defer wg.Done() + queryA(t, r, "mgmt.example.com.") + }() + } + wg.Wait() + + waitFor(t, 2*time.Second, func() bool { + return inflight.Load() == 0 + }) + + calls := chain.callCount("mgmt.example.com.", dns.TypeA) + assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls) + assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently") +} + +func TestResolver_RefreshFailureArmsBackoff(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.err = errors.New("boom") + r.SetChainResolver(chain, 50) + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now().Add(-2 * defaultTTL), + } + + // First stale hit triggers a refresh attempt that fails. + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails") + + waitFor(t, time.Second, func() bool { + return chain.callCount("mgmt.example.com.", dns.TypeA) == 1 + }) + waitFor(t, time.Second, func() bool { + r.mutex.RLock() + defer r.mutex.RUnlock() + c, ok := r.records[q] + return ok && !c.lastFailedRefresh.IsZero() + }) + + // Subsequent stale hits within backoff window should not schedule more refreshes. + for i := 0; i < 10; i++ { + queryA(t, r, "mgmt.example.com.") + } + time.Sleep(50 * time.Millisecond) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes") +} + +func TestResolver_NoRootHandler_SkipsChain(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.hasRoot = false + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + r.SetChainResolver(chain, 50) + + // With hasRoot=false the chain must not be consulted. Use a short + // deadline so the OS fallback returns quickly without waiting on a + // real network call in CI. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.") + + assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), + "chain must not be used when no root handler is registered at the bound priority") +} + +func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) { + // ServeDNS being invoked for a question while a refresh for that question + // is inflight indicates a resolver loop (OS resolver sent the recursive + // query back to us). The inflightRefresh.loopLoggedOnce flag must be set. + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + // Simulate an inflight refresh. + r.markRefreshing(q) + defer r.clearRefreshing(q) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries") + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + require.NotNil(t, inflight) + assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed") +} + +func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + r.markRefreshing(q) + defer r.clearRefreshing(q) + + // Multiple ServeDNS calls during the same refresh must not re-set the flag + // (CompareAndSwap from false -> true returns true only on the first call). + for range 5 { + queryA(t, r, "mgmt.example.com.") + } + + r.mutex.RLock() + inflight := r.refreshing[q] + r.mutex.RUnlock() + assert.True(t, inflight.Load()) +} + +func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) { + r := NewResolver() + + q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET} + r.records[q] = &cachedRecord{ + records: []dns.RR{&dns.A{ + Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60}, + A: net.ParseIP("10.0.0.1").To4(), + }}, + cachedAt: time.Now(), + } + + queryA(t, r, "mgmt.example.com.") + + r.mutex.RLock() + _, ok := r.refreshing[q] + r.mutex.RUnlock() + assert.False(t, ok, "no refresh inflight means no loop tracking") +} + +func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) { + r := NewResolver() + chain := newFakeChain() + chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2") + chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2") + r.SetChainResolver(chain, 50) + + require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com"))) + + resp := queryA(t, r, "mgmt.example.com.") + assert.Equal(t, "10.0.0.2", firstA(t, resp)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA)) + assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA)) +} diff --git a/client/internal/dns/mgmt/mgmt_test.go b/client/internal/dns/mgmt/mgmt_test.go index 9e8a746f3..276cbba0a 100644 --- a/client/internal/dns/mgmt/mgmt_test.go +++ b/client/internal/dns/mgmt/mgmt_test.go @@ -6,6 +6,7 @@ import ( "net/url" "strings" "testing" + "time" "github.com/miekg/dns" "github.com/stretchr/testify/assert" @@ -23,6 +24,60 @@ func TestResolver_NewResolver(t *testing.T) { assert.False(t, resolver.MatchSubdomains()) } +func TestResolveCacheTTL(t *testing.T) { + tests := []struct { + name string + value string + want time.Duration + }{ + {"unset falls back to default", "", defaultTTL}, + {"valid duration", "45s", 45 * time.Second}, + {"valid minutes", "2m", 2 * time.Minute}, + {"malformed falls back to default", "not-a-duration", defaultTTL}, + {"zero falls back to default", "0s", defaultTTL}, + {"negative falls back to default", "-5s", defaultTTL}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Setenv(envMgmtCacheTTL, tc.value) + got := resolveCacheTTL() + assert.Equal(t, tc.want, got, "parsed TTL should match") + }) + } +} + +func TestNewResolver_CacheTTLFromEnv(t *testing.T) { + t.Setenv(envMgmtCacheTTL, "7s") + r := NewResolver() + assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env") +} + +func TestResolver_ResponseTTL(t *testing.T) { + now := time.Now() + tests := []struct { + name string + cacheTTL time.Duration + cachedAt time.Time + wantMin uint32 + wantMax uint32 + }{ + {"fresh entry returns full TTL", 60 * time.Second, now, 59, 60}, + {"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31}, + {"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0}, + {"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + r := &Resolver{cacheTTL: tc.cacheTTL} + got := r.responseTTL(tc.cachedAt) + assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin") + assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax") + }) + } +} + func TestResolver_ExtractDomainFromURL(t *testing.T) { tests := []struct { name string diff --git a/client/internal/dns/server.go b/client/internal/dns/server.go index f7865047b..d4f54dec5 100644 --- a/client/internal/dns/server.go +++ b/client/internal/dns/server.go @@ -212,6 +212,7 @@ func newDefaultServer( ctx, stop := context.WithCancel(ctx) mgmtCacheResolver := mgmt.NewResolver() + mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream) defaultServer := &DefaultServer{ ctx: ctx, diff --git a/client/internal/engine.go b/client/internal/engine.go index ef643872f..c7e1abc53 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -26,6 +26,7 @@ import ( nberrors "github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/firewall" + "github.com/netbirdio/netbird/client/firewall/firewalld" firewallManager "github.com/netbirdio/netbird/client/firewall/manager" "github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder" "github.com/netbirdio/netbird/client/iface" @@ -142,6 +143,7 @@ type EngineConfig struct { ProfileConfig *profilemanager.Config LogPath string + TempDir string } // EngineServices holds the external service dependencies required by the Engine. @@ -573,7 +575,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL) e.connMgr.Start(e.ctx) e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg) - e.srWatcher.Start() + e.srWatcher.Start(peer.IsForceRelayed()) e.receiveSignalEvents() e.receiveManagementEvents() @@ -607,6 +609,8 @@ func (e *Engine) createFirewall() error { return nil } + firewalld.SetParentContext(e.ctx) + var err error e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU) if err != nil { @@ -1099,6 +1103,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR StatusRecorder: e.statusRecorder, SyncResponse: syncResponse, LogPath: e.config.LogPath, + TempDir: e.config.TempDir, ClientMetrics: e.clientMetrics, RefreshStatus: func() { e.RunHealthProbes(true) diff --git a/client/internal/engine_test.go b/client/internal/engine_test.go index 1f6fe384a..9fa4e51b2 100644 --- a/client/internal/engine_test.go +++ b/client/internal/engine_test.go @@ -55,6 +55,7 @@ import ( nbdns "github.com/netbirdio/netbird/dns" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -1634,7 +1635,12 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri peersManager := peers.NewManager(store, permissionsManager) jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -1656,7 +1662,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/internal/mobile_dependency.go b/client/internal/mobile_dependency.go index 7c95e2b99..310d61a25 100644 --- a/client/internal/mobile_dependency.go +++ b/client/internal/mobile_dependency.go @@ -22,4 +22,8 @@ type MobileDependency struct { DnsManager dns.IosDnsManager FileDescriptor int32 StateFilePath string + + // TempDir is a writable directory for temporary files (e.g., debug bundle zip). + // On Android, this should be set to the app's cache directory. + TempDir string } diff --git a/client/internal/netflow/conntrack/conntrack.go b/client/internal/netflow/conntrack/conntrack.go index a4ffa3a25..2420b1fdf 100644 --- a/client/internal/netflow/conntrack/conntrack.go +++ b/client/internal/netflow/conntrack/conntrack.go @@ -7,7 +7,9 @@ import ( "fmt" "net/netip" "sync" + "time" + "github.com/cenkalti/backoff/v4" "github.com/google/uuid" log "github.com/sirupsen/logrus" nfct "github.com/ti-mo/conntrack" @@ -17,31 +19,64 @@ import ( nbnet "github.com/netbirdio/netbird/client/net" ) -const defaultChannelSize = 100 +const ( + defaultChannelSize = 100 + reconnectInitInterval = 5 * time.Second + reconnectMaxInterval = 5 * time.Minute + reconnectRandomization = 0.5 +) + +// listener abstracts a netlink conntrack connection for testability. +type listener interface { + Listen(evChan chan<- nfct.Event, numWorkers uint8, groups []netfilter.NetlinkGroup) (chan error, error) + Close() error +} // ConnTrack manages kernel-based conntrack events type ConnTrack struct { flowLogger nftypes.FlowLogger iface nftypes.IFaceMapper - conn *nfct.Conn + conn listener mux sync.Mutex + dial func() (listener, error) instanceID uuid.UUID started bool done chan struct{} sysctlModified bool } +// DialFunc is a constructor for netlink conntrack connections. +type DialFunc func() (listener, error) + +// Option configures a ConnTrack instance. +type Option func(*ConnTrack) + +// WithDialer overrides the default netlink dialer, primarily for testing. +func WithDialer(dial DialFunc) Option { + return func(c *ConnTrack) { + c.dial = dial + } +} + +func defaultDial() (listener, error) { + return nfct.Dial(nil) +} + // New creates a new connection tracker that interfaces with the kernel's conntrack system -func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper) *ConnTrack { - return &ConnTrack{ +func New(flowLogger nftypes.FlowLogger, iface nftypes.IFaceMapper, opts ...Option) *ConnTrack { + ct := &ConnTrack{ flowLogger: flowLogger, iface: iface, instanceID: uuid.New(), - started: false, + dial: defaultDial, done: make(chan struct{}, 1), } + for _, opt := range opts { + opt(ct) + } + return ct } // Start begins tracking connections by listening for conntrack events. This method is idempotent. @@ -59,8 +94,9 @@ func (c *ConnTrack) Start(enableCounters bool) error { c.EnableAccounting() } - conn, err := nfct.Dial(nil) + conn, err := c.dial() if err != nil { + c.RestoreAccounting() return fmt.Errorf("dial conntrack: %w", err) } c.conn = conn @@ -76,9 +112,16 @@ func (c *ConnTrack) Start(enableCounters bool) error { log.Errorf("Error closing conntrack connection: %v", err) } c.conn = nil + c.RestoreAccounting() return fmt.Errorf("start conntrack listener: %w", err) } + // Drain any stale stop signal from a previous cycle. + select { + case <-c.done: + default: + } + c.started = true go c.receiverRoutine(events, errChan) @@ -92,17 +135,98 @@ func (c *ConnTrack) receiverRoutine(events chan nfct.Event, errChan chan error) case event := <-events: c.handleEvent(event) case err := <-errChan: - log.Errorf("Error from conntrack event listener: %v", err) - if err := c.conn.Close(); err != nil { - log.Errorf("Error closing conntrack connection: %v", err) + if events, errChan = c.handleListenerError(err); events == nil { + return } - return case <-c.done: return } } } +// handleListenerError closes the failed connection and attempts to reconnect. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) handleListenerError(err error) (chan nfct.Event, chan error) { + log.Warnf("conntrack event listener failed: %v", err) + c.closeConn() + return c.reconnect() +} + +func (c *ConnTrack) closeConn() { + c.mux.Lock() + defer c.mux.Unlock() + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + log.Debugf("close conntrack connection: %v", err) + } + c.conn = nil + } +} + +// reconnect attempts to re-establish the conntrack netlink listener with exponential backoff. +// Returns new channels on success, or nil if shutdown was requested. +func (c *ConnTrack) reconnect() (chan nfct.Event, chan error) { + bo := &backoff.ExponentialBackOff{ + InitialInterval: reconnectInitInterval, + RandomizationFactor: reconnectRandomization, + Multiplier: backoff.DefaultMultiplier, + MaxInterval: reconnectMaxInterval, + MaxElapsedTime: 0, // retry indefinitely + Clock: backoff.SystemClock, + } + bo.Reset() + + for { + delay := bo.NextBackOff() + log.Infof("reconnecting conntrack listener in %s", delay) + + select { + case <-c.done: + c.mux.Lock() + c.started = false + c.mux.Unlock() + return nil, nil + case <-time.After(delay): + } + + conn, err := c.dial() + if err != nil { + log.Warnf("reconnect conntrack dial: %v", err) + continue + } + + events := make(chan nfct.Event, defaultChannelSize) + errChan, err := conn.Listen(events, 1, []netfilter.NetlinkGroup{ + netfilter.GroupCTNew, + netfilter.GroupCTDestroy, + }) + if err != nil { + log.Warnf("reconnect conntrack listen: %v", err) + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + continue + } + + c.mux.Lock() + if !c.started { + // Stop() ran while we were reconnecting. + c.mux.Unlock() + if closeErr := conn.Close(); closeErr != nil { + log.Debugf("close conntrack connection: %v", closeErr) + } + return nil, nil + } + c.conn = conn + c.mux.Unlock() + + log.Infof("conntrack listener reconnected successfully") + + return events, errChan + } +} + // Stop stops the connection tracking. This method is idempotent. func (c *ConnTrack) Stop() { c.mux.Lock() @@ -136,23 +260,27 @@ func (c *ConnTrack) Close() error { c.mux.Lock() defer c.mux.Unlock() - if c.started { - select { - case c.done <- struct{}{}: - default: - } + if !c.started { + return nil } + select { + case c.done <- struct{}{}: + default: + } + + c.started = false + + var closeErr error if c.conn != nil { - err := c.conn.Close() + closeErr = c.conn.Close() c.conn = nil - c.started = false + } - c.RestoreAccounting() + c.RestoreAccounting() - if err != nil { - return fmt.Errorf("close conntrack: %w", err) - } + if closeErr != nil { + return fmt.Errorf("close conntrack: %w", closeErr) } return nil diff --git a/client/internal/netflow/conntrack/conntrack_test.go b/client/internal/netflow/conntrack/conntrack_test.go new file mode 100644 index 000000000..35ceec90d --- /dev/null +++ b/client/internal/netflow/conntrack/conntrack_test.go @@ -0,0 +1,224 @@ +//go:build linux && !android + +package conntrack + +import ( + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + nfct "github.com/ti-mo/conntrack" + "github.com/ti-mo/netfilter" +) + +type mockListener struct { + errChan chan error + closed atomic.Bool + closedCh chan struct{} +} + +func newMockListener() *mockListener { + return &mockListener{ + errChan: make(chan error, 1), + closedCh: make(chan struct{}), + } +} + +func (m *mockListener) Listen(evChan chan<- nfct.Event, _ uint8, _ []netfilter.NetlinkGroup) (chan error, error) { + return m.errChan, nil +} + +func (m *mockListener) Close() error { + if m.closed.CompareAndSwap(false, true) { + close(m.closedCh) + } + return nil +} + +func TestReconnectAfterError(t *testing.T) { + first := newMockListener() + second := newMockListener() + third := newMockListener() + listeners := []*mockListener{first, second, third} + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := int(callCount.Add(1)) - 1 + return listeners[n], nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Inject an error on the first listener. + first.errChan <- assert.AnError + + // Wait for reconnect to complete. + require.Eventually(t, func() bool { + return callCount.Load() >= 2 + }, 15*time.Second, 100*time.Millisecond, "reconnect should dial a new connection") + + // The first connection must have been closed. + select { + case <-first.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("first connection was not closed") + } + + // Verify the receiver is still running by injecting and handling a second error. + second.errChan <- assert.AnError + + require.Eventually(t, func() bool { + return callCount.Load() >= 3 + }, 15*time.Second, 100*time.Millisecond, "second reconnect should succeed") + + ct.Stop() +} + +func TestStopDuringReconnectBackoff(t *testing.T) { + mock := newMockListener() + + ct := New(nil, nil, WithDialer(func() (listener, error) { + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger an error so the receiver enters reconnect. + mock.errChan <- assert.AnError + + // Wait for the error handler to close the old listener before calling Stop. + select { + case <-mock.closedCh: + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for reconnect to start") + } + + // Stop while reconnecting. + ct.Stop() + + ct.mux.Lock() + assert.False(t, ct.started, "started should be false after Stop") + assert.Nil(t, ct.conn, "conn should be nil after Stop") + ct.mux.Unlock() +} + +func TestStopRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + // Second dial: signal that we're in progress, wait for test to call Stop. + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Trigger error to enter reconnect. + first.errChan <- assert.AnError + + // Wait for reconnect's second dial to begin. + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Stop while dial is in progress (conn is nil at this point). + ct.Stop() + + // Let the dial complete. reconnect should detect started==false and close the new conn. + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Stop") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestCloseRaceWithReconnectDial(t *testing.T) { + first := newMockListener() + dialStarted := make(chan struct{}) + dialProceed := make(chan struct{}) + second := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + n := callCount.Add(1) + if n == 1 { + return first, nil + } + close(dialStarted) + <-dialProceed + return second, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + first.errChan <- assert.AnError + + select { + case <-dialStarted: + case <-time.After(15 * time.Second): + t.Fatal("timed out waiting for reconnect dial") + } + + // Close while dial is in progress (conn is nil). + require.NoError(t, ct.Close()) + + close(dialProceed) + + // The second connection should be closed (not leaked). + select { + case <-second.closedCh: + case <-time.After(2 * time.Second): + t.Fatal("second connection was leaked after Close") + } + + ct.mux.Lock() + assert.False(t, ct.started) + assert.Nil(t, ct.conn) + ct.mux.Unlock() +} + +func TestStartIsIdempotent(t *testing.T) { + mock := newMockListener() + callCount := atomic.Int32{} + + ct := New(nil, nil, WithDialer(func() (listener, error) { + callCount.Add(1) + return mock, nil + })) + + err := ct.Start(false) + require.NoError(t, err) + + // Second Start should be a no-op. + err = ct.Start(false) + require.NoError(t, err) + + assert.Equal(t, int32(1), callCount.Load(), "dial should only be called once") + + ct.Stop() +} diff --git a/client/internal/peer/conn.go b/client/internal/peer/conn.go index 8d1585b3f..1e416bfe7 100644 --- a/client/internal/peer/conn.go +++ b/client/internal/peer/conn.go @@ -185,17 +185,20 @@ func (conn *Conn) Open(engineCtx context.Context) error { conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager) - relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() - workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) - if err != nil { - return err + forceRelay := IsForceRelayed() + if !forceRelay { + relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally() + workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally) + if err != nil { + return err + } + conn.workerICE = workerICE } - conn.workerICE = workerICE conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages) conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer) - if !isForceRelayed() { + if !forceRelay { conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer) } @@ -251,7 +254,9 @@ func (conn *Conn) Close(signalToRemote bool) { conn.wgWatcherCancel() } conn.workerRelay.CloseConn() - conn.workerICE.Close() + if conn.workerICE != nil { + conn.workerICE.Close() + } if conn.wgProxyRelay != nil { err := conn.wgProxyRelay.CloseConn() @@ -294,7 +299,9 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) { // OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer. func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) { conn.dumpState.RemoteCandidate() - conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + if conn.workerICE != nil { + conn.workerICE.OnRemoteCandidate(candidate, haRoutes) + } } // SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established @@ -712,33 +719,35 @@ func (conn *Conn) evalStatus() ConnStatus { return StatusConnecting } -func (conn *Conn) isConnectedOnAllWay() (connected bool) { - // would be better to protect this with a mutex, but it could cause deadlock with Close function - +// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports. +// +// The result is a tri-state: +// - ConnStatusConnected: all available transports are up +// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting +// - ConnStatusDisconnected: no working transport +func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) { defer func() { - if !connected { + if status == guard.ConnStatusDisconnected { conn.logTraceConnState() } }() - // For JS platform: only relay connection is supported - if runtime.GOOS == "js" { - return conn.statusRelay.Get() == worker.StatusConnected + iceWorkerCreated := conn.workerICE != nil + + var iceInProgress bool + if iceWorkerCreated { + iceInProgress = conn.workerICE.InProgress() } - // For non-JS platforms: check ICE connection status - if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() { - return false - } - - // If relay is supported with peer, it must also be connected - if conn.workerRelay.IsRelayConnectionSupportedWithPeer() { - if conn.statusRelay.Get() == worker.StatusDisconnected { - return false - } - } - - return true + return evalConnStatus(connStatusInputs{ + forceRelay: IsForceRelayed(), + peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(), + relayConnected: conn.statusRelay.Get() == worker.StatusConnected, + remoteSupportsICE: conn.handshaker.RemoteICESupported(), + iceWorkerCreated: iceWorkerCreated, + iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected, + iceInProgress: iceInProgress, + }) } func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) { @@ -926,3 +935,43 @@ func isController(config ConnConfig) bool { func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool { return remoteRosenpassPubKey != nil } + +func evalConnStatus(in connStatusInputs) guard.ConnStatus { + // "Relay up and needed" — the peer uses relay and the transport is connected. + relayUsedAndUp := in.peerUsesRelay && in.relayConnected + + // Force-relay mode: ICE never runs. Relay is the only transport and must be up. + if in.forceRelay { + return boolToConnStatus(relayUsedAndUp) + } + + // Remote peer doesn't support ICE, or we haven't created the worker yet: + // relay is the only possible transport. + if !in.remoteSupportsICE || !in.iceWorkerCreated { + return boolToConnStatus(relayUsedAndUp) + } + + // ICE counts as "up" when the status is anything other than Disconnected, OR + // when a negotiation is currently in progress (so we don't spam offers while one is in flight). + iceUp := in.iceStatusConnecting || in.iceInProgress + + // Relay side is acceptable if the peer doesn't rely on relay, or relay is connected. + relayOK := !in.peerUsesRelay || in.relayConnected + + switch { + case iceUp && relayOK: + return guard.ConnStatusConnected + case relayUsedAndUp: + // Relay is up but ICE is down — partially connected. + return guard.ConnStatusPartiallyConnected + default: + return guard.ConnStatusDisconnected + } +} + +func boolToConnStatus(connected bool) guard.ConnStatus { + if connected { + return guard.ConnStatusConnected + } + return guard.ConnStatusDisconnected +} diff --git a/client/internal/peer/conn_status.go b/client/internal/peer/conn_status.go index 73acc5ef5..b43e245f3 100644 --- a/client/internal/peer/conn_status.go +++ b/client/internal/peer/conn_status.go @@ -13,6 +13,20 @@ const ( StatusConnected ) +// connStatusInputs is the primitive-valued snapshot of the state that drives the +// tri-state connection classification. Extracted so the decision logic can be unit-tested +// without constructing full Worker/Handshaker objects. +type connStatusInputs struct { + forceRelay bool // NB_FORCE_RELAY or JS/WASM + peerUsesRelay bool // remote peer advertises relay support AND local has relay + relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay) + remoteSupportsICE bool // remote peer sent ICE credentials + iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode) + iceStatusConnecting bool // statusICE is anything other than Disconnected + iceInProgress bool // a negotiation is currently in flight +} + + // ConnStatus describe the status of a peer's connection type ConnStatus int32 diff --git a/client/internal/peer/conn_status_eval_test.go b/client/internal/peer/conn_status_eval_test.go new file mode 100644 index 000000000..66393cafe --- /dev/null +++ b/client/internal/peer/conn_status_eval_test.go @@ -0,0 +1,201 @@ +package peer + +import ( + "testing" + + "github.com/netbirdio/netbird/client/internal/peer/guard" +) + +func TestEvalConnStatus_ForceRelay(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "force relay, peer uses relay, relay up", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "force relay, peer uses relay, relay down", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: true, + relayConnected: false, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "force relay, peer does NOT use relay - disconnected forever", + in: connStatusInputs{ + forceRelay: true, + peerUsesRelay: false, + relayConnected: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_ICEUnavailable(t *testing.T) { + tests := []struct { + name string + in connStatusInputs + want guard.ConnStatus + }{ + { + name: "remote does not support ICE, peer uses relay, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer uses relay, relay down", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE worker not yet created, relay up", + in: connStatusInputs{ + peerUsesRelay: true, + relayConnected: true, + remoteSupportsICE: true, + iceWorkerCreated: false, + }, + want: guard.ConnStatusConnected, + }, + { + name: "remote does not support ICE, peer does not use relay", + in: connStatusInputs{ + peerUsesRelay: false, + relayConnected: false, + remoteSupportsICE: false, + iceWorkerCreated: true, + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := evalConnStatus(tc.in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v", got, tc.want) + } + }) + } +} + +func TestEvalConnStatus_FullyAvailable(t *testing.T) { + base := connStatusInputs{ + remoteSupportsICE: true, + iceWorkerCreated: true, + } + + tests := []struct { + name string + mutator func(*connStatusInputs) + want guard.ConnStatus + }{ + { + name: "ICE connected, relay connected, peer uses relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE connected, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE InProgress only, peer does NOT use relay", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.iceStatusConnecting = false + in.iceInProgress = true + }, + want: guard.ConnStatusConnected, + }, + { + name: "ICE down, relay up, peer uses relay -> partial", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = true + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusPartiallyConnected, + }, + { + name: "ICE down, peer does NOT use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = false + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = true + in.relayConnected = false + in.iceStatusConnecting = true + }, + // relayOK = false (peer uses relay but it's down), iceUp = true + // first switch arm fails (relayOK false), relayUsedAndUp = false (relay down), + // falls into default: Disconnected. + want: guard.ConnStatusDisconnected, + }, + { + name: "ICE down, relay up but peer does not use relay -> disconnected", + mutator: func(in *connStatusInputs) { + in.peerUsesRelay = false + in.relayConnected = true // not actually used since peer doesn't rely on it + in.iceStatusConnecting = false + in.iceInProgress = false + }, + want: guard.ConnStatusDisconnected, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + in := base + tc.mutator(&in) + if got := evalConnStatus(in); got != tc.want { + t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in) + } + }) + } +} diff --git a/client/internal/peer/env.go b/client/internal/peer/env.go index 7f500c410..b4ba9ad7b 100644 --- a/client/internal/peer/env.go +++ b/client/internal/peer/env.go @@ -10,7 +10,7 @@ const ( EnvKeyNBForceRelay = "NB_FORCE_RELAY" ) -func isForceRelayed() bool { +func IsForceRelayed() bool { if runtime.GOOS == "js" { return true } diff --git a/client/internal/peer/guard/guard.go b/client/internal/peer/guard/guard.go index d93403730..2e5efbcc5 100644 --- a/client/internal/peer/guard/guard.go +++ b/client/internal/peer/guard/guard.go @@ -8,7 +8,19 @@ import ( log "github.com/sirupsen/logrus" ) -type isConnectedFunc func() bool +// ConnStatus represents the connection state as seen by the guard. +type ConnStatus int + +const ( + // ConnStatusDisconnected means neither ICE nor Relay is connected. + ConnStatusDisconnected ConnStatus = iota + // ConnStatusPartiallyConnected means Relay is connected but ICE is not. + ConnStatusPartiallyConnected + // ConnStatusConnected means all required connections are established. + ConnStatusConnected +) + +type connStatusFunc func() ConnStatus // Guard is responsible for the reconnection logic. // It will trigger to send an offer to the peer then has connection issues. @@ -20,14 +32,14 @@ type isConnectedFunc func() bool // - ICE candidate changes type Guard struct { log *log.Entry - isConnectedOnAllWay isConnectedFunc + isConnectedOnAllWay connStatusFunc timeout time.Duration srWatcher *SRWatcher relayedConnDisconnected chan struct{} iCEConnDisconnected chan struct{} } -func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { +func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard { return &Guard{ log: log, isConnectedOnAllWay: isConnectedFn, @@ -57,8 +69,17 @@ func (g *Guard) SetICEConnDisconnected() { } } -// reconnectLoopWithRetry periodically check the connection status. -// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported +// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity. +// +// Behavior depends on the connection state reported by isConnectedOnAllWay: +// - Connected: no action, the peer is fully reachable. +// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling +// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all. +// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches +// to one attempt per hour. This limits signaling traffic when relay already provides connectivity. +// +// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry +// counter and backoff ticker, giving ICE a fresh chance after network conditions change. func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { srReconnectedChan := g.srWatcher.NewListener() defer g.srWatcher.RemoveListener(srReconnectedChan) @@ -68,36 +89,47 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) { tickerChannel := ticker.C + iceState := &iceRetryState{log: g.log} + defer iceState.reset() + for { select { - case t := <-tickerChannel: - if t.IsZero() { - g.log.Infof("retry timed out, stop periodic offer sending") - // after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop - tickerChannel = make(<-chan time.Time) - continue + case <-tickerChannel: + switch g.isConnectedOnAllWay() { + case ConnStatusConnected: + // all good, nothing to do + case ConnStatusDisconnected: + callback() + case ConnStatusPartiallyConnected: + if iceState.shouldRetry() { + callback() + } else { + iceState.enterHourlyMode() + ticker.Stop() + tickerChannel = iceState.hourlyC() + } } - if !g.isConnectedOnAllWay() { - callback() - } case <-g.relayedConnDisconnected: g.log.Debugf("Relay connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-g.iCEConnDisconnected: g.log.Debugf("ICE connection changed, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-srReconnectedChan: g.log.Debugf("has network changes, reset reconnection ticker") ticker.Stop() - ticker = g.prepareExponentTicker(ctx) + ticker = g.newReconnectTicker(ctx) tickerChannel = ticker.C + iceState.reset() case <-ctx.Done(): g.log.Debugf("context is done, stop reconnect loop") @@ -120,7 +152,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker { return backoff.NewTicker(bo) } -func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker { +func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker { bo := backoff.WithContext(&backoff.ExponentialBackOff{ InitialInterval: 800 * time.Millisecond, RandomizationFactor: 0.1, diff --git a/client/internal/peer/guard/ice_retry_state.go b/client/internal/peer/guard/ice_retry_state.go new file mode 100644 index 000000000..01dc1bf2d --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state.go @@ -0,0 +1,61 @@ +package guard + +import ( + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + // maxICERetries is the maximum number of ICE offer attempts when relay is connected + maxICERetries = 3 + // iceRetryInterval is the periodic retry interval after ICE retries are exhausted + iceRetryInterval = 1 * time.Hour +) + +// iceRetryState tracks the limited ICE retry attempts when relay is already connected. +// After maxICERetries attempts it switches to a periodic hourly retry. +type iceRetryState struct { + log *log.Entry + retries int + hourly *time.Ticker +} + +func (s *iceRetryState) reset() { + s.retries = 0 + if s.hourly != nil { + s.hourly.Stop() + s.hourly = nil + } +} + +// shouldRetry reports whether the caller should send another ICE offer on this tick. +// Returns false when the per-cycle retry budget is exhausted and the caller must switch +// to the hourly ticker via enterHourlyMode + hourlyC. +func (s *iceRetryState) shouldRetry() bool { + if s.hourly != nil { + s.log.Debugf("hourly ICE retry attempt") + return true + } + + s.retries++ + if s.retries <= maxICERetries { + s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries) + return true + } + + return false +} + +// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false. +func (s *iceRetryState) enterHourlyMode() { + s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries) + s.hourly = time.NewTicker(iceRetryInterval) +} + +func (s *iceRetryState) hourlyC() <-chan time.Time { + if s.hourly == nil { + return nil + } + return s.hourly.C +} diff --git a/client/internal/peer/guard/ice_retry_state_test.go b/client/internal/peer/guard/ice_retry_state_test.go new file mode 100644 index 000000000..6a5b5a76f --- /dev/null +++ b/client/internal/peer/guard/ice_retry_state_test.go @@ -0,0 +1,103 @@ +package guard + +import ( + "testing" + + log "github.com/sirupsen/logrus" +) + +func newTestRetryState() *iceRetryState { + return &iceRetryState{log: log.NewEntry(log.StandardLogger())} +} + +func TestICERetryState_AllowsInitialBudget(t *testing.T) { + s := newTestRetryState() + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries) + } + } +} + +func TestICERetryState_ExhaustsAfterBudget(t *testing.T) { + s := newTestRetryState() + + for i := 0; i < maxICERetries; i++ { + _ = s.shouldRetry() + } + + if s.shouldRetry() { + t.Fatalf("shouldRetry returned true after budget exhausted, want false") + } +} + +func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) { + s := newTestRetryState() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode") + } +} + +func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + + s.enterHourlyMode() + defer s.reset() + + if s.hourlyC() == nil { + t.Fatalf("hourlyC returned nil after enterHourlyMode") + } +} + +func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) { + s := newTestRetryState() + s.enterHourlyMode() + defer s.reset() + + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false in hourly mode, want true") + } + + // Subsequent calls also return true — we keep retrying on each hourly tick. + if !s.shouldRetry() { + t.Fatalf("second shouldRetry returned false in hourly mode, want true") + } +} + +func TestICERetryState_ResetRestoresBudget(t *testing.T) { + s := newTestRetryState() + for i := 0; i < maxICERetries+1; i++ { + _ = s.shouldRetry() + } + s.enterHourlyMode() + + s.reset() + + if s.hourlyC() != nil { + t.Fatalf("hourlyC returned non-nil channel after reset") + } + if s.retries != 0 { + t.Fatalf("retries = %d after reset, want 0", s.retries) + } + + for i := 1; i <= maxICERetries; i++ { + if !s.shouldRetry() { + t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i) + } + } +} + +func TestICERetryState_ResetIsIdempotent(t *testing.T) { + s := newTestRetryState() + s.reset() + s.reset() // second call must not panic or re-stop a nil ticker + + if s.hourlyC() != nil { + t.Fatalf("hourlyC non-nil after double reset") + } +} diff --git a/client/internal/peer/guard/sr_watcher.go b/client/internal/peer/guard/sr_watcher.go index 6f4f5ad4f..0befd7438 100644 --- a/client/internal/peer/guard/sr_watcher.go +++ b/client/internal/peer/guard/sr_watcher.go @@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove return srw } -func (w *SRWatcher) Start() { +func (w *SRWatcher) Start(disableICEMonitor bool) { w.mu.Lock() defer w.mu.Unlock() @@ -50,8 +50,10 @@ func (w *SRWatcher) Start() { ctx, cancel := context.WithCancel(context.Background()) w.cancelIceMonitor = cancel - iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) - go iceMonitor.Start(ctx, w.onICEChanged) + if !disableICEMonitor { + iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod()) + go iceMonitor.Start(ctx, w.onICEChanged) + } w.signalClient.SetOnReconnectedListener(w.onReconnected) w.relayManager.SetOnReconnectedListener(w.onReconnected) diff --git a/client/internal/peer/handshaker.go b/client/internal/peer/handshaker.go index 9b50cecd1..741dfce60 100644 --- a/client/internal/peer/handshaker.go +++ b/client/internal/peer/handshaker.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" log "github.com/sirupsen/logrus" @@ -43,6 +44,10 @@ type OfferAnswer struct { SessionID *ICESessionID } +func (o *OfferAnswer) hasICECredentials() bool { + return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != "" +} + type Handshaker struct { mu sync.Mutex log *log.Entry @@ -59,6 +64,10 @@ type Handshaker struct { relayListener *AsyncOfferListener iceListener func(remoteOfferAnswer *OfferAnswer) + // remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers. + // When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses. + remoteICESupported atomic.Bool + // remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection remoteOffersCh chan OfferAnswer // remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection @@ -66,7 +75,7 @@ type Handshaker struct { } func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker { - return &Handshaker{ + h := &Handshaker{ log: log, config: config, signaler: signaler, @@ -76,6 +85,13 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W remoteOffersCh: make(chan OfferAnswer), remoteAnswerCh: make(chan OfferAnswer), } + // assume remote supports ICE until we learn otherwise from received offers + h.remoteICESupported.Store(ice != nil) + return h +} + +func (h *Handshaker) RemoteICESupported() bool { + return h.remoteICESupported.Load() } func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) { @@ -90,18 +106,20 @@ func (h *Handshaker) Listen(ctx context.Context) { for { select { case remoteOfferAnswer := <-h.remoteOffersCh: - h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } @@ -110,18 +128,20 @@ func (h *Handshaker) Listen(ctx context.Context) { continue } case remoteOfferAnswer := <-h.remoteAnswerCh: - h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString()) + h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials()) // Record signaling received for reconnection attempts if h.metricsStages != nil { h.metricsStages.RecordSignalingReceived() } + h.updateRemoteICEState(&remoteOfferAnswer) + if h.relayListener != nil { h.relayListener.Notify(&remoteOfferAnswer) } - if h.iceListener != nil { + if h.iceListener != nil && h.RemoteICESupported() { h.iceListener(&remoteOfferAnswer) } case <-ctx.Done(): @@ -183,15 +203,18 @@ func (h *Handshaker) sendAnswer() error { } func (h *Handshaker) buildOfferAnswer() OfferAnswer { - uFrag, pwd := h.ice.GetLocalUserCredentials() - sid := h.ice.SessionID() answer := OfferAnswer{ - IceCredentials: IceCredentials{uFrag, pwd}, WgListenPort: h.config.LocalWgPort, Version: version.NetbirdVersion(), RosenpassPubKey: h.config.RosenpassConfig.PubKey, RosenpassAddr: h.config.RosenpassConfig.Addr, - SessionID: &sid, + } + + if h.ice != nil && h.RemoteICESupported() { + uFrag, pwd := h.ice.GetLocalUserCredentials() + sid := h.ice.SessionID() + answer.IceCredentials = IceCredentials{uFrag, pwd} + answer.SessionID = &sid } if addr, err := h.relay.RelayInstanceAddress(); err == nil { @@ -200,3 +223,18 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer { return answer } + +func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) { + hasICE := offer.hasICECredentials() + prev := h.remoteICESupported.Swap(hasICE) + if prev != hasICE { + if hasICE { + h.log.Infof("remote peer started sending ICE credentials") + } else { + h.log.Infof("remote peer stopped sending ICE credentials") + if h.ice != nil { + h.ice.Close() + } + } + } +} diff --git a/client/internal/peer/signaler.go b/client/internal/peer/signaler.go index b28906625..f6eb87cca 100644 --- a/client/internal/peer/signaler.go +++ b/client/internal/peer/signaler.go @@ -46,9 +46,13 @@ func (s *Signaler) Ready() bool { // SignalOfferAnswer signals either an offer or an answer to remote peer func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error { - sessionIDBytes, err := offerAnswer.SessionID.Bytes() - if err != nil { - log.Warnf("failed to get session ID bytes: %v", err) + var sessionIDBytes []byte + if offerAnswer.SessionID != nil { + var err error + sessionIDBytes, err = offerAnswer.SessionID.Bytes() + if err != nil { + log.Warnf("failed to get session ID bytes: %v", err) + } } msg, err := signal.MarshalCredential( s.wgPrivateKey, diff --git a/client/internal/portforward/env.go b/client/internal/portforward/env.go index 444a6b478..ba83c79bf 100644 --- a/client/internal/portforward/env.go +++ b/client/internal/portforward/env.go @@ -8,18 +8,27 @@ import ( ) const ( - envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisableNATMapper = "NB_DISABLE_NAT_MAPPER" + envDisablePCPHealthCheck = "NB_DISABLE_PCP_HEALTH_CHECK" ) func isDisabledByEnv() bool { - val := os.Getenv(envDisableNATMapper) + return parseBoolEnv(envDisableNATMapper) +} + +func isHealthCheckDisabled() bool { + return parseBoolEnv(envDisablePCPHealthCheck) +} + +func parseBoolEnv(key string) bool { + val := os.Getenv(key) if val == "" { return false } disabled, err := strconv.ParseBool(val) if err != nil { - log.Warnf("failed to parse %s: %v", envDisableNATMapper, err) + log.Warnf("failed to parse %s: %v", key, err) return false } return disabled diff --git a/client/internal/portforward/manager.go b/client/internal/portforward/manager.go index bf7533af9..b0680160c 100644 --- a/client/internal/portforward/manager.go +++ b/client/internal/portforward/manager.go @@ -12,12 +12,15 @@ import ( "github.com/libp2p/go-nat" log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" ) const ( - defaultMappingTTL = 2 * time.Hour - discoveryTimeout = 10 * time.Second - mappingDescription = "NetBird" + defaultMappingTTL = 2 * time.Hour + healthCheckInterval = 1 * time.Minute + discoveryTimeout = 10 * time.Second + mappingDescription = "NetBird" ) // upnpErrPermanentLeaseOnly matches UPnP error 725 in SOAP fault XML, @@ -154,7 +157,7 @@ func (m *Manager) setup(ctx context.Context) (nat.NAT, *Mapping, error) { discoverCtx, discoverCancel := context.WithTimeout(ctx, discoveryTimeout) defer discoverCancel() - gateway, err := nat.DiscoverGateway(discoverCtx) + gateway, err := discoverGateway(discoverCtx) if err != nil { return nil, nil, fmt.Errorf("discover gateway: %w", err) } @@ -189,7 +192,6 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, externalIP, err := gateway.GetExternalAddress() if err != nil { log.Debugf("failed to get external address: %v", err) - // todo return with err? } mapping := &Mapping{ @@ -208,27 +210,87 @@ func (m *Manager) createMapping(ctx context.Context, gateway nat.NAT) (*Mapping, func (m *Manager) renewLoop(ctx context.Context, gateway nat.NAT, ttl time.Duration) { if ttl == 0 { - // Permanent mappings don't expire, just wait for cancellation. - <-ctx.Done() + // Permanent mappings don't expire, just wait for cancellation + // but still run health checks for PCP gateways. + m.permanentLeaseLoop(ctx, gateway) return } - ticker := time.NewTicker(ttl / 2) - defer ticker.Stop() + renewTicker := time.NewTicker(ttl / 2) + healthTicker := time.NewTicker(healthCheckInterval) + defer renewTicker.Stop() + defer healthTicker.Stop() for { select { case <-ctx.Done(): return - case <-ticker.C: + case <-renewTicker.C: if err := m.renewMapping(ctx, gateway); err != nil { log.Warnf("failed to renew port mapping: %v", err) continue } + case <-healthTicker.C: + if m.checkHealthAndRecreate(ctx, gateway) { + renewTicker.Reset(ttl / 2) + } } } } +func (m *Manager) permanentLeaseLoop(ctx context.Context, gateway nat.NAT) { + healthTicker := time.NewTicker(healthCheckInterval) + defer healthTicker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-healthTicker.C: + m.checkHealthAndRecreate(ctx, gateway) + } + } +} + +func (m *Manager) checkHealthAndRecreate(ctx context.Context, gateway nat.NAT) bool { + if isHealthCheckDisabled() { + return false + } + + m.mappingLock.Lock() + hasMapping := m.mapping != nil + m.mappingLock.Unlock() + + if !hasMapping { + return false + } + + pcpNAT, ok := gateway.(*pcp.NAT) + if !ok { + return false + } + + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + + epoch, serverRestarted, err := pcpNAT.CheckServerHealth(ctx) + if err != nil { + log.Debugf("PCP health check failed: %v", err) + return false + } + + if serverRestarted { + log.Warnf("PCP server restart detected (epoch=%d), recreating port mapping", epoch) + if err := m.renewMapping(ctx, gateway); err != nil { + log.Errorf("failed to recreate port mapping after server restart: %v", err) + return false + } + return true + } + + return false +} + func (m *Manager) renewMapping(ctx context.Context, gateway nat.NAT) error { ctx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() diff --git a/client/internal/portforward/pcp/client.go b/client/internal/portforward/pcp/client.go new file mode 100644 index 000000000..f6d243ef9 --- /dev/null +++ b/client/internal/portforward/pcp/client.go @@ -0,0 +1,408 @@ +package pcp + +import ( + "context" + "crypto/rand" + "errors" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +const ( + defaultTimeout = 3 * time.Second + responseBufferSize = 128 + + // RFC 6887 Section 8.1.1 retry timing + initialRetryDelay = 3 * time.Second + maxRetryDelay = 1024 * time.Second + maxRetries = 4 // 3s + 6s + 12s + 24s = 45s total worst case +) + +// Client is a PCP protocol client. +// All methods are safe for concurrent use. +type Client struct { + gateway netip.Addr + timeout time.Duration + + mu sync.Mutex + // localIP caches the resolved local IP address. + localIP netip.Addr + // lastEpoch is the last observed server epoch value. + lastEpoch uint32 + // epochTime tracks when lastEpoch was received for state loss detection. + epochTime time.Time + // externalIP caches the external IP from the last successful MAP response. + externalIP netip.Addr + // epochStateLost is set when epoch indicates server restart. + epochStateLost bool +} + +// NewClient creates a new PCP client for the gateway at the given IP. +func NewClient(gateway net.IP) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: defaultTimeout, + } +} + +// NewClientWithTimeout creates a new PCP client with a custom timeout. +func NewClientWithTimeout(gateway net.IP, timeout time.Duration) *Client { + addr, ok := netip.AddrFromSlice(gateway) + if !ok { + log.Debugf("invalid gateway IP: %v", gateway) + } + return &Client{ + gateway: addr.Unmap(), + timeout: timeout, + } +} + +// SetLocalIP sets the local IP address to use in PCP requests. +func (c *Client) SetLocalIP(ip net.IP) { + addr, ok := netip.AddrFromSlice(ip) + if !ok { + log.Debugf("invalid local IP: %v", ip) + } + c.mu.Lock() + c.localIP = addr.Unmap() + c.mu.Unlock() +} + +// Gateway returns the gateway IP address. +func (c *Client) Gateway() net.IP { + return c.gateway.AsSlice() +} + +// Announce sends a PCP ANNOUNCE request to discover PCP support. +// Returns the server's epoch time on success. +func (c *Client) Announce(ctx context.Context) (epoch uint32, err error) { + localIP, err := c.getLocalIP() + if err != nil { + return 0, fmt.Errorf("get local IP: %w", err) + } + + req := buildAnnounceRequest(localIP) + resp, err := c.sendRequest(ctx, req) + if err != nil { + return 0, fmt.Errorf("send announce: %w", err) + } + + parsed, err := parseResponse(resp) + if err != nil { + return 0, fmt.Errorf("parse announce response: %w", err) + } + + if parsed.ResultCode != ResultSuccess { + return 0, fmt.Errorf("PCP ANNOUNCE failed: %s", ResultCodeString(parsed.ResultCode)) + } + + c.mu.Lock() + if c.updateEpochLocked(parsed.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.mu.Unlock() + return parsed.Epoch, nil +} + +// AddPortMapping requests a port mapping from the PCP server. +func (c *Client) AddPortMapping(ctx context.Context, protocol string, internalPort int, lifetime time.Duration) (*MapResponse, error) { + return c.addPortMappingWithHint(ctx, protocol, internalPort, internalPort, netip.Addr{}, lifetime) +} + +// AddPortMappingWithHint requests a port mapping with suggested external port and IP. +// Use lifetime <= 0 to delete a mapping. +func (c *Client) AddPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP net.IP, lifetime time.Duration) (*MapResponse, error) { + var extIP netip.Addr + if suggestedExtIP != nil { + var ok bool + extIP, ok = netip.AddrFromSlice(suggestedExtIP) + if !ok { + log.Debugf("invalid suggested external IP: %v", suggestedExtIP) + } + extIP = extIP.Unmap() + } + return c.addPortMappingWithHint(ctx, protocol, internalPort, suggestedExtPort, extIP, lifetime) +} + +func (c *Client) addPortMappingWithHint(ctx context.Context, protocol string, internalPort, suggestedExtPort int, suggestedExtIP netip.Addr, lifetime time.Duration) (*MapResponse, error) { + localIP, err := c.getLocalIP() + if err != nil { + return nil, fmt.Errorf("get local IP: %w", err) + } + + proto, err := protocolNumber(protocol) + if err != nil { + return nil, fmt.Errorf("parse protocol: %w", err) + } + + var nonce [12]byte + if _, err := rand.Read(nonce[:]); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + // Convert lifetime to seconds. Lifetime 0 means delete, so only apply + // default for positive durations that round to 0 seconds. + var lifetimeSec uint32 + if lifetime > 0 { + lifetimeSec = uint32(lifetime.Seconds()) + if lifetimeSec == 0 { + lifetimeSec = DefaultLifetime + } + } + + req := buildMapRequest(localIP, nonce, proto, uint16(internalPort), uint16(suggestedExtPort), suggestedExtIP, lifetimeSec) + + resp, err := c.sendRequest(ctx, req) + if err != nil { + return nil, fmt.Errorf("send map request: %w", err) + } + + mapResp, err := parseMapResponse(resp) + if err != nil { + return nil, fmt.Errorf("parse map response: %w", err) + } + + if mapResp.Nonce != nonce { + return nil, fmt.Errorf("nonce mismatch in response") + } + + if mapResp.Protocol != proto { + return nil, fmt.Errorf("protocol mismatch: requested %d, got %d", proto, mapResp.Protocol) + } + if mapResp.InternalPort != uint16(internalPort) { + return nil, fmt.Errorf("internal port mismatch: requested %d, got %d", internalPort, mapResp.InternalPort) + } + + if mapResp.ResultCode != ResultSuccess { + return nil, &Error{ + Code: mapResp.ResultCode, + Message: ResultCodeString(mapResp.ResultCode), + } + } + + c.mu.Lock() + if c.updateEpochLocked(mapResp.Epoch) { + log.Warnf("PCP server epoch indicates state loss - mappings may need refresh") + } + c.cacheExternalIPLocked(mapResp.ExternalIP) + c.mu.Unlock() + return mapResp, nil +} + +// DeletePortMapping removes a port mapping by requesting zero lifetime. +func (c *Client) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + if _, err := c.addPortMappingWithHint(ctx, protocol, internalPort, 0, netip.Addr{}, 0); err != nil { + var pcpErr *Error + if errors.As(err, &pcpErr) && pcpErr.Code == ResultNotAuthorized { + return nil + } + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// GetExternalAddress returns the external IP address. +// First checks for a cached value from previous MAP responses. +// If not cached, creates a short-lived mapping to discover the external IP. +func (c *Client) GetExternalAddress(ctx context.Context) (net.IP, error) { + c.mu.Lock() + if c.externalIP.IsValid() { + ip := c.externalIP.AsSlice() + c.mu.Unlock() + return ip, nil + } + c.mu.Unlock() + + // Use an ephemeral port in the dynamic range (49152-65535). + // Port 0 is not valid with UDP/TCP protocols per RFC 6887. + ephemeralPort := 49152 + int(uint16(time.Now().UnixNano()))%(65535-49152) + + // Use minimal lifetime (1 second) for discovery. + resp, err := c.AddPortMapping(ctx, "udp", ephemeralPort, time.Second) + if err != nil { + return nil, fmt.Errorf("create temporary mapping: %w", err) + } + + if err := c.DeletePortMapping(ctx, "udp", ephemeralPort); err != nil { + log.Debugf("cleanup temporary PCP mapping: %v", err) + } + + return resp.ExternalIP.AsSlice(), nil +} + +// LastEpoch returns the last observed server epoch value. +// A decrease in epoch indicates the server may have restarted and mappings may be lost. +func (c *Client) LastEpoch() uint32 { + c.mu.Lock() + defer c.mu.Unlock() + return c.lastEpoch +} + +// EpochStateLost returns true if epoch state loss was detected and clears the flag. +func (c *Client) EpochStateLost() bool { + c.mu.Lock() + defer c.mu.Unlock() + lost := c.epochStateLost + c.epochStateLost = false + return lost +} + +// updateEpoch updates the epoch tracking and detects potential state loss. +// Returns true if state loss was detected (server likely restarted). +// Caller must hold c.mu. +func (c *Client) updateEpochLocked(newEpoch uint32) bool { + now := time.Now() + stateLost := false + + // RFC 6887 Section 8.5: Detect invalid epoch indicating server state loss. + // client_delta = time since last response + // server_delta = epoch change since last response + // Invalid if: client_delta+2 < server_delta - server_delta/16 + // OR: server_delta+2 < client_delta - client_delta/16 + // The +2 handles quantization, /16 (6.25%) handles clock drift. + if !c.epochTime.IsZero() && c.lastEpoch > 0 { + clientDelta := uint32(now.Sub(c.epochTime).Seconds()) + serverDelta := newEpoch - c.lastEpoch + + // Check for epoch going backwards or jumping unexpectedly. + // Subtraction is safe: serverDelta/16 is always <= serverDelta. + if clientDelta+2 < serverDelta-(serverDelta/16) || + serverDelta+2 < clientDelta-(clientDelta/16) { + stateLost = true + c.epochStateLost = true + } + } + + c.lastEpoch = newEpoch + c.epochTime = now + return stateLost +} + +// cacheExternalIP stores the external IP from a successful MAP response. +// Caller must hold c.mu. +func (c *Client) cacheExternalIPLocked(ip netip.Addr) { + if ip.IsValid() && !ip.IsUnspecified() { + c.externalIP = ip + } +} + +// sendRequest sends a PCP request with retries per RFC 6887 Section 8.1.1. +func (c *Client) sendRequest(ctx context.Context, req []byte) ([]byte, error) { + addr := &net.UDPAddr{IP: c.gateway.AsSlice(), Port: Port} + + var lastErr error + delay := initialRetryDelay + + for range maxRetries { + resp, err := c.sendOnce(ctx, addr, req) + if err == nil { + return resp, nil + } + lastErr = err + + if ctx.Err() != nil { + return nil, ctx.Err() + } + + // RFC 6887 Section 8.1.1: RT = (1 + RAND) * MIN(2 * RTprev, MRT) + // RAND is random between -0.1 and +0.1 + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(retryDelayWithJitter(delay)): + } + delay = min(delay*2, maxRetryDelay) + } + + return nil, fmt.Errorf("PCP request failed after %d retries: %w", maxRetries, lastErr) +} + +// retryDelayWithJitter applies RFC 6887 jitter: multiply by (1 + RAND) where RAND is [-0.1, +0.1]. +func retryDelayWithJitter(d time.Duration) time.Duration { + var b [1]byte + _, _ = rand.Read(b[:]) + // Convert byte to range [-0.1, +0.1]: (b/255 * 0.2) - 0.1 + jitter := (float64(b[0])/255.0)*0.2 - 0.1 + return time.Duration(float64(d) * (1 + jitter)) +} + +func (c *Client) sendOnce(ctx context.Context, addr *net.UDPAddr, req []byte) ([]byte, error) { + // Use ListenUDP instead of DialUDP to validate response source address per RFC 6887 §8.3. + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, fmt.Errorf("listen: %w", err) + } + defer func() { + if err := conn.Close(); err != nil { + log.Debugf("close UDP connection: %v", err) + } + }() + + timeout := c.timeout + if deadline, ok := ctx.Deadline(); ok { + if remaining := time.Until(deadline); remaining < timeout { + timeout = remaining + } + } + + if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("set deadline: %w", err) + } + + if _, err := conn.WriteToUDP(req, addr); err != nil { + return nil, fmt.Errorf("write: %w", err) + } + + resp := make([]byte, responseBufferSize) + n, from, err := conn.ReadFromUDP(resp) + if err != nil { + return nil, fmt.Errorf("read: %w", err) + } + + // RFC 6887 §8.3: Validate response came from expected PCP server. + if !from.IP.Equal(addr.IP) { + return nil, fmt.Errorf("response from unexpected source %s (expected %s)", from.IP, addr.IP) + } + + return resp[:n], nil +} + +func (c *Client) getLocalIP() (netip.Addr, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.localIP.IsValid() { + return netip.Addr{}, fmt.Errorf("local IP not set for gateway %s", c.gateway) + } + return c.localIP, nil +} + +func protocolNumber(protocol string) (uint8, error) { + switch protocol { + case "udp", "UDP": + return ProtoUDP, nil + case "tcp", "TCP": + return ProtoTCP, nil + default: + return 0, fmt.Errorf("unsupported protocol: %s", protocol) + } +} + +// Error represents a PCP error response. +type Error struct { + Code uint8 + Message string +} + +func (e *Error) Error() string { + return fmt.Sprintf("PCP error: %s (%d)", e.Message, e.Code) +} diff --git a/client/internal/portforward/pcp/client_test.go b/client/internal/portforward/pcp/client_test.go new file mode 100644 index 000000000..79f44a426 --- /dev/null +++ b/client/internal/portforward/pcp/client_test.go @@ -0,0 +1,187 @@ +package pcp + +import ( + "context" + "net" + "net/netip" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestAddrConversion(t *testing.T) { + tests := []struct { + name string + addr netip.Addr + }{ + {"IPv4", netip.MustParseAddr("192.168.1.100")}, + {"IPv4 loopback", netip.MustParseAddr("127.0.0.1")}, + {"IPv6", netip.MustParseAddr("2001:db8::1")}, + {"IPv6 loopback", netip.MustParseAddr("::1")}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b16 := addrTo16(tt.addr) + + recovered := addrFrom16(b16) + assert.Equal(t, tt.addr, recovered, "address should round-trip") + }) + } +} + +func TestBuildAnnounceRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + req := buildAnnounceRequest(clientIP) + + require.Len(t, req, headerSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpAnnounce), req[1], "opcode") + + // Check client IP is properly encoded as IPv4-mapped IPv6 + assert.Equal(t, byte(0xff), req[18], "IPv4-mapped prefix byte 10") + assert.Equal(t, byte(0xff), req[19], "IPv4-mapped prefix byte 11") + assert.Equal(t, byte(192), req[20], "IP octet 1") + assert.Equal(t, byte(168), req[21], "IP octet 2") + assert.Equal(t, byte(1), req[22], "IP octet 3") + assert.Equal(t, byte(100), req[23], "IP octet 4") +} + +func TestBuildMapRequest(t *testing.T) { + clientIP := netip.MustParseAddr("192.168.1.100") + nonce := [12]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12} + req := buildMapRequest(clientIP, nonce, ProtoUDP, 51820, 51820, netip.Addr{}, 3600) + + require.Len(t, req, mapRequestSize) + assert.Equal(t, byte(Version), req[0], "version") + assert.Equal(t, byte(OpMap), req[1], "opcode") + + // Lifetime at bytes 4-7 + assert.Equal(t, uint32(3600), (uint32(req[4])<<24)|(uint32(req[5])<<16)|(uint32(req[6])<<8)|uint32(req[7]), "lifetime") + + // Nonce at bytes 24-35 + assert.Equal(t, nonce[:], req[24:36], "nonce") + + // Protocol at byte 36 + assert.Equal(t, byte(ProtoUDP), req[36], "protocol") + + // Internal port at bytes 40-41 + assert.Equal(t, uint16(51820), (uint16(req[40])<<8)|uint16(req[41]), "internal port") + + // External port at bytes 42-43 + assert.Equal(t, uint16(51820), (uint16(req[42])<<8)|uint16(req[43]), "external port") +} + +func TestParseResponse(t *testing.T) { + // Construct a valid ANNOUNCE response + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce | OpReply + // Result code = 0 (success) + // Lifetime = 0 + // Epoch = 12345 + resp[8] = 0 + resp[9] = 0 + resp[10] = 0x30 + resp[11] = 0x39 + + parsed, err := parseResponse(resp) + require.NoError(t, err) + assert.Equal(t, uint8(Version), parsed.Version) + assert.Equal(t, uint8(OpAnnounce|OpReply), parsed.Opcode) + assert.Equal(t, uint8(ResultSuccess), parsed.ResultCode) + assert.Equal(t, uint32(12345), parsed.Epoch) +} + +func TestParseResponseErrors(t *testing.T) { + t.Run("too short", func(t *testing.T) { + _, err := parseResponse([]byte{1, 2, 3}) + assert.Error(t, err) + }) + + t.Run("wrong version", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = 1 // Wrong version + resp[1] = OpReply + _, err := parseResponse(resp) + assert.Error(t, err) + }) + + t.Run("missing reply bit", func(t *testing.T) { + resp := make([]byte, headerSize) + resp[0] = Version + resp[1] = OpAnnounce // Missing OpReply bit + _, err := parseResponse(resp) + assert.Error(t, err) + }) +} + +func TestResultCodeString(t *testing.T) { + assert.Equal(t, "SUCCESS", ResultCodeString(ResultSuccess)) + assert.Equal(t, "NOT_AUTHORIZED", ResultCodeString(ResultNotAuthorized)) + assert.Equal(t, "ADDRESS_MISMATCH", ResultCodeString(ResultAddressMismatch)) + assert.Contains(t, ResultCodeString(255), "UNKNOWN") +} + +func TestProtocolNumber(t *testing.T) { + proto, err := protocolNumber("udp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + proto, err = protocolNumber("tcp") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoTCP), proto) + + proto, err = protocolNumber("UDP") + require.NoError(t, err) + assert.Equal(t, uint8(ProtoUDP), proto) + + _, err = protocolNumber("icmp") + assert.Error(t, err) +} + +func TestClientCreation(t *testing.T) { + gateway := netip.MustParseAddr("192.168.1.1").AsSlice() + + client := NewClient(gateway) + assert.Equal(t, net.IP(gateway), client.Gateway()) + assert.Equal(t, defaultTimeout, client.timeout) + + clientWithTimeout := NewClientWithTimeout(gateway, 5*time.Second) + assert.Equal(t, 5*time.Second, clientWithTimeout.timeout) +} + +func TestNATType(t *testing.T) { + n := NewNAT(netip.MustParseAddr("192.168.1.1").AsSlice(), netip.MustParseAddr("192.168.1.100").AsSlice()) + assert.Equal(t, "PCP", n.Type()) +} + +// Integration test - skipped unless PCP_TEST_GATEWAY env is set +func TestClientIntegration(t *testing.T) { + t.Skip("Integration test - run manually with PCP_TEST_GATEWAY=") + + gateway := netip.MustParseAddr("10.0.1.1").AsSlice() // Change to your test gateway + localIP := netip.MustParseAddr("10.0.1.100").AsSlice() // Change to your local IP + + client := NewClient(gateway) + client.SetLocalIP(localIP) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // Test ANNOUNCE + epoch, err := client.Announce(ctx) + require.NoError(t, err) + t.Logf("Server epoch: %d", epoch) + + // Test MAP + resp, err := client.AddPortMapping(ctx, "udp", 51820, 1*time.Hour) + require.NoError(t, err) + t.Logf("Mapping: internal=%d external=%d externalIP=%s", + resp.InternalPort, resp.ExternalPort, resp.ExternalIP) + + // Cleanup + err = client.DeletePortMapping(ctx, "udp", 51820) + require.NoError(t, err) +} diff --git a/client/internal/portforward/pcp/nat.go b/client/internal/portforward/pcp/nat.go new file mode 100644 index 000000000..1dc24274b --- /dev/null +++ b/client/internal/portforward/pcp/nat.go @@ -0,0 +1,209 @@ +package pcp + +import ( + "context" + "fmt" + "net" + "net/netip" + "sync" + "time" + + log "github.com/sirupsen/logrus" + + "github.com/libp2p/go-nat" + "github.com/libp2p/go-netroute" +) + +var _ nat.NAT = (*NAT)(nil) + +// NAT implements the go-nat NAT interface using PCP. +// Supports dual-stack (IPv4 and IPv6) when available. +// All methods are safe for concurrent use. +// +// TODO: IPv6 pinholes use the local IPv6 address. If the address changes +// (e.g., due to SLAAC rotation or network change), the pinhole becomes stale +// and needs to be recreated with the new address. +type NAT struct { + client *Client + + mu sync.RWMutex + // client6 is the IPv6 PCP client, nil if IPv6 is unavailable. + client6 *Client + // localIP6 caches the local IPv6 address used for PCP requests. + localIP6 netip.Addr +} + +// NewNAT creates a new NAT instance backed by PCP. +func NewNAT(gateway, localIP net.IP) *NAT { + client := NewClient(gateway) + client.SetLocalIP(localIP) + return &NAT{ + client: client, + } +} + +// Type returns "PCP" as the NAT type. +func (n *NAT) Type() string { + return "PCP" +} + +// GetDeviceAddress returns the gateway IP address. +func (n *NAT) GetDeviceAddress() (net.IP, error) { + return n.client.Gateway(), nil +} + +// GetExternalAddress returns the external IP address. +func (n *NAT) GetExternalAddress() (net.IP, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + return n.client.GetExternalAddress(ctx) +} + +// GetInternalAddress returns the local IP address used to communicate with the gateway. +func (n *NAT) GetInternalAddress() (net.IP, error) { + addr, err := n.client.getLocalIP() + if err != nil { + return nil, err + } + return addr.AsSlice(), nil +} + +// AddPortMapping creates a port mapping on both IPv4 and IPv6 (if available). +func (n *NAT) AddPortMapping(ctx context.Context, protocol string, internalPort int, _ string, timeout time.Duration) (int, error) { + resp, err := n.client.AddPortMapping(ctx, protocol, internalPort, timeout) + if err != nil { + return 0, fmt.Errorf("add mapping: %w", err) + } + + n.mu.RLock() + client6 := n.client6 + localIP6 := n.localIP6 + n.mu.RUnlock() + + if client6 == nil { + return int(resp.ExternalPort), nil + } + + if _, err := client6.AddPortMapping(ctx, protocol, internalPort, timeout); err != nil { + log.Warnf("IPv6 PCP mapping failed (continuing with IPv4): %v", err) + return int(resp.ExternalPort), nil + } + + log.Infof("created IPv6 PCP pinhole: %s:%d", localIP6, internalPort) + return int(resp.ExternalPort), nil +} + +// DeletePortMapping removes a port mapping from both IPv4 and IPv6. +func (n *NAT) DeletePortMapping(ctx context.Context, protocol string, internalPort int) error { + err := n.client.DeletePortMapping(ctx, protocol, internalPort) + + n.mu.RLock() + client6 := n.client6 + n.mu.RUnlock() + + if client6 != nil { + if err6 := client6.DeletePortMapping(ctx, protocol, internalPort); err6 != nil { + log.Warnf("IPv6 PCP delete mapping failed: %v", err6) + } + } + + if err != nil { + return fmt.Errorf("delete mapping: %w", err) + } + return nil +} + +// CheckServerHealth sends an ANNOUNCE to verify the server is still responsive. +// Returns the current epoch and whether the server may have restarted (epoch state loss detected). +func (n *NAT) CheckServerHealth(ctx context.Context) (epoch uint32, serverRestarted bool, err error) { + epoch, err = n.client.Announce(ctx) + if err != nil { + return 0, false, fmt.Errorf("announce: %w", err) + } + return epoch, n.client.EpochStateLost(), nil +} + +// DiscoverPCP attempts to discover a PCP-capable gateway. +// Returns a NAT interface if PCP is supported, or an error otherwise. +// Discovers both IPv4 and IPv6 gateways when available. +func DiscoverPCP(ctx context.Context) (nat.NAT, error) { + gateway, localIP, err := getDefaultGateway() + if err != nil { + return nil, fmt.Errorf("get default gateway: %w", err) + } + + client := NewClient(gateway) + client.SetLocalIP(localIP) + if _, err := client.Announce(ctx); err != nil { + return nil, fmt.Errorf("PCP announce: %w", err) + } + + result := &NAT{client: client} + discoverIPv6(ctx, result) + + return result, nil +} + +func discoverIPv6(ctx context.Context, result *NAT) { + gateway6, localIP6, err := getDefaultGateway6() + if err != nil { + log.Debugf("IPv6 gateway discovery failed: %v", err) + return + } + + client6 := NewClient(gateway6) + client6.SetLocalIP(localIP6) + if _, err := client6.Announce(ctx); err != nil { + log.Debugf("PCP IPv6 announce failed: %v", err) + return + } + + addr, ok := netip.AddrFromSlice(localIP6) + if !ok { + log.Debugf("invalid IPv6 local IP: %v", localIP6) + return + } + result.mu.Lock() + result.client6 = client6 + result.localIP6 = addr + result.mu.Unlock() + log.Debugf("PCP IPv6 gateway discovered: %s (local: %s)", gateway6, localIP6) +} + +// getDefaultGateway returns the default IPv4 gateway and local IP using the system routing table. +func getDefaultGateway() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv4zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} + +// getDefaultGateway6 returns the default IPv6 gateway IP address using the system routing table. +func getDefaultGateway6() (gateway net.IP, localIP net.IP, err error) { + router, err := netroute.New() + if err != nil { + return nil, nil, err + } + + _, gateway, localIP, err = router.Route(net.IPv6zero) + if err != nil { + return nil, nil, err + } + + if gateway == nil { + return nil, nil, nat.ErrNoNATFound + } + + return gateway, localIP, nil +} diff --git a/client/internal/portforward/pcp/protocol.go b/client/internal/portforward/pcp/protocol.go new file mode 100644 index 000000000..d81c50c8c --- /dev/null +++ b/client/internal/portforward/pcp/protocol.go @@ -0,0 +1,225 @@ +// Package pcp implements the Port Control Protocol (RFC 6887). +// +// # Implemented Features +// +// - ANNOUNCE opcode: Discovers PCP server support +// - MAP opcode: Creates/deletes port mappings (IPv4 NAT) and firewall pinholes (IPv6) +// - Dual-stack: Simultaneous IPv4 and IPv6 support via separate clients +// - Nonce validation: Prevents response spoofing +// - Epoch tracking: Detects server restarts per Section 8.5 +// - RFC-compliant retry timing: 3s initial, exponential backoff to 1024s max (Section 8.1.1) +// +// # Not Implemented +// +// - PEER opcode: For outbound peer connections (not needed for inbound NAT traversal) +// - THIRD_PARTY option: For managing mappings on behalf of other devices +// - PREFER_FAILURE option: Requires exact external port or fail (IPv4 NAT only, not needed for IPv6 pinholing) +// - FILTER option: To restrict remote peer addresses +// +// These optional features are omitted because the primary use case is simple +// port forwarding for WireGuard, which only requires MAP with default behavior. +package pcp + +import ( + "encoding/binary" + "fmt" + "net/netip" +) + +const ( + // Version is the PCP protocol version (RFC 6887). + Version = 2 + + // Port is the standard PCP server port. + Port = 5351 + + // DefaultLifetime is the default requested mapping lifetime in seconds. + DefaultLifetime = 7200 // 2 hours + + // Header sizes + headerSize = 24 + mapPayloadSize = 36 + mapRequestSize = headerSize + mapPayloadSize // 60 bytes +) + +// Opcodes +const ( + OpAnnounce = 0 + OpMap = 1 + OpPeer = 2 + OpReply = 0x80 // OR'd with opcode in responses +) + +// Protocol numbers for MAP requests +const ( + ProtoUDP = 17 + ProtoTCP = 6 +) + +// Result codes (RFC 6887 Section 7.4) +const ( + ResultSuccess = 0 + ResultUnsuppVersion = 1 + ResultNotAuthorized = 2 + ResultMalformedRequest = 3 + ResultUnsuppOpcode = 4 + ResultUnsuppOption = 5 + ResultMalformedOption = 6 + ResultNetworkFailure = 7 + ResultNoResources = 8 + ResultUnsuppProtocol = 9 + ResultUserExQuota = 10 + ResultCannotProvideExt = 11 + ResultAddressMismatch = 12 + ResultExcessiveRemotePeers = 13 +) + +// ResultCodeString returns a human-readable string for a result code. +func ResultCodeString(code uint8) string { + switch code { + case ResultSuccess: + return "SUCCESS" + case ResultUnsuppVersion: + return "UNSUPP_VERSION" + case ResultNotAuthorized: + return "NOT_AUTHORIZED" + case ResultMalformedRequest: + return "MALFORMED_REQUEST" + case ResultUnsuppOpcode: + return "UNSUPP_OPCODE" + case ResultUnsuppOption: + return "UNSUPP_OPTION" + case ResultMalformedOption: + return "MALFORMED_OPTION" + case ResultNetworkFailure: + return "NETWORK_FAILURE" + case ResultNoResources: + return "NO_RESOURCES" + case ResultUnsuppProtocol: + return "UNSUPP_PROTOCOL" + case ResultUserExQuota: + return "USER_EX_QUOTA" + case ResultCannotProvideExt: + return "CANNOT_PROVIDE_EXTERNAL" + case ResultAddressMismatch: + return "ADDRESS_MISMATCH" + case ResultExcessiveRemotePeers: + return "EXCESSIVE_REMOTE_PEERS" + default: + return fmt.Sprintf("UNKNOWN(%d)", code) + } +} + +// Response represents a parsed PCP response header. +type Response struct { + Version uint8 + Opcode uint8 + ResultCode uint8 + Lifetime uint32 + Epoch uint32 +} + +// MapResponse contains the full response to a MAP request. +type MapResponse struct { + Response + Nonce [12]byte + Protocol uint8 + InternalPort uint16 + ExternalPort uint16 + ExternalIP netip.Addr +} + +// addrTo16 converts an address to its 16-byte IPv4-mapped IPv6 representation. +func addrTo16(addr netip.Addr) [16]byte { + if addr.Is4() { + return netip.AddrFrom4(addr.As4()).As16() + } + return addr.As16() +} + +// addrFrom16 extracts an address from a 16-byte representation, unmapping IPv4. +func addrFrom16(b [16]byte) netip.Addr { + return netip.AddrFrom16(b).Unmap() +} + +// buildAnnounceRequest creates a PCP ANNOUNCE request packet. +func buildAnnounceRequest(clientIP netip.Addr) []byte { + req := make([]byte, headerSize) + req[0] = Version + req[1] = OpAnnounce + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + return req +} + +// buildMapRequest creates a PCP MAP request packet. +func buildMapRequest(clientIP netip.Addr, nonce [12]byte, protocol uint8, internalPort, suggestedExtPort uint16, suggestedExtIP netip.Addr, lifetime uint32) []byte { + req := make([]byte, mapRequestSize) + + // Header + req[0] = Version + req[1] = OpMap + binary.BigEndian.PutUint32(req[4:8], lifetime) + mapped := addrTo16(clientIP) + copy(req[8:24], mapped[:]) + + // MAP payload + copy(req[24:36], nonce[:]) + req[36] = protocol + binary.BigEndian.PutUint16(req[40:42], internalPort) + binary.BigEndian.PutUint16(req[42:44], suggestedExtPort) + if suggestedExtIP.IsValid() { + extMapped := addrTo16(suggestedExtIP) + copy(req[44:60], extMapped[:]) + } + + return req +} + +// parseResponse parses the common PCP response header. +func parseResponse(data []byte) (*Response, error) { + if len(data) < headerSize { + return nil, fmt.Errorf("response too short: %d bytes", len(data)) + } + + resp := &Response{ + Version: data[0], + Opcode: data[1], + ResultCode: data[3], // Byte 2 is reserved, byte 3 is result code (RFC 6887 §7.2) + Lifetime: binary.BigEndian.Uint32(data[4:8]), + Epoch: binary.BigEndian.Uint32(data[8:12]), + } + + if resp.Version != Version { + return nil, fmt.Errorf("unsupported PCP version: %d", resp.Version) + } + + if resp.Opcode&OpReply == 0 { + return nil, fmt.Errorf("response missing reply bit: opcode=0x%02x", resp.Opcode) + } + + return resp, nil +} + +// parseMapResponse parses a complete MAP response. +func parseMapResponse(data []byte) (*MapResponse, error) { + if len(data) < mapRequestSize { + return nil, fmt.Errorf("MAP response too short: %d bytes", len(data)) + } + + resp, err := parseResponse(data) + if err != nil { + return nil, fmt.Errorf("parse header: %w", err) + } + + mapResp := &MapResponse{ + Response: *resp, + Protocol: data[36], + InternalPort: binary.BigEndian.Uint16(data[40:42]), + ExternalPort: binary.BigEndian.Uint16(data[42:44]), + ExternalIP: addrFrom16([16]byte(data[44:60])), + } + copy(mapResp.Nonce[:], data[24:36]) + + return mapResp, nil +} diff --git a/client/internal/portforward/state.go b/client/internal/portforward/state.go new file mode 100644 index 000000000..b1315cdc0 --- /dev/null +++ b/client/internal/portforward/state.go @@ -0,0 +1,63 @@ +//go:build !js + +package portforward + +import ( + "context" + "fmt" + + "github.com/libp2p/go-nat" + log "github.com/sirupsen/logrus" + + "github.com/netbirdio/netbird/client/internal/portforward/pcp" +) + +// discoverGateway is the function used for NAT gateway discovery. +// It can be replaced in tests to avoid real network operations. +// Tries PCP first, then falls back to NAT-PMP/UPnP. +var discoverGateway = defaultDiscoverGateway + +func defaultDiscoverGateway(ctx context.Context) (nat.NAT, error) { + pcpGateway, err := pcp.DiscoverPCP(ctx) + if err == nil { + return pcpGateway, nil + } + log.Debugf("PCP discovery failed: %v, trying NAT-PMP/UPnP", err) + + return nat.DiscoverGateway(ctx) +} + +// State is persisted only for crash recovery cleanup +type State struct { + InternalPort uint16 `json:"internal_port,omitempty"` + Protocol string `json:"protocol,omitempty"` +} + +func (s *State) Name() string { + return "port_forward_state" +} + +// Cleanup implements statemanager.CleanableState for crash recovery +func (s *State) Cleanup() error { + if s.InternalPort == 0 { + return nil + } + + log.Infof("cleaning up stale port mapping for port %d", s.InternalPort) + + ctx, cancel := context.WithTimeout(context.Background(), discoveryTimeout) + defer cancel() + + gateway, err := discoverGateway(ctx) + if err != nil { + // Discovery failure is not an error - gateway may not exist + log.Debugf("cleanup: no gateway found: %v", err) + return nil + } + + if err := gateway.DeletePortMapping(ctx, s.Protocol, int(s.InternalPort)); err != nil { + return fmt.Errorf("delete port mapping: %w", err) + } + + return nil +} diff --git a/client/internal/routemanager/systemops/systemops_bsd_other.go b/client/internal/routemanager/systemops/systemops_bsd_other.go new file mode 100644 index 000000000..3f09219aa --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_bsd_other.go @@ -0,0 +1,10 @@ +//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin + +package systemops + +// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They +// always fall through to the ref-counter exclusion-route path; these stubs +// exist only so systemops_unix.go compiles. +func (r *SysOps) setupAdvancedRouting() error { return nil } +func (r *SysOps) cleanupAdvancedRouting() error { return nil } +func (r *SysOps) flushPlatformExtras() error { return nil } diff --git a/client/internal/routemanager/systemops/systemops_darwin.go b/client/internal/routemanager/systemops/systemops_darwin.go new file mode 100644 index 000000000..d6875ff95 --- /dev/null +++ b/client/internal/routemanager/systemops/systemops_darwin.go @@ -0,0 +1,241 @@ +//go:build darwin && !ios + +package systemops + +import ( + "errors" + "fmt" + "net/netip" + "os" + "time" + + "github.com/hashicorp/go-multierror" + log "github.com/sirupsen/logrus" + "golang.org/x/net/route" + "golang.org/x/sys/unix" + + nberrors "github.com/netbirdio/netbird/client/errors" + "github.com/netbirdio/netbird/client/internal/routemanager/vars" + nbnet "github.com/netbirdio/netbird/client/net" +) + +// scopedRouteBudget bounds retries for the scoped default route. Installing or +// deleting it matters enough that we're willing to spend longer waiting for the +// kernel reply than for per-prefix exclusion routes. +const scopedRouteBudget = 5 * time.Second + +// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family +// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can +// resolve gateway'd destinations while the VPN's split default owns the +// unscoped table. +// +// Timing note: this runs during routeManager.Init, which happens before the +// VPN interface is created and before any peer routes propagate. The initial +// mgmt / signal / relay TCP dials always fire before this runs, so those +// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route +// lookup, which at that point correctly picks the physical default. Those +// already-established TCP flows keep their originally-selected interface for +// their lifetime on Darwin because the kernel caches the egress route +// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default +// afterwards does not migrate them since the original en0 default stays in +// the table. Any subsequent reconnect via nbnet.NewDialer picks up the +// populated bound-iface cache and gets IP_BOUND_IF set cleanly. +func (r *SysOps) setupAdvancedRouting() error { + // Drop any previously-cached egress interface before reinstalling. On a + // refresh, a family that no longer resolves would otherwise keep the stale + // binding, causing new sockets to scope to an interface without a matching + // scoped default. + nbnet.ClearBoundInterfaces() + + if err := r.flushScopedDefaults(); err != nil { + log.Warnf("flush residual scoped defaults: %v", err) + } + + var merr *multierror.Error + installed := 0 + + for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} { + ok, err := r.installScopedDefaultFor(unspec) + if err != nil { + merr = multierror.Append(merr, err) + continue + } + if ok { + installed++ + } + } + + if installed == 0 && merr != nil { + return nberrors.FormatErrorOrNil(merr) + } + if merr != nil { + log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr)) + } + return nil +} + +// installScopedDefaultFor resolves the physical default nexthop for the given +// address family, installs a scoped default via it, and caches the iface for +// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds. +func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) { + nexthop, err := GetNextHop(unspec) + if err != nil { + if errors.Is(err, vars.ErrRouteNotFound) { + return false, nil + } + return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err) + } + if nexthop.Intf == nil { + return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) + } + + if err := r.addScopedDefault(unspec, nexthop); err != nil { + return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err) + } + + af := unix.AF_INET + if unspec.Is6() { + af = unix.AF_INET6 + } + nbnet.SetBoundInterface(af, nexthop.Intf) + via := "point-to-point" + if nexthop.IP.IsValid() { + via = nexthop.IP.String() + } + log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec)) + return true, nil +} + +func (r *SysOps) cleanupAdvancedRouting() error { + nbnet.ClearBoundInterfaces() + return r.flushScopedDefaults() +} + +// flushPlatformExtras runs darwin-specific residual cleanup hooked into the +// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get +// removed on the next boot regardless of whether a profile is brought up. +func (r *SysOps) flushPlatformExtras() error { + return r.flushScopedDefaults() +} + +// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag. +// Safe to call at startup to clear residual entries from a prior session. +func (r *SysOps) flushScopedDefaults() error { + rib, err := retryFetchRIB() + if err != nil { + return fmt.Errorf("fetch routing table: %w", err) + } + + msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) + if err != nil { + return fmt.Errorf("parse routing table: %w", err) + } + + var merr *multierror.Error + removed := 0 + + for _, msg := range msgs { + rtMsg, ok := msg.(*route.RouteMessage) + if !ok { + continue + } + if rtMsg.Flags&routeProtoFlag == 0 { + continue + } + if rtMsg.Flags&unix.RTF_IFSCOPE == 0 { + continue + } + + info, err := MsgToRoute(rtMsg) + if err != nil { + log.Debugf("skip scoped flush: %v", err) + continue + } + if !info.Dst.IsValid() || info.Dst.Bits() != 0 { + continue + } + + if err := r.deleteScopedRoute(rtMsg); err != nil { + merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w", + info.Dst, rtMsg.Index, err)) + continue + } + removed++ + log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index) + } + + if removed > 0 { + log.Infof("flushed %d residual scoped default route(s)", removed) + } + return nberrors.FormatErrorOrNil(merr) +} + +func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error { + return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop) +} + +func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error { + // Preserve identifying flags from the stored route (including RTF_GATEWAY + // only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE. + keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag + del := &route.RouteMessage{ + Type: unix.RTM_DELETE, + Flags: rtMsg.Flags & keep, + Version: unix.RTM_VERSION, + Seq: r.getSeq(), + Index: rtMsg.Index, + Addrs: rtMsg.Addrs, + } + return r.writeRouteMessage(del, scopedRouteBudget) +} + +func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error { + flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag + + msg := &route.RouteMessage{ + Type: action, + Flags: flags, + Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), + Seq: r.getSeq(), + Index: nexthop.Intf.Index, + } + + const numAddrs = unix.RTAX_NETMASK + 1 + addrs := make([]route.Addr, numAddrs) + + dst, err := addrToRouteAddr(unspec) + if err != nil { + return fmt.Errorf("build destination: %w", err) + } + mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0)) + if err != nil { + return fmt.Errorf("build netmask: %w", err) + } + addrs[unix.RTAX_DST] = dst + addrs[unix.RTAX_NETMASK] = mask + + if nexthop.IP.IsValid() { + msg.Flags |= unix.RTF_GATEWAY + gw, err := addrToRouteAddr(nexthop.IP.Unmap()) + if err != nil { + return fmt.Errorf("build gateway: %w", err) + } + addrs[unix.RTAX_GATEWAY] = gw + } else { + addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{ + Index: nexthop.Intf.Index, + Name: nexthop.Intf.Name, + } + } + msg.Addrs = addrs + + return r.writeRouteMessage(msg, scopedRouteBudget) +} + +func afOf(a netip.Addr) string { + if a.Is4() { + return "IPv4" + } + return "IPv6" +} diff --git a/client/internal/routemanager/systemops/systemops_generic.go b/client/internal/routemanager/systemops/systemops_generic.go index ec219c7fe..4211eb057 100644 --- a/client/internal/routemanager/systemops/systemops_generic.go +++ b/client/internal/routemanager/systemops/systemops_generic.go @@ -21,6 +21,7 @@ import ( "github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/vars" "github.com/netbirdio/netbird/client/internal/statemanager" + nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/net/hooks" ) @@ -31,8 +32,6 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1) var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1) var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1) -var ErrRoutingIsSeparate = errors.New("routing is separate") - func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error { stateManager.RegisterState(&ShutdownState{}) @@ -397,12 +396,16 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) { } // IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix. +// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux, +// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped. func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) { - localRoutes, err := hasSeparateRouting() + if nbnet.AdvancedRouting() { + return false, netip.Prefix{} + } + + localRoutes, err := GetRoutesFromTable() if err != nil { - if !errors.Is(err, ErrRoutingIsSeparate) { - log.Errorf("Failed to get routes: %v", err) - } + log.Errorf("Failed to get routes: %v", err) return false, netip.Prefix{} } diff --git a/client/internal/routemanager/systemops/systemops_js.go b/client/internal/routemanager/systemops/systemops_js.go index 808507fc9..242571b3d 100644 --- a/client/internal/routemanager/systemops/systemops_js.go +++ b/client/internal/routemanager/systemops/systemops_js.go @@ -22,10 +22,6 @@ func GetRoutesFromTable() ([]netip.Prefix, error) { return []netip.Prefix{}, nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return []netip.Prefix{}, nil -} - // GetDetailedRoutesFromTable returns empty routes for WASM. func GetDetailedRoutesFromTable() ([]DetailedRoute, error) { return []DetailedRoute{}, nil diff --git a/client/internal/routemanager/systemops/systemops_linux.go b/client/internal/routemanager/systemops/systemops_linux.go index bd10f131f..39a9fd978 100644 --- a/client/internal/routemanager/systemops/systemops_linux.go +++ b/client/internal/routemanager/systemops/systemops_linux.go @@ -894,13 +894,6 @@ func getAddressFamily(prefix netip.Prefix) int { return netlink.FAMILY_V6 } -func hasSeparateRouting() ([]netip.Prefix, error) { - if !nbnet.AdvancedRouting() { - return GetRoutesFromTable() - } - return nil, ErrRoutingIsSeparate -} - func isOpErr(err error) bool { // EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) { diff --git a/client/internal/routemanager/systemops/systemops_nonlinux.go b/client/internal/routemanager/systemops/systemops_nonlinux.go index 905a7bc12..016a62ebd 100644 --- a/client/internal/routemanager/systemops/systemops_nonlinux.go +++ b/client/internal/routemanager/systemops/systemops_nonlinux.go @@ -48,10 +48,6 @@ func EnableIPForwarding() error { return nil } -func hasSeparateRouting() ([]netip.Prefix, error) { - return GetRoutesFromTable() -} - // GetIPRules returns IP rules for debugging (not supported on non-Linux platforms) func GetIPRules() ([]IPRule, error) { log.Infof("IP rules collection is not supported on %s", runtime.GOOS) diff --git a/client/internal/routemanager/systemops/systemops_unix.go b/client/internal/routemanager/systemops/systemops_unix.go index 7089178fb..2d3f9b69a 100644 --- a/client/internal/routemanager/systemops/systemops_unix.go +++ b/client/internal/routemanager/systemops/systemops_unix.go @@ -25,6 +25,9 @@ import ( const ( envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG" + + // routeBudget bounds retries for per-prefix exclusion route programming. + routeBudget = 1 * time.Second ) var routeProtoFlag int @@ -41,26 +44,42 @@ func init() { } func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.setupAdvancedRouting() + } + + log.Infof("Using legacy routing setup with ref counters") return r.setupRefCounter(initAddresses, stateManager) } func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error { + if advancedRouting { + return r.cleanupAdvancedRouting() + } + return r.cleanupRefCounter(stateManager) } // FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag. +// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a +// crashed prior session can't leave crud in the table. func (r *SysOps) FlushMarkedRoutes() error { + var merr *multierror.Error + + if err := r.flushPlatformExtras(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err)) + } + rib, err := retryFetchRIB() if err != nil { - return fmt.Errorf("fetch routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err))) } msgs, err := route.ParseRIB(route.RIBTypeRoute, rib) if err != nil { - return fmt.Errorf("parse routing table: %w", err) + return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err))) } - var merr *multierror.Error flushedCount := 0 for _, msg := range msgs { @@ -117,12 +136,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return fmt.Errorf("invalid prefix: %s", prefix) } - expBackOff := backoff.NewExponentialBackOff() - expBackOff.InitialInterval = 50 * time.Millisecond - expBackOff.MaxInterval = 500 * time.Millisecond - expBackOff.MaxElapsedTime = 1 * time.Second + msg, err := r.buildRouteMessage(action, prefix, nexthop) + if err != nil { + return fmt.Errorf("build route message: %w", err) + } - if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil { + if err := r.writeRouteMessage(msg, routeBudget); err != nil { a := "add" if action == unix.RTM_DELETE { a = "remove" @@ -132,50 +151,91 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e return nil } -func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error { - operation := func() error { - fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) - if err != nil { - return fmt.Errorf("open routing socket: %w", err) +// writeRouteMessage sends a route message over AF_ROUTE and waits for the +// kernel's matching reply, retrying transient failures until budget elapses. +// Callers do not need to manage sockets or seq numbers themselves. +func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error { + expBackOff := backoff.NewExponentialBackOff() + expBackOff.InitialInterval = 50 * time.Millisecond + expBackOff.MaxInterval = 500 * time.Millisecond + expBackOff.MaxElapsedTime = budget + + return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff) +} + +func routeMessageRoundtrip(msg *route.RouteMessage) error { + fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC) + if err != nil { + return fmt.Errorf("open routing socket: %w", err) + } + defer func() { + if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { + log.Warnf("close routing socket: %v", err) } - defer func() { - if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) { - log.Warnf("failed to close routing socket: %v", err) + }() + + tv := unix.Timeval{Sec: 1} + if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil { + return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err)) + } + + // AF_ROUTE is a broadcast channel: every route socket on the host sees + // every RTM_* event. With concurrent route programming the default + // per-socket queue overflows and our own reply gets dropped. + if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil { + log.Debugf("set SO_RCVBUF on route socket: %v", err) + } + + bytes, err := msg.Marshal() + if err != nil { + return backoff.Permanent(fmt.Errorf("marshal: %w", err)) + } + + if _, err = unix.Write(fd, bytes); err != nil { + if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { + return fmt.Errorf("write: %w", err) + } + return backoff.Permanent(fmt.Errorf("write: %w", err)) + } + return readRouteResponse(fd, msg.Type, msg.Seq) +} + +// readRouteResponse reads from the AF_ROUTE socket until it sees a reply +// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a +// broadcast channel: interface up/down, third-party route changes and neighbor +// discovery events can all land between our write and read, so we must filter. +func readRouteResponse(fd, wantType, wantSeq int) error { + pid := int32(os.Getpid()) + resp := make([]byte, 2048) + deadline := time.Now().Add(time.Second) + for { + if time.Now().After(deadline) { + // Transient: under concurrent pressure the kernel can drop our reply + // from the socket buffer. Let backoff.Retry re-send with a fresh seq. + return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq) + } + n, err := unix.Read(fd, resp) + if err != nil { + if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) { + // SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline. + continue } - }() - - msg, err := r.buildRouteMessage(action, prefix, nexthop) - if err != nil { - return backoff.Permanent(fmt.Errorf("build route message: %w", err)) + return backoff.Permanent(fmt.Errorf("read: %w", err)) } - - msgBytes, err := msg.Marshal() - if err != nil { - return backoff.Permanent(fmt.Errorf("marshal route message: %w", err)) + if n < int(unsafe.Sizeof(unix.RtMsghdr{})) { + continue } - - if _, err = unix.Write(fd, msgBytes); err != nil { - if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) { - return fmt.Errorf("write: %w", err) - } - return backoff.Permanent(fmt.Errorf("write: %w", err)) + hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0])) + // Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid) + // uniquely identifies our own reply among broadcast traffic. + if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid { + continue } - - respBuf := make([]byte, 2048) - n, err := unix.Read(fd, respBuf) - if err != nil { - return backoff.Permanent(fmt.Errorf("read route response: %w", err)) + if hdr.Errno != 0 { + return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno))) } - - if n > 0 { - if err := r.parseRouteResponse(respBuf[:n]); err != nil { - return backoff.Permanent(err) - } - } - return nil } - return operation } func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) { @@ -183,6 +243,7 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next Type: action, Flags: unix.RTF_UP | routeProtoFlag, Version: unix.RTM_VERSION, + ID: uintptr(os.Getpid()), Seq: r.getSeq(), } @@ -221,19 +282,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next return msg, nil } -func (r *SysOps) parseRouteResponse(buf []byte) error { - if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) { - return nil - } - - rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0])) - if rtMsg.Errno != 0 { - return fmt.Errorf("parse: %d", rtMsg.Errno) - } - - return nil -} - // addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr). func addrToRouteAddr(addr netip.Addr) (route.Addr, error) { if addr.Is4() { diff --git a/client/net/dialer_init_darwin.go b/client/net/dialer_init_darwin.go new file mode 100644 index 000000000..e18909ff7 --- /dev/null +++ b/client/net/dialer_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (d *Dialer) init() { + d.Dialer.Control = applyBoundIfToSocket +} diff --git a/client/net/dialer_init_generic.go b/client/net/dialer_init_generic.go index 18ebc6ad1..78973b47d 100644 --- a/client/net/dialer_init_generic.go +++ b/client/net/dialer_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_android.go b/client/net/env_android.go deleted file mode 100644 index 9d89951a1..000000000 --- a/client/net/env_android.go +++ /dev/null @@ -1,24 +0,0 @@ -//go:build android - -package net - -// Init initializes the network environment for Android -func Init() { - // No initialization needed on Android -} - -// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. -// Always returns true on Android since we cannot handle routes dynamically. -func AdvancedRouting() bool { - return true -} - -// SetVPNInterfaceName is a no-op on Android -func SetVPNInterfaceName(name string) { - // No-op on Android - not needed for Android VPN service -} - -// GetVPNInterfaceName returns empty string on Android -func GetVPNInterfaceName() string { - return "" -} diff --git a/client/net/env_windows.go b/client/net/env_bound_iface.go similarity index 71% rename from client/net/env_windows.go rename to client/net/env_bound_iface.go index 7e8868ba5..593988c2c 100644 --- a/client/net/env_windows.go +++ b/client/net/env_bound_iface.go @@ -1,4 +1,4 @@ -//go:build windows +//go:build (darwin && !ios) || windows package net @@ -24,17 +24,22 @@ func Init() { } func checkAdvancedRoutingSupport() bool { - var err error - var legacyRouting bool + legacyRouting := false if val := os.Getenv(envUseLegacyRouting); val != "" { - legacyRouting, err = strconv.ParseBool(val) + parsed, err := strconv.ParseBool(val) if err != nil { - log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err) + log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err) + } else { + legacyRouting = parsed } } - if legacyRouting || netstack.IsEnabled() { - log.Info("advanced routing has been requested to be disabled") + if legacyRouting { + log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting) + return false + } + if netstack.IsEnabled() { + log.Info("advanced routing disabled: netstack mode is enabled") return false } diff --git a/client/net/env_generic.go b/client/net/env_generic.go index f467930c3..18c10bb78 100644 --- a/client/net/env_generic.go +++ b/client/net/env_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows && !android +//go:build !linux && !windows && !darwin package net diff --git a/client/net/env_mobile.go b/client/net/env_mobile.go new file mode 100644 index 000000000..80b0fad8d --- /dev/null +++ b/client/net/env_mobile.go @@ -0,0 +1,25 @@ +//go:build ios || android + +package net + +// Init initializes the network environment for mobile platforms. +func Init() { + // no-op on mobile: routing scope is owned by the VPN extension. +} + +// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes. +// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension +// owns the routing scope. +func AdvancedRouting() bool { + return true +} + +// SetVPNInterfaceName is a no-op on mobile. +func SetVPNInterfaceName(string) { + // no-op on mobile: the VPN extension manages the interface. +} + +// GetVPNInterfaceName returns an empty string on mobile. +func GetVPNInterfaceName() string { + return "" +} diff --git a/client/net/listener_init_darwin.go b/client/net/listener_init_darwin.go new file mode 100644 index 000000000..f2fcc80ed --- /dev/null +++ b/client/net/listener_init_darwin.go @@ -0,0 +1,5 @@ +package net + +func (l *ListenerConfig) init() { + l.ListenConfig.Control = applyBoundIfToSocket +} diff --git a/client/net/listener_init_generic.go b/client/net/listener_init_generic.go index 4f8f17ab2..65a785222 100644 --- a/client/net/listener_init_generic.go +++ b/client/net/listener_init_generic.go @@ -1,4 +1,4 @@ -//go:build !linux && !windows +//go:build !linux && !windows && !darwin package net diff --git a/client/net/net_darwin.go b/client/net/net_darwin.go new file mode 100644 index 000000000..00d858a6a --- /dev/null +++ b/client/net/net_darwin.go @@ -0,0 +1,160 @@ +package net + +import ( + "fmt" + "net" + "net/netip" + "strconv" + "strings" + "sync" + "syscall" + + log "github.com/sirupsen/logrus" + "golang.org/x/sys/unix" +) + +// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack +// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6" +// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns +// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is +// dispatched by socket domain (AF_INET only) not by inp_vflag. + +// boundIface holds the physical interface chosen at routing setup time. Sockets +// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF +// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup +// hits the RTF_IFSCOPE default installed by the routemanager, rather than +// following the VPN's split default. +var ( + boundIfaceMu sync.RWMutex + boundIface4 *net.Interface + boundIface6 *net.Interface +) + +// SetBoundInterface records the egress interface for an address family. Called +// by the routemanager after a scoped default route has been installed. +// af must be unix.AF_INET or unix.AF_INET6; other values are ignored. +// nil iface is rejected — use ClearBoundInterfaces to clear all slots. +func SetBoundInterface(af int, iface *net.Interface) { + if iface == nil { + log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af) + return + } + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + switch af { + case unix.AF_INET: + boundIface4 = iface + case unix.AF_INET6: + boundIface6 = iface + default: + log.Warnf("SetBoundInterface: unsupported address family %d", af) + } +} + +// ClearBoundInterfaces resets the cached egress interfaces. Called by the +// routemanager during cleanup. +func ClearBoundInterfaces() { + boundIfaceMu.Lock() + defer boundIfaceMu.Unlock() + boundIface4 = nil + boundIface6 = nil +} + +// boundInterfaceFor returns the cached egress interface for a socket's address +// family, falling back to the other family if the preferred slot is empty. +// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so +// either setsockopt scopes the socket; preferring same-family still matters +// when v4 and v6 defaults egress different NICs. +func boundInterfaceFor(network, address string) *net.Interface { + if iface := zoneInterface(address); iface != nil { + return iface + } + + boundIfaceMu.RLock() + defer boundIfaceMu.RUnlock() + + primary, secondary := boundIface4, boundIface6 + if isV6Network(network) { + primary, secondary = boundIface6, boundIface4 + } + if primary != nil { + return primary + } + return secondary +} + +func isV6Network(network string) bool { + return strings.HasSuffix(network, "6") +} + +// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0). +func zoneInterface(address string) *net.Interface { + if address == "" { + return nil + } + addr, err := netip.ParseAddrPort(address) + if err != nil { + a, err := netip.ParseAddr(address) + if err != nil { + return nil + } + addr = netip.AddrPortFrom(a, 0) + } + zone := addr.Addr().Zone() + if zone == "" { + return nil + } + if iface, err := net.InterfaceByName(zone); err == nil { + return iface + } + if idx, err := strconv.Atoi(zone); err == nil { + if iface, err := net.InterfaceByIndex(idx); err == nil { + return iface + } + } + return nil +} + +func setIPv4BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +func setIPv6BoundIf(fd uintptr, iface *net.Interface) error { + if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil { + return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index) + } + return nil +} + +// applyBoundIfToSocket binds the socket to the cached physical egress interface +// so scoped route lookup avoids the VPN utun and egresses the underlay directly. +func applyBoundIfToSocket(network, address string, c syscall.RawConn) error { + if !AdvancedRouting() { + return nil + } + + iface := boundInterfaceFor(network, address) + if iface == nil { + log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address) + return nil + } + + isV6 := isV6Network(network) + var controlErr error + if err := c.Control(func(fd uintptr) { + if isV6 { + controlErr = setIPv6BoundIf(fd, iface) + } else { + controlErr = setIPv4BoundIf(fd, iface) + } + if controlErr == nil { + log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address) + } + }); err != nil { + return fmt.Errorf("control: %w", err) + } + return controlErr +} diff --git a/client/proto/daemon.pb.go b/client/proto/daemon.pb.go index 47283e216..d2a2c43c5 100644 --- a/client/proto/daemon.pb.go +++ b/client/proto/daemon.pb.go @@ -4979,6 +4979,7 @@ type GetFeaturesResponse struct { state protoimpl.MessageState `protogen:"open.v1"` DisableProfiles bool `protobuf:"varint,1,opt,name=disable_profiles,json=disableProfiles,proto3" json:"disable_profiles,omitempty"` DisableUpdateSettings bool `protobuf:"varint,2,opt,name=disable_update_settings,json=disableUpdateSettings,proto3" json:"disable_update_settings,omitempty"` + DisableNetworks bool `protobuf:"varint,3,opt,name=disable_networks,json=disableNetworks,proto3" json:"disable_networks,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -5027,6 +5028,13 @@ func (x *GetFeaturesResponse) GetDisableUpdateSettings() bool { return false } +func (x *GetFeaturesResponse) GetDisableNetworks() bool { + if x != nil { + return x.DisableNetworks + } + return false +} + type TriggerUpdateRequest struct { state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields @@ -6754,10 +6762,11 @@ const file_daemon_proto_rawDesc = "" + "\f_profileNameB\v\n" + "\t_username\"\x10\n" + "\x0eLogoutResponse\"\x14\n" + - "\x12GetFeaturesRequest\"x\n" + + "\x12GetFeaturesRequest\"\xa3\x01\n" + "\x13GetFeaturesResponse\x12)\n" + "\x10disable_profiles\x18\x01 \x01(\bR\x0fdisableProfiles\x126\n" + - "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\"\x16\n" + + "\x17disable_update_settings\x18\x02 \x01(\bR\x15disableUpdateSettings\x12)\n" + + "\x10disable_networks\x18\x03 \x01(\bR\x0fdisableNetworks\"\x16\n" + "\x14TriggerUpdateRequest\"M\n" + "\x15TriggerUpdateResponse\x12\x18\n" + "\asuccess\x18\x01 \x01(\bR\asuccess\x12\x1a\n" + diff --git a/client/proto/daemon.proto b/client/proto/daemon.proto index b8d1418f3..f3043f236 100644 --- a/client/proto/daemon.proto +++ b/client/proto/daemon.proto @@ -738,6 +738,7 @@ message GetFeaturesRequest{} message GetFeaturesResponse{ bool disable_profiles = 1; bool disable_update_settings = 2; + bool disable_networks = 3; } message TriggerUpdateRequest {} diff --git a/client/server/network.go b/client/server/network.go index bb1cce56c..76c5af40e 100644 --- a/client/server/network.go +++ b/client/server/network.go @@ -9,6 +9,8 @@ import ( "strings" "golang.org/x/exp/maps" + "google.golang.org/grpc/codes" + gstatus "google.golang.org/grpc/status" "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/route" @@ -27,6 +29,10 @@ func (s *Server) ListNetworks(context.Context, *proto.ListNetworksRequest) (*pro s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -118,6 +124,10 @@ func (s *Server) SelectNetworks(_ context.Context, req *proto.SelectNetworksRequ s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } @@ -164,6 +174,10 @@ func (s *Server) DeselectNetworks(_ context.Context, req *proto.SelectNetworksRe s.mutex.Lock() defer s.mutex.Unlock() + if s.networksDisabled { + return nil, gstatus.Errorf(codes.Unavailable, errNetworksDisabled) + } + if s.connectClient == nil { return nil, fmt.Errorf("not connected") } diff --git a/client/server/server.go b/client/server/server.go index 78d645d12..2e7aeec60 100644 --- a/client/server/server.go +++ b/client/server/server.go @@ -33,6 +33,7 @@ import ( "github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/proto" + "github.com/netbirdio/netbird/util/capture" "github.com/netbirdio/netbird/version" ) @@ -53,6 +54,7 @@ const ( errRestoreResidualState = "failed to restore residual state: %v" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" + errNetworksDisabled = "network selection is disabled by the administrator" ) var ErrServiceNotUp = errors.New("service is not up") @@ -90,6 +92,9 @@ type Server struct { updateSettingsDisabled bool captureEnabled bool bundleCapture *bundleCapture + // activeCapture is the session currently installed on the engine; guarded by s.mutex. + activeCapture *capture.Session + networksDisabled bool sleepHandler *sleephandler.SleepHandler @@ -106,7 +111,7 @@ type oauthAuthFlow struct { } // New server instance constructor. -func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool) *Server { +func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server { s := &Server{ rootCtx: ctx, logFile: logFile, @@ -116,6 +121,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable profilesDisabled: profilesDisabled, updateSettingsDisabled: updateSettingsDisabled, captureEnabled: captureEnabled, + networksDisabled: networksDisabled, jwtCache: newJWTCache(), } agent := &serverAgent{s} @@ -1631,6 +1637,7 @@ func (s *Server) GetFeatures(ctx context.Context, msg *proto.GetFeaturesRequest) features := &proto.GetFeaturesResponse{ DisableProfiles: s.checkProfilesDisabled(), DisableUpdateSettings: s.checkUpdateSettingsDisabled(), + DisableNetworks: s.networksDisabled, } return features, nil diff --git a/client/server/server_test.go b/client/server/server_test.go index c5148104f..caae3e6cd 100644 --- a/client/server/server_test.go +++ b/client/server/server_test.go @@ -36,6 +36,7 @@ import ( daemonProto "github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/permissions" "github.com/netbirdio/netbird/management/server/settings" @@ -103,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "debug", "", false, false, false) + s := New(ctx, "debug", "", false, false, false, false) s.config = config @@ -164,7 +165,7 @@ func TestServer_Up(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -234,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) { t.Fatalf("failed to set active profile state: %v", err) } - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) err = s.Start() require.NoError(t, err) @@ -309,7 +310,12 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve jobManager := job.NewJobManager(nil, store, peersManager) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore) + cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, "", err + } + + ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, settingsManagerMock, eventStore, cacheStore) metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) require.NoError(t, err) @@ -320,7 +326,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve requestBuffer := server.NewAccountRequestBuffer(context.Background(), store) peersUpdateManager := update_channel.NewPeersUpdateManager(metrics) networkMapController := controller.NewController(context.Background(), store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config) - accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { return nil, "", err } diff --git a/client/server/setconfig_test.go b/client/server/setconfig_test.go index 7f6847c43..b90b5653d 100644 --- a/client/server/setconfig_test.go +++ b/client/server/setconfig_test.go @@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) { require.NoError(t, err) ctx := context.Background() - s := New(ctx, "console", "", false, false, false) + s := New(ctx, "console", "", false, false, false, false) rosenpassEnabled := true rosenpassPermissive := true diff --git a/client/server/state.go b/client/server/state.go index 8dca6bde1..f2d823465 100644 --- a/client/server/state.go +++ b/client/server/state.go @@ -12,7 +12,6 @@ import ( "github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/statemanager" - nbnet "github.com/netbirdio/netbird/client/net" "github.com/netbirdio/netbird/client/proto" ) @@ -138,10 +137,8 @@ func restoreResidualState(ctx context.Context, statePath string) error { } // clean up any remaining routes independently of the state file - if !nbnet.AdvancedRouting() { - if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { - merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) - } + if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil { + merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err)) } return nberrors.FormatErrorOrNil(merr) diff --git a/client/ssh/config/manager.go b/client/ssh/config/manager.go index cc47fd2d2..6e584b2c3 100644 --- a/client/ssh/config/manager.go +++ b/client/ssh/config/manager.go @@ -187,24 +187,23 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) { return "", fmt.Errorf("get NetBird executable path: %w", err) } - hostLine := strings.Join(deduplicatedPatterns, " ") - config := fmt.Sprintf("Host %s\n", hostLine) - config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath) - config += " PreferredAuthentications password,publickey,keyboard-interactive\n" - config += " PasswordAuthentication yes\n" - config += " PubkeyAuthentication yes\n" - config += " BatchMode no\n" - config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) - config += " StrictHostKeyChecking no\n" + hostList := strings.Join(deduplicatedPatterns, ",") + config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath) + config += " PreferredAuthentications password,publickey,keyboard-interactive\n" + config += " PasswordAuthentication yes\n" + config += " PubkeyAuthentication yes\n" + config += " BatchMode no\n" + config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath) + config += " StrictHostKeyChecking no\n" if runtime.GOOS == "windows" { - config += " UserKnownHostsFile NUL\n" + config += " UserKnownHostsFile NUL\n" } else { - config += " UserKnownHostsFile /dev/null\n" + config += " UserKnownHostsFile /dev/null\n" } - config += " CheckHostIP no\n" - config += " LogLevel ERROR\n\n" + config += " CheckHostIP no\n" + config += " LogLevel ERROR\n\n" return config, nil } diff --git a/client/ssh/config/manager_test.go b/client/ssh/config/manager_test.go index dc3ad95b3..e7380c7f2 100644 --- a/client/ssh/config/manager_test.go +++ b/client/ssh/config/manager_test.go @@ -116,6 +116,37 @@ func TestManager_PeerLimit(t *testing.T) { assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") } +func TestManager_MatchHostFormat(t *testing.T) { + tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") + require.NoError(t, err) + defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }() + + manager := &Manager{ + sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), + sshConfigFile: "99-netbird.conf", + } + + peers := []PeerSSHInfo{ + {Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"}, + {Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"}, + } + + err = manager.SetupSSHClientConfig(peers) + require.NoError(t, err) + + configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) + content, err := os.ReadFile(configPath) + require.NoError(t, err) + configStr := string(content) + + // Must use "Match host" with comma-separated patterns, not a bare "Host" directive. + // A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block + // ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped. + assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive") + assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"", + "should use Match host with comma-separated patterns") +} + func TestManager_ForcedSSHConfig(t *testing.T) { // Set force environment variable t.Setenv(EnvForceSSHConfig, "true") diff --git a/client/system/info.go b/client/system/info.go index f2546cfe6..175d1f07f 100644 --- a/client/system/info.go +++ b/client/system/info.go @@ -2,7 +2,6 @@ package system import ( "context" - "net" "net/netip" "strings" @@ -145,59 +144,6 @@ func extractDeviceName(ctx context.Context, defaultName string) string { return v } -func networkAddresses() ([]NetworkAddress, error) { - interfaces, err := net.Interfaces() - if err != nil { - return nil, err - } - - var netAddresses []NetworkAddress - for _, iface := range interfaces { - if iface.Flags&net.FlagUp == 0 { - continue - } - if iface.HardwareAddr.String() == "" { - continue - } - addrs, err := iface.Addrs() - if err != nil { - continue - } - - for _, address := range addrs { - ipNet, ok := address.(*net.IPNet) - if !ok { - continue - } - - if ipNet.IP.IsLoopback() { - continue - } - - netAddr := NetworkAddress{ - NetIP: netip.MustParsePrefix(ipNet.String()), - Mac: iface.HardwareAddr.String(), - } - - if isDuplicated(netAddresses, netAddr) { - continue - } - - netAddresses = append(netAddresses, netAddr) - } - } - return netAddresses, nil -} - -func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { - for _, duplicated := range addresses { - if duplicated.NetIP == addr.NetIP { - return true - } - } - return false -} - // GetInfoWithChecks retrieves and parses the system information with applied checks. func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) { log.Debugf("gathering system information with checks: %d", len(checks)) diff --git a/client/system/info_ios.go b/client/system/info_ios.go index 322609db4..ad42b1edf 100644 --- a/client/system/info_ios.go +++ b/client/system/info_ios.go @@ -2,12 +2,16 @@ package system import ( "context" + "net" + "net/netip" "runtime" + log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/version" ) -// UpdateStaticInfoAsync is a no-op on Android as there is no static info to update +// UpdateStaticInfoAsync is a no-op on iOS as there is no static info to update func UpdateStaticInfoAsync() { // do nothing } @@ -15,11 +19,24 @@ func UpdateStaticInfoAsync() { // GetInfo retrieves and parses the system information func GetInfo(ctx context.Context) *Info { - // Convert fixed-size byte arrays to Go strings sysName := extractOsName(ctx, "sysName") swVersion := extractOsVersion(ctx, "swVersion") - gio := &Info{Kernel: sysName, OSVersion: swVersion, Platform: "unknown", OS: sysName, GoOS: runtime.GOOS, CPUs: runtime.NumCPU(), KernelVersion: swVersion} + addrs, err := networkAddresses() + if err != nil { + log.Warnf("failed to discover network addresses: %s", err) + } + + gio := &Info{ + Kernel: sysName, + OSVersion: swVersion, + Platform: "unknown", + OS: sysName, + GoOS: runtime.GOOS, + CPUs: runtime.NumCPU(), + KernelVersion: swVersion, + NetworkAddresses: addrs, + } gio.Hostname = extractDeviceName(ctx, "hostname") gio.NetbirdVersion = version.NetbirdVersion() gio.UIVersion = extractUserAgent(ctx) @@ -27,6 +44,66 @@ func GetInfo(ctx context.Context) *Info { return gio } +// networkAddresses returns the list of network addresses on iOS. +// On iOS, hardware (MAC) addresses are not available due to Apple's privacy +// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we +// leave Mac empty to match Android's behavior. We also skip the HardwareAddr +// check that other platforms use and filter out link-local addresses as they +// are not useful for posture checks. +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: ""}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} + // checkFileAndProcess checks if the file path exists and if a process is running at that path. func checkFileAndProcess(paths []string) ([]File, error) { return []File{}, nil diff --git a/client/system/network_addr.go b/client/system/network_addr.go new file mode 100644 index 000000000..5423cf8ad --- /dev/null +++ b/client/system/network_addr.go @@ -0,0 +1,66 @@ +//go:build !ios + +package system + +import ( + "net" + "net/netip" +) + +func networkAddresses() ([]NetworkAddress, error) { + interfaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + var netAddresses []NetworkAddress + for _, iface := range interfaces { + if iface.Flags&net.FlagUp == 0 { + continue + } + if iface.HardwareAddr.String() == "" { + continue + } + addrs, err := iface.Addrs() + if err != nil { + continue + } + + mac := iface.HardwareAddr.String() + for _, address := range addrs { + netAddr, ok := toNetworkAddress(address, mac) + if !ok { + continue + } + if isDuplicated(netAddresses, netAddr) { + continue + } + netAddresses = append(netAddresses, netAddr) + } + } + return netAddresses, nil +} + +func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) { + ipNet, ok := address.(*net.IPNet) + if !ok { + return NetworkAddress{}, false + } + if ipNet.IP.IsLoopback() { + return NetworkAddress{}, false + } + prefix, err := netip.ParsePrefix(ipNet.String()) + if err != nil { + return NetworkAddress{}, false + } + return NetworkAddress{NetIP: prefix, Mac: mac}, true +} + +func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool { + for _, duplicated := range addresses { + if duplicated.NetIP == addr.NetIP { + return true + } + } + return false +} diff --git a/client/ui/client_ui.go b/client/ui/client_ui.go index b1e0aec41..c149b2152 100644 --- a/client/ui/client_ui.go +++ b/client/ui/client_ui.go @@ -314,6 +314,7 @@ type serviceClient struct { lastNotifiedVersion string settingsEnabled bool profilesEnabled bool + networksEnabled bool showNetworks bool wNetworks fyne.Window wProfiles fyne.Window @@ -368,6 +369,7 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient { showAdvancedSettings: args.showSettings, showNetworks: args.showNetworks, + networksEnabled: true, } s.eventHandler = newEventHandler(s) @@ -920,8 +922,10 @@ func (s *serviceClient) updateStatus() error { s.mStatus.SetIcon(s.icConnectedDot) s.mUp.Disable() s.mDown.Enable() - s.mNetworks.Enable() - s.mExitNode.Enable() + if s.networksEnabled { + s.mNetworks.Enable() + s.mExitNode.Enable() + } s.startExitNodeRefresh() systrayIconState = true case status.Status == string(internal.StatusConnecting): @@ -1093,14 +1097,14 @@ func (s *serviceClient) onTrayReady() { s.getSrvConfig() time.Sleep(100 * time.Millisecond) // To prevent race condition caused by systray not being fully initialized and ignoring setIcon for { + // Check features before status so menus respect disable flags before being enabled + s.checkAndUpdateFeatures() + err := s.updateStatus() if err != nil { log.Errorf("error while updating status: %v", err) } - // Check features periodically to handle daemon restarts - s.checkAndUpdateFeatures() - time.Sleep(2 * time.Second) } }() @@ -1299,6 +1303,16 @@ func (s *serviceClient) checkAndUpdateFeatures() { s.mProfile.setEnabled(profilesEnabled) } } + + // Update networks and exit node menus based on current features + s.networksEnabled = features == nil || !features.DisableNetworks + if s.networksEnabled && s.connected { + s.mNetworks.Enable() + s.mExitNode.Enable() + } else { + s.mNetworks.Disable() + s.mExitNode.Disable() + } } // getFeatures from the daemon to determine which features are enabled/disabled. diff --git a/combined/config.yaml.example b/combined/config.yaml.example index dce658d89..af85b0477 100644 --- a/combined/config.yaml.example +++ b/combined/config.yaml.example @@ -119,6 +119,8 @@ server: # Reverse proxy settings (optional) # reverseProxy: - # trustedHTTPProxies: [] - # trustedHTTPProxiesCount: 0 - # trustedPeers: [] + # trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"]) + # trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies) + # trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"]) + # accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely). + # accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value. diff --git a/flow/client/client_test.go b/flow/client/client_test.go index 55157acbc..c8f5f4af4 100644 --- a/flow/client/client_test.go +++ b/flow/client/client_test.go @@ -457,6 +457,18 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second) require.NoError(t, err) + + // Cleanups run LIFO: the goroutine-drain registered here runs after Close below, + // which is when Receive has actually returned. Without this, the Receive goroutine + // can outlive the test and call t.Logf after teardown, panicking. + receiveDone := make(chan struct{}) + t.Cleanup(func() { + select { + case <-receiveDone: + case <-time.After(2 * time.Second): + t.Error("Receive goroutine did not exit after Close") + } + }) t.Cleanup(func() { err := client.Close() assert.NoError(t, err, "failed to close flow") @@ -468,6 +480,7 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) { receivedAfterReconnect := make(chan struct{}) go func() { + defer close(receiveDone) err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error { if msg.IsInitiator || len(msg.EventId) == 0 { return nil diff --git a/go.mod b/go.mod index 76fb8b7be..1b5861a37 100644 --- a/go.mod +++ b/go.mod @@ -17,12 +17,12 @@ require ( github.com/spf13/cobra v1.10.1 github.com/spf13/pflag v1.0.9 github.com/vishvananda/netlink v1.3.1 - golang.org/x/crypto v0.48.0 + golang.org/x/crypto v0.49.0 golang.org/x/sys v0.42.0 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/windows v0.5.3 - google.golang.org/grpc v1.79.3 + google.golang.org/grpc v1.80.0 google.golang.org/protobuf v1.36.11 gopkg.in/natefinch/lumberjack.v2 v2.2.1 ) @@ -71,7 +71,7 @@ require ( github.com/mdlayher/socket v0.5.1 github.com/miekg/dns v1.1.59 github.com/mitchellh/hashstructure/v2 v2.0.2 - github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 + github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 github.com/oapi-codegen/runtime v1.1.2 github.com/okta/okta-sdk-golang/v2 v2.18.0 @@ -115,13 +115,13 @@ require ( goauthentik.io/api/v3 v3.2023051.3 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b golang.org/x/mobile v0.0.0-20251113184115-a159579294ab - golang.org/x/mod v0.32.0 - golang.org/x/net v0.51.0 - golang.org/x/oauth2 v0.34.0 - golang.org/x/sync v0.19.0 - golang.org/x/term v0.40.0 - golang.org/x/time v0.14.0 - google.golang.org/api v0.257.0 + golang.org/x/mod v0.33.0 + golang.org/x/net v0.52.0 + golang.org/x/oauth2 v0.36.0 + golang.org/x/sync v0.20.0 + golang.org/x/term v0.41.0 + golang.org/x/time v0.15.0 + google.golang.org/api v0.276.0 gopkg.in/yaml.v3 v3.0.1 gorm.io/driver/mysql v1.5.7 gorm.io/driver/postgres v1.5.7 @@ -131,7 +131,7 @@ require ( ) require ( - cloud.google.com/go/auth v0.17.0 // indirect + cloud.google.com/go/auth v0.20.0 // indirect cloud.google.com/go/auth/oauth2adapt v0.2.8 // indirect cloud.google.com/go/compute/metadata v0.9.0 // indirect dario.cat/mergo v1.0.1 // indirect @@ -210,8 +210,8 @@ require ( github.com/google/btree v1.1.2 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/google/s2a-go v0.1.9 // indirect - github.com/googleapis/enterprise-certificate-proxy v0.3.7 // indirect - github.com/googleapis/gax-go/v2 v2.15.0 // indirect + github.com/googleapis/enterprise-certificate-proxy v0.3.14 // indirect + github.com/googleapis/gax-go/v2 v2.21.0 // indirect github.com/gorilla/handlers v1.5.2 // indirect github.com/hack-pad/go-indexeddb v0.3.2 // indirect github.com/hack-pad/safejs v0.1.0 // indirect @@ -295,16 +295,16 @@ require ( github.com/zeebo/blake3 v0.2.3 // indirect go.mongodb.org/mongo-driver v1.17.9 // indirect go.opentelemetry.io/auto/sdk v1.2.1 // indirect - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 // indirect go.opentelemetry.io/otel/sdk v1.43.0 // indirect go.opentelemetry.io/otel/trace v1.43.0 // indirect go.uber.org/multierr v1.11.0 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect golang.org/x/image v0.33.0 // indirect - golang.org/x/text v0.34.0 // indirect - golang.org/x/tools v0.41.0 // indirect + golang.org/x/text v0.35.0 // indirect + golang.org/x/tools v0.42.0 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect @@ -323,3 +323,5 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184 replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0 + +replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0 diff --git a/go.sum b/go.sum index f06f7deba..3772946e1 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,5 @@ -cloud.google.com/go/auth v0.17.0 h1:74yCm7hCj2rUyyAocqnFzsAYXgJhrG26XCFimrc/Kz4= -cloud.google.com/go/auth v0.17.0/go.mod h1:6wv/t5/6rOPAX4fJiRjKkJCvswLwdet7G8+UGXt7nCQ= +cloud.google.com/go/auth v0.20.0 h1:kXTssoVb4azsVDoUiF8KvxAqrsQcQtB53DcSgta74CA= +cloud.google.com/go/auth v0.20.0/go.mod h1:942/yi/itH1SsmpyrbnTMDgGfdy2BUqIKyd0cyYLc5Q= cloud.google.com/go/auth/oauth2adapt v0.2.8 h1:keo8NaayQZ6wimpNSmW5OPc283g65QNIiLpZnkHRbnc= cloud.google.com/go/auth/oauth2adapt v0.2.8/go.mod h1:XQ9y31RkqZCcwJWNSx2Xvric3RrU88hAYYbjDWYDL+c= cloud.google.com/go/compute/metadata v0.2.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= @@ -285,10 +285,10 @@ github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0= github.com/google/s2a-go v0.1.9/go.mod h1:YA0Ei2ZQL3acow2O62kdp9UlnvMmU7kA6Eutn0dXayM= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/googleapis/enterprise-certificate-proxy v0.3.7 h1:zrn2Ee/nWmHulBx5sAVrGgAa0f2/R35S4DJwfFaUPFQ= -github.com/googleapis/enterprise-certificate-proxy v0.3.7/go.mod h1:MkHOF77EYAE7qfSuSS9PU6g4Nt4e11cnsDUowfwewLA= -github.com/googleapis/gax-go/v2 v2.15.0 h1:SyjDc1mGgZU5LncH8gimWo9lW1DtIfPibOG81vgd/bo= -github.com/googleapis/gax-go/v2 v2.15.0/go.mod h1:zVVkkxAQHa1RQpg9z2AUCMnKhi0Qld9rcmyfL1OZhoc= +github.com/googleapis/enterprise-certificate-proxy v0.3.14 h1:yh8ncqsbUY4shRD5dA6RlzjJaT4hi3kII+zYw8wmLb8= +github.com/googleapis/enterprise-certificate-proxy v0.3.14/go.mod h1:vqVt9yG9480NtzREnTlmGSBmFrA+bzb0yl0TxoBQXOg= +github.com/googleapis/gax-go/v2 v2.21.0 h1:h45NjjzEO3faG9Lg/cFrBh2PgegVVgzqKzuZl/wMbiI= +github.com/googleapis/gax-go/v2 v2.21.0/go.mod h1:But/NJU6TnZsrLai/xBAQLLz+Hc7fHZJt/hsCz3Fih4= github.com/gopacket/gopacket v1.1.1 h1:zbx9F9d6A7sWNkFKrvMBZTfGgxFoY4NgUudFVVHMfcw= github.com/gopacket/gopacket v1.1.1/go.mod h1:HavMeONEl7W9036of9LbSWoonqhH7HA1+ZRO+rMIvFs= github.com/gorilla/handlers v1.5.2 h1:cLTUSsNkgcwhgRqvCNmdbRWG0A3N4F+M2nWKdScwyEE= @@ -400,8 +400,6 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k= github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE= github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4= -github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU= github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To= github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= @@ -449,12 +447,14 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U= github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU= +github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus= +github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk= github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI= github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51/go.mod h1:ZSIbPdBn5hePO8CpF1PekH2SfpTxg1PDhEwtbqZS7R8= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25 h1:iwAq/Ncaq0etl4uAlVsbNBzC1yY52o0AmY7uCm2AMTs= -github.com/netbirdio/management-integrations/integrations v0.0.0-20260210160626-df4b180c7b25/go.mod h1:y7CxagMYzg9dgu+masRqYM7BQlOGA5Y8US85MCNFPlY= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 h1:F3zS5fT9xzD1OFLfcdAE+3FfyiwjGukF1hvj0jErgs8= +github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42/go.mod h1:n47r67ZSPgwSmT/Z1o48JjZQW9YJ6m/6Bd/uAXkL3Pg= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9axERMVN63dqyFqnvuD+EMJHzM7mNGON8= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= @@ -664,8 +664,8 @@ go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0 h1:yI1/OhfEPy7J9eoa6Sj051C7n5dvpj0QX8g4sRchg04= go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.67.0/go.mod h1:NoUCKYWK+3ecatC4HjkRktREheMeEtrXoQxrqYFeHSc= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0 h1:F7Jx+6hwnZ41NSFTO5q4LYDtJRXBf2PD0rNBkeB/lus= -go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.61.0/go.mod h1:UHB22Z8QsdRDrnAtX4PntOl36ajSxcdUMt1sF7Y6E7Q= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0 h1:OyrsyzuttWTSur2qN/Lm0m2a8yqyIjUVBZcxFPuXq2o= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.67.0/go.mod h1:C2NGBr+kAB4bk3xtMXfZ94gqFDtg/GkI7e9zqGh5Beg= go.opentelemetry.io/otel v1.43.0 h1:mYIM03dnh5zfN7HautFE4ieIig9amkNANT+xcVxAj9I= go.opentelemetry.io/otel v1.43.0/go.mod h1:JuG+u74mvjvcm8vj8pI5XiHy1zDeoCS2LB1spIq7Ay0= go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.19.0 h1:Mne5On7VWdx7omSrSSZvM4Kw7cS7NQkOOmLcgscI51U= @@ -707,8 +707,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= -golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= -golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/crypto v0.49.0 h1:+Ng2ULVvLHnJ/ZFEq4KdcDd/cfjrrjjNSXNzxg0Y4U4= +golang.org/x/crypto v0.49.0/go.mod h1:ErX4dUh2UM+CFYiXZRTcMpEcN8b/1gxEuv3nODoYtCA= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b h1:M2rDM6z3Fhozi9O7NWsxAkg/yqS/lQJ6PmkyIV3YP+o= golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b/go.mod h1:3//PLf8L/X+8b4vuAfHzxeRUl04Adcb341+IGKfnqS8= golang.org/x/image v0.33.0 h1:LXRZRnv1+zGd5XBUVRFmYEphyyKJjQjCRiOuAP3sZfQ= @@ -725,8 +725,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/mod v0.32.0 h1:9F4d3PHLljb6x//jOyokMv3eX+YDeepZSEo3mFJy93c= -golang.org/x/mod v0.32.0/go.mod h1:SgipZ/3h2Ci89DlEtEXWUk/HteuRin+HHhN+WbNhguU= +golang.org/x/mod v0.33.0 h1:tHFzIWbBifEmbwtGz65eaWyGiGZatSrT9prnU8DbVL8= +golang.org/x/mod v0.33.0/go.mod h1:swjeQEj+6r7fODbD2cqrnje9PnziFuw4bmLbBZFrQ5w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= @@ -745,11 +745,11 @@ golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM= -golang.org/x/net v0.51.0 h1:94R/GTO7mt3/4wIKpcR5gkGmRLOuE/2hNGeWq/GBIFo= -golang.org/x/net v0.51.0/go.mod h1:aamm+2QF5ogm02fjy5Bb7CQ0WMt1/WVM7FtyaTLlA9Y= +golang.org/x/net v0.52.0 h1:He/TN1l0e4mmR3QqHMT2Xab3Aj3L9qjbhRm78/6jrW0= +golang.org/x/net v0.52.0/go.mod h1:R1MAz7uMZxVMualyPXb+VaqGSa3LIaUqk0eEt3w36Sw= golang.org/x/oauth2 v0.8.0/go.mod h1:yr7u4HXZRm1R1kBWqr/xKNqewf0plRYoB7sla+BCIXE= -golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= -golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.36.0 h1:peZ/1z27fi9hUOFCAZaHyrpWG5lwe0RJEEEeH0ThlIs= +golang.org/x/oauth2 v0.36.0/go.mod h1:YDBUJMTkDnJS+A4BP4eZBjCqtokkg1hODuPjwiGPO7Q= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -761,8 +761,8 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y= golang.org/x/sync v0.6.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -811,8 +811,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= -golang.org/x/term v0.40.0 h1:36e4zGLqU4yhjlmxEaagx2KuYbJq3EwY8K943ZsHcvg= -golang.org/x/term v0.40.0/go.mod h1:w2P8uVp06p2iyKKuvXIm7N/y0UCRt3UfJTfZ7oOpglM= +golang.org/x/term v0.41.0 h1:QCgPso/Q3RTJx2Th4bDLqML4W6iJiaXFq2/ftQF13YU= +golang.org/x/term v0.41.0/go.mod h1:3pfBgksrReYfZ5lvYM0kSO0LIkAl4Yl2bXOkKP7Ec2A= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= @@ -824,10 +824,10 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= -golang.org/x/text v0.34.0 h1:oL/Qq0Kdaqxa1KbNeMKwQq0reLCCaFtqu2eNuSeNHbk= -golang.org/x/text v0.34.0/go.mod h1:homfLqTYRFyVYemLBFl5GgL/DWEiH5wcsQ5gSh1yziA= -golang.org/x/time v0.14.0 h1:MRx4UaLrDotUKUdCIqzPC48t1Y9hANFKIRpNx+Te8PI= -golang.org/x/time v0.14.0/go.mod h1:eL/Oa2bBBK0TkX57Fyni+NgnyQQN4LitPmob2Hjnqw4= +golang.org/x/text v0.35.0 h1:JOVx6vVDFokkpaq1AEptVzLTpDe9KGpj5tR4/X+ybL8= +golang.org/x/text v0.35.0/go.mod h1:khi/HExzZJ2pGnjenulevKNX1W67CUy0AsXcNubPGCA= +golang.org/x/time v0.15.0 h1:bbrp8t3bGUeFOx08pvsMYRTCVSMk89u4tKbNOZbp88U= +golang.org/x/time v0.15.0/go.mod h1:Y4YMaQmXwGQZoFaVFk4YpCt4FLQMYKZe9oeV/f4MSno= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= @@ -839,8 +839,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= -golang.org/x/tools v0.41.0 h1:a9b8iMweWG+S0OBnlU36rzLp20z1Rp10w+IY2czHTQc= -golang.org/x/tools v0.41.0/go.mod h1:XSY6eDqxVNiYgezAVqqCeihT4j1U2CCsqvH3WhQpnlg= +golang.org/x/tools v0.42.0 h1:uNgphsn75Tdz5Ji2q36v/nsFSfR/9BRFvqhGBaJGd5k= +golang.org/x/tools v0.42.0/go.mod h1:Ma6lCIwGZvHK6XtgbswSoWroEkhugApmsXyrUmBhfr0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -851,19 +851,19 @@ golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 h1:CawjfCvY golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6/go.mod h1:3rxYc4HtVcSG9gVaTs2GEBdehh+sYPOwKtyUWEOTb80= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= -gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= -google.golang.org/api v0.257.0 h1:8Y0lzvHlZps53PEaw+G29SsQIkuKrumGWs9puiexNAA= -google.golang.org/api v0.257.0/go.mod h1:4eJrr+vbVaZSqs7vovFd1Jb/A6ml6iw2e6FBYf3GAO4= +gonum.org/v1/gonum v0.17.0 h1:VbpOemQlsSMrYmn7T2OUvQ4dqxQXU+ouZFQsZOx50z4= +gonum.org/v1/gonum v0.17.0/go.mod h1:El3tOrEuMpv2UdMrbNlKEh9vd86bmQ6vqIcDwxEOc1E= +google.golang.org/api v0.276.0 h1:nVArUtfLEihtW+b0DdcqRGK1xoEm2+ltAihyztq7MKY= +google.golang.org/api v0.276.0/go.mod h1:Fnag/EWUPIcJXuIkP1pjoTgS5vdxlk3eeemL7Do6bvw= google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822 h1:rHWScKit0gvAPuOnu87KpaYtjK5zBMLcULh7gxkCXu4= -google.golang.org/genproto v0.0.0-20250603155806-513f23925822/go.mod h1:HubltRL7rMh0LfnQPkMH4NPDFEWp0jw3vixw7jEM53s= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= -google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171 h1:ggcbiqK8WWh6l1dnltU4BgWGIGo+EVYxCaAPih/zQXQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20260226221140-a57be14db171/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= -google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= -google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7 h1:XzmzkmB14QhVhgnawEVsOn6OFsnpyxNPRY9QV01dNB0= +google.golang.org/genproto v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:L43LFes82YgSonw6iTXTxXUX1OlULt4AQtkik4ULL/I= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7 h1:41r6JMbpzBMen0R/4TZeeAmGXSJC7DftGINUodzTkPI= +google.golang.org/genproto/googleapis/api v0.0.0-20260319201613-d00831a3d3e7/go.mod h1:EIQZ5bFCfRQDV4MhRle7+OgjNtZ6P1PiZBgAKuxXu/Y= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9 h1:m8qni9SQFH0tJc1X0vmnpw/0t+AImlSvp30sEupozUg= +google.golang.org/genproto/googleapis/rpc v0.0.0-20260401024825-9d38bb4040a9/go.mod h1:4Hqkh8ycfw05ld/3BWL7rJOSfebL2Q+DVDeRgYgxUU8= +google.golang.org/grpc v1.80.0 h1:Xr6m2WmWZLETvUNvIUmeD5OAagMw3FiKmMlTdViWsHM= +google.golang.org/grpc v1.80.0/go.mod h1:ho/dLnxwi3EDJA4Zghp7k2Ec1+c2jqup0bFkw07bwF4= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/infrastructure_files/getting-started.sh b/infrastructure_files/getting-started.sh index 9236d851d..2a3f840b4 100755 --- a/infrastructure_files/getting-started.sh +++ b/infrastructure_files/getting-started.sh @@ -182,6 +182,23 @@ read_enable_proxy() { return 0 } +read_enable_crowdsec() { + echo "" > /dev/stderr + echo "Do you want to enable CrowdSec IP reputation blocking?" > /dev/stderr + echo "CrowdSec checks client IPs against a community threat intelligence database" > /dev/stderr + echo "and blocks known malicious sources before they reach your services." > /dev/stderr + echo "A local CrowdSec LAPI container will be added to your deployment." > /dev/stderr + echo -n "Enable CrowdSec? [y/N]: " > /dev/stderr + read -r CHOICE < /dev/tty + + if [[ "$CHOICE" =~ ^[Yy]$ ]]; then + echo "true" + else + echo "false" + fi + return 0 +} + read_traefik_acme_email() { echo "" > /dev/stderr echo "Enter your email for Let's Encrypt certificate notifications." > /dev/stderr @@ -297,6 +314,10 @@ initialize_default_values() { # NetBird Proxy configuration ENABLE_PROXY="false" PROXY_TOKEN="" + + # CrowdSec configuration + ENABLE_CROWDSEC="false" + CROWDSEC_BOUNCER_KEY="" return 0 } @@ -325,6 +346,9 @@ configure_reverse_proxy() { if [[ "$REVERSE_PROXY_TYPE" == "0" ]]; then TRAEFIK_ACME_EMAIL=$(read_traefik_acme_email) ENABLE_PROXY=$(read_enable_proxy) + if [[ "$ENABLE_PROXY" == "true" ]]; then + ENABLE_CROWDSEC=$(read_enable_crowdsec) + fi fi # Handle external Traefik-specific prompts (option 1) @@ -354,7 +378,7 @@ check_existing_installation() { echo "Generated files already exist, if you want to reinitialize the environment, please remove them first." echo "You can use the following commands:" echo " $DOCKER_COMPOSE_COMMAND down --volumes # to remove all containers and volumes" - echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt" + echo " rm -f docker-compose.yml dashboard.env config.yaml proxy.env traefik-dynamic.yaml nginx-netbird.conf caddyfile-netbird.txt npm-advanced-config.txt && rm -rf crowdsec/" echo "Be aware that this will remove all data from the database, and you will have to reconfigure the dashboard." exit 1 fi @@ -375,6 +399,9 @@ generate_configuration_files() { echo "NB_PROXY_TOKEN=placeholder" >> proxy.env # TCP ServersTransport for PROXY protocol v2 to the proxy backend render_traefik_dynamic > traefik-dynamic.yaml + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + mkdir -p crowdsec + fi fi ;; 1) @@ -417,8 +444,12 @@ start_services_and_show_instructions() { if [[ "$ENABLE_PROXY" == "true" ]]; then # Phase 1: Start core services (without proxy) + local core_services="traefik dashboard netbird-server" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + core_services="$core_services crowdsec" + fi echo "Starting core services..." - $DOCKER_COMPOSE_COMMAND up -d traefik dashboard netbird-server + $DOCKER_COMPOSE_COMMAND up -d $core_services sleep 3 wait_management_proxy traefik @@ -438,7 +469,33 @@ start_services_and_show_instructions() { echo "Proxy token created successfully." - # Generate proxy.env with the token + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + echo "Registering CrowdSec bouncer..." + local cs_retries=0 + while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do + cs_retries=$((cs_retries + 1)) + if [[ $cs_retries -ge 30 ]]; then + echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr + echo "You can register a bouncer manually later with:" > /dev/stderr + echo " docker exec netbird-crowdsec cscli bouncers add netbird-proxy -o raw" > /dev/stderr + ENABLE_CROWDSEC="false" + break + fi + sleep 2 + done + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + CROWDSEC_BOUNCER_KEY=$($DOCKER_COMPOSE_COMMAND exec -T crowdsec \ + cscli bouncers add netbird-proxy -o raw 2>/dev/null) + if [[ -z "$CROWDSEC_BOUNCER_KEY" ]]; then + echo "WARNING: Failed to create CrowdSec bouncer key. Skipping CrowdSec setup." > /dev/stderr + ENABLE_CROWDSEC="false" + else + echo "CrowdSec bouncer registered." + fi + fi + fi + render_proxy_env > proxy.env # Start proxy service @@ -525,11 +582,25 @@ render_docker_compose_traefik_builtin() { # Generate proxy service section and Traefik dynamic config if enabled local proxy_service="" local proxy_volumes="" + local crowdsec_service="" + local crowdsec_volumes="" local traefik_file_provider="" local traefik_dynamic_volume="" if [[ "$ENABLE_PROXY" == "true" ]]; then traefik_file_provider=' - "--providers.file.filename=/etc/traefik/dynamic.yaml"' traefik_dynamic_volume=" - ./traefik-dynamic.yaml:/etc/traefik/dynamic.yaml:ro" + + local proxy_depends=" + netbird-server: + condition: service_started" + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + proxy_depends=" + netbird-server: + condition: service_started + crowdsec: + condition: service_healthy" + fi + proxy_service=" # NetBird Proxy - exposes internal resources to the internet proxy: @@ -539,8 +610,7 @@ render_docker_compose_traefik_builtin() { - 51820:51820/udp restart: unless-stopped networks: [netbird] - depends_on: - - netbird-server + depends_on:${proxy_depends} env_file: - ./proxy.env volumes: @@ -563,6 +633,35 @@ render_docker_compose_traefik_builtin() { " proxy_volumes=" netbird_proxy_certs:" + + if [[ "$ENABLE_CROWDSEC" == "true" ]]; then + crowdsec_service=" + crowdsec: + image: crowdsecurity/crowdsec:v1.7.7 + container_name: netbird-crowdsec + restart: unless-stopped + networks: [netbird] + environment: + COLLECTIONS: crowdsecurity/linux + volumes: + - ./crowdsec:/etc/crowdsec + - crowdsec_db:/var/lib/crowdsec/data + healthcheck: + test: ["CMD", "cscli", "lapi", "status"] + interval: 10s + timeout: 5s + retries: 15 + labels: + - traefik.enable=false + logging: + driver: \"json-file\" + options: + max-size: \"500m\" + max-file: \"2\" +" + crowdsec_volumes=" + crowdsec_db:" + fi fi cat <" + echo " Get your enrollment key at: https://app.crowdsec.net" + echo "" + fi fi return 0 } diff --git a/management/internals/modules/reverseproxy/service/manager/manager_test.go b/management/internals/modules/reverseproxy/service/manager/manager_test.go index 69d48f10a..54ac8ab18 100644 --- a/management/internals/modules/reverseproxy/service/manager/manager_test.go +++ b/management/internals/modules/reverseproxy/service/manager/manager_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +19,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/mock_server" resourcetypes "github.com/netbirdio/netbird/management/server/networks/resources/types" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -29,6 +31,13 @@ import ( "github.com/netbirdio/netbird/shared/management/status" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + func TestInitializeServiceForCreate(t *testing.T) { ctx := context.Background() accountID := "test-account" @@ -422,10 +431,8 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) { newProxyServer := func(t *testing.T) *nbgrpc.ProxyServiceServer { t.Helper() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t)) srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) return srv } @@ -703,10 +710,8 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) { }, } - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) @@ -1128,10 +1133,8 @@ func TestDeleteService_DeletesTargets(t *testing.T) { mockPerms := permissions.NewMockManager(ctrl) mockAcct := account.NewMockManager(ctrl) - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 1*time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t)) proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil) proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter("")) diff --git a/management/internals/server/boot.go b/management/internals/server/boot.go index 88d37ca80..2b40c0aad 100644 --- a/management/internals/server/boot.go +++ b/management/internals/server/boot.go @@ -18,6 +18,7 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/keepalive" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/netbird/encryption" @@ -26,8 +27,10 @@ import ( accesslogsmanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/accesslogs/manager" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" nbContext "github.com/netbirdio/netbird/management/server/context" nbhttp "github.com/netbirdio/netbird/management/server/http" + "github.com/netbirdio/netbird/management/server/http/middleware" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/telemetry" mgmtProto "github.com/netbirdio/netbird/shared/management/proto" @@ -58,6 +61,18 @@ func (s *BaseServer) Metrics() telemetry.AppMetrics { }) } +// CacheStore returns a shared cache store backed by Redis or in-memory depending on the environment. +// All consumers should reuse this store to avoid creating multiple Redis connections. +func (s *BaseServer) CacheStore() cachestore.StoreInterface { + return Create(s, func() cachestore.StoreInterface { + cs, err := nbcache.NewStore(context.Background(), nbcache.DefaultStoreMaxTimeout, nbcache.DefaultStoreCleanupInterval, nbcache.DefaultStoreMaxConn) + if err != nil { + log.Fatalf("failed to create shared cache store: %v", err) + } + return cs + }) +} + func (s *BaseServer) Store() store.Store { return Create(s, func() store.Store { store, err := store.NewStore(context.Background(), s.Config.StoreConfig.Engine, s.Config.Datadir, s.Metrics(), false) @@ -95,7 +110,7 @@ func (s *BaseServer) EventStore() activity.Store { func (s *BaseServer) APIHandler() http.Handler { return Create(s, func() http.Handler { - httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies) + httpAPIHandler, err := nbhttp.NewAPIHandler(context.Background(), s.AccountManager(), s.NetworksManager(), s.ResourcesManager(), s.RoutesManager(), s.GroupsManager(), s.GeoLocationManager(), s.AuthManager(), s.Metrics(), s.IntegratedValidator(), s.ProxyController(), s.PermissionsManager(), s.PeersManager(), s.SettingsManager(), s.ZonesManager(), s.RecordsManager(), s.NetworkMapController(), s.IdpManager(), s.ServiceManager(), s.ReverseProxyDomainManager(), s.AccessLogsManager(), s.ReverseProxyGRPCServer(), s.Config.ReverseProxy.TrustedHTTPProxies, s.RateLimiter()) if err != nil { log.Fatalf("failed to create API handler: %v", err) } @@ -103,6 +118,15 @@ func (s *BaseServer) APIHandler() http.Handler { }) } +func (s *BaseServer) RateLimiter() *middleware.APIRateLimiter { + return Create(s, func() *middleware.APIRateLimiter { + cfg, enabled := middleware.RateLimiterConfigFromEnv() + limiter := middleware.NewAPIRateLimiter(cfg) + limiter.SetEnabled(enabled) + return limiter + }) +} + func (s *BaseServer) GRPCServer() *grpc.Server { return Create(s, func() *grpc.Server { trustedPeers := s.Config.ReverseProxy.TrustedPeers @@ -195,10 +219,7 @@ func (s *BaseServer) proxyOIDCConfig() nbgrpc.ProxyOIDCConfig { func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { return Create(s, func() *nbgrpc.OneTimeTokenStore { - tokenStore, err := nbgrpc.NewOneTimeTokenStore(context.Background(), 5*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create proxy token store: %v", err) - } + tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), s.CacheStore()) log.Info("One-time token store initialized for proxy authentication") return tokenStore }) @@ -206,11 +227,7 @@ func (s *BaseServer) ProxyTokenStore() *nbgrpc.OneTimeTokenStore { func (s *BaseServer) PKCEVerifierStore() *nbgrpc.PKCEVerifierStore { return Create(s, func() *nbgrpc.PKCEVerifierStore { - pkceStore, err := nbgrpc.NewPKCEVerifierStore(context.Background(), 10*time.Minute, 10*time.Minute, 100) - if err != nil { - log.Fatalf("failed to create PKCE verifier store: %v", err) - } - return pkceStore + return nbgrpc.NewPKCEVerifierStore(context.Background(), s.CacheStore()) }) } diff --git a/management/internals/server/controllers.go b/management/internals/server/controllers.go index c7eab3d19..9a8e45d33 100644 --- a/management/internals/server/controllers.go +++ b/management/internals/server/controllers.go @@ -41,7 +41,8 @@ func (s *BaseServer) IntegratedValidator() integrated_validator.IntegratedValida context.Background(), s.PeersManager(), s.SettingsManager(), - s.EventStore()) + s.EventStore(), + s.CacheStore()) if err != nil { log.Errorf("failed to create integrated peer validator: %v", err) } diff --git a/management/internals/server/modules.go b/management/internals/server/modules.go index 374ea5c81..9b2ec2989 100644 --- a/management/internals/server/modules.go +++ b/management/internals/server/modules.go @@ -100,7 +100,7 @@ func (s *BaseServer) PeersManager() peers.Manager { func (s *BaseServer) AccountManager() account.Manager { return Create(s, func() account.Manager { - accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy) + accountManager, err := server.BuildManager(context.Background(), s.Config, s.Store(), s.NetworkMapController(), s.JobManager(), s.IdpManager(), s.mgmtSingleAccModeDomain, s.EventStore(), s.GeoLocationManager(), s.userDeleteFromIDPEnabled, s.IntegratedValidator(), s.Metrics(), s.ProxyController(), s.SettingsManager(), s.PermissionsManager(), s.Config.DisableDefaultPolicy, s.CacheStore()) if err != nil { log.Fatalf("failed to create account service: %v", err) } diff --git a/management/internals/shared/grpc/onetime_token.go b/management/internals/shared/grpc/onetime_token.go index 7999407db..acfd6eafb 100644 --- a/management/internals/shared/grpc/onetime_token.go +++ b/management/internals/shared/grpc/onetime_token.go @@ -14,8 +14,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) type tokenMetadata struct { @@ -32,17 +30,12 @@ type OneTimeTokenStore struct { ctx context.Context } -// NewOneTimeTokenStore creates a token store with automatic backend selection -func NewOneTimeTokenStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*OneTimeTokenStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewOneTimeTokenStore creates a token store using the provided shared cache store. +func NewOneTimeTokenStore(ctx context.Context, cacheStore store.StoreInterface) *OneTimeTokenStore { return &OneTimeTokenStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // GenerateToken creates a new cryptographically secure one-time token diff --git a/management/internals/shared/grpc/pkce_verifier.go b/management/internals/shared/grpc/pkce_verifier.go index 441e8b051..a1325256c 100644 --- a/management/internals/shared/grpc/pkce_verifier.go +++ b/management/internals/shared/grpc/pkce_verifier.go @@ -8,8 +8,6 @@ import ( "github.com/eko/gocache/lib/v4/cache" "github.com/eko/gocache/lib/v4/store" log "github.com/sirupsen/logrus" - - nbcache "github.com/netbirdio/netbird/management/server/cache" ) // PKCEVerifierStore manages PKCE verifiers for OAuth flows. @@ -19,17 +17,12 @@ type PKCEVerifierStore struct { ctx context.Context } -// NewPKCEVerifierStore creates a PKCE verifier store with automatic backend selection -func NewPKCEVerifierStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (*PKCEVerifierStore, error) { - cacheStore, err := nbcache.NewStore(ctx, maxTimeout, cleanupInterval, maxConn) - if err != nil { - return nil, fmt.Errorf("failed to create cache store: %w", err) - } - +// NewPKCEVerifierStore creates a PKCE verifier store using the provided shared cache store. +func NewPKCEVerifierStore(ctx context.Context, cacheStore store.StoreInterface) *PKCEVerifierStore { return &PKCEVerifierStore{ cache: cache.New[string](cacheStore), ctx: ctx, - }, nil + } } // Store saves a PKCE verifier associated with an OAuth state parameter. diff --git a/management/internals/shared/grpc/proxy_test.go b/management/internals/shared/grpc/proxy_test.go index d5aed3dee..de4e96d93 100644 --- a/management/internals/shared/grpc/proxy_test.go +++ b/management/internals/shared/grpc/proxy_test.go @@ -9,13 +9,22 @@ import ( "testing" "time" + cachestore "github.com/eko/gocache/lib/v4/store" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/shared/management/proto" ) +func testCacheStore(t *testing.T) cachestore.StoreInterface { + t.Helper() + s, err := nbcache.NewStore(context.Background(), 30*time.Minute, 10*time.Minute, 100) + require.NoError(t, err) + return s +} + type testProxyController struct { mu sync.Mutex clusterProxies map[string]map[string]struct{} @@ -114,11 +123,8 @@ func drainEmpty(ch chan *proto.GetMappingUpdateResponse) bool { func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -174,11 +180,8 @@ func TestSendServiceUpdateToCluster_UniqueTokensPerProxy(t *testing.T) { func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -211,11 +214,8 @@ func TestSendServiceUpdateToCluster_DeleteNoToken(t *testing.T) { func TestSendServiceUpdate_UniqueTokensPerProxy(t *testing.T) { ctx := context.Background() - tokenStore, err := NewOneTimeTokenStore(ctx, time.Hour, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -267,8 +267,7 @@ func generateState(s *ProxyServiceServer, redirectURL string) string { func TestOAuthState_NeverTheSame(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -296,8 +295,7 @@ func TestOAuthState_NeverTheSame(t *testing.T) { func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -307,7 +305,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { } // Old format had only 2 parts: base64(url)|hmac - err = s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("base64url|hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("base64url|hmac") @@ -317,8 +315,7 @@ func TestValidateState_RejectsOldTwoPartFormat(t *testing.T) { func TestValidateState_RejectsInvalidHMAC(t *testing.T) { ctx := context.Background() - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) s := &ProxyServiceServer{ oidcConfig: ProxyOIDCConfig{ @@ -328,7 +325,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } // Store with tampered HMAC - err = s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) + err := s.pkceVerifierStore.Store("dGVzdA==|nonce|wrong-hmac", "test", 10*time.Minute) require.NoError(t, err) _, _, err = s.ValidateState("dGVzdA==|nonce|wrong-hmac") @@ -337,8 +334,7 @@ func TestValidateState_RejectsInvalidHMAC(t *testing.T) { } func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -410,8 +406,7 @@ func TestSendServiceUpdateToCluster_FiltersOnCapability(t *testing.T) { } func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) s := &ProxyServiceServer{ tokenStore: tokenStore, @@ -442,8 +437,7 @@ func TestSendServiceUpdateToCluster_TLSNotFiltered(t *testing.T) { // scenario for an existing service, verifying the correct update types // reach the correct clusters. func TestServiceModifyNotifications(t *testing.T) { - tokenStore, err := NewOneTimeTokenStore(context.Background(), time.Hour, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(context.Background(), testCacheStore(t)) newServer := func() (*ProxyServiceServer, map[string]chan *proto.GetMappingUpdateResponse) { s := &ProxyServiceServer{ diff --git a/management/internals/shared/grpc/validate_session_test.go b/management/internals/shared/grpc/validate_session_test.go index 2f77de86e..d1d7fc8b7 100644 --- a/management/internals/shared/grpc/validate_session_test.go +++ b/management/internals/shared/grpc/validate_session_test.go @@ -39,11 +39,8 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup { usersManager := &testValidateSessionUsersManager{store: testStore} proxyManager := &testValidateSessionProxyManager{} - tokenStore, err := NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) - require.NoError(t, err) - - pkceStore, err := NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t)) + pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t)) proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager) proxyService.SetServiceManager(serviceManager) @@ -327,7 +324,7 @@ func (m *testValidateSessionServiceManager) GetActiveClusters(_ context.Context, type testValidateSessionProxyManager struct{} -func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string) error { +func (m *testValidateSessionProxyManager) Connect(_ context.Context, _, _, _ string, _ *proxy.Capabilities) error { return nil } @@ -335,7 +332,7 @@ func (m *testValidateSessionProxyManager) Disconnect(_ context.Context, _ string return nil } -func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _ string) error { +func (m *testValidateSessionProxyManager) Heartbeat(_ context.Context, _, _, _ string) error { return nil } @@ -351,6 +348,18 @@ func (m *testValidateSessionProxyManager) CleanupStale(_ context.Context, _ time return nil } +func (m *testValidateSessionProxyManager) ClusterSupportsCustomPorts(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterRequireSubdomain(_ context.Context, _ string) *bool { + return nil +} + +func (m *testValidateSessionProxyManager) ClusterSupportsCrowdSec(_ context.Context, _ string) *bool { + return nil +} + type testValidateSessionUsersManager struct { store store.Store } diff --git a/management/server/account.go b/management/server/account.go index d90b46659..7d53cef03 100644 --- a/management/server/account.go +++ b/management/server/account.go @@ -181,7 +181,7 @@ func (am *DefaultAccountManager) getJWTGroupsChanges(user *types.User, groups [] return modified, newUserAutoGroups, newGroupsToCreate, nil } -// BuildManager creates a new DefaultAccountManager with a provided Store +// BuildManager creates a new DefaultAccountManager with all dependencies. func BuildManager( ctx context.Context, config *nbconfig.Config, @@ -199,6 +199,7 @@ func BuildManager( settingsManager settings.Manager, permissionsManager permissions.Manager, disableDefaultPolicy bool, + sharedCacheStore cacheStore.StoreInterface, ) (*DefaultAccountManager, error) { start := time.Now() defer func() { @@ -247,16 +248,12 @@ func BuildManager( log.WithContext(ctx).Infof("single account mode disabled, accounts number %d", accountsCounter) } - cacheStore, err := nbcache.NewStore(ctx, nbcache.DefaultIDPCacheExpirationMax, nbcache.DefaultIDPCacheCleanupInterval, nbcache.DefaultIDPCacheOpenConn) - if err != nil { - return nil, fmt.Errorf("getting cache store: %s", err) - } - am.externalCacheManager = nbcache.NewUserDataCache(cacheStore) - am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, cacheStore) + am.externalCacheManager = nbcache.NewUserDataCache(sharedCacheStore) + am.cacheManager = nbcache.NewAccountUserDataCache(am.loadAccount, sharedCacheStore) if !isNil(am.idpManager) && !IsEmbeddedIdp(am.idpManager) { go func() { - err := am.warmupIDPCache(ctx, cacheStore) + err := am.warmupIDPCache(ctx, sharedCacheStore) if err != nil { log.WithContext(ctx).Warnf("failed warming up cache due to error: %v", err) // todo retry? diff --git a/management/server/account_test.go b/management/server/account_test.go index 2f0533281..bcc73d52f 100644 --- a/management/server/account_test.go +++ b/management/server/account_test.go @@ -2311,6 +2311,29 @@ func TestAccount_GetExpiredPeers(t *testing.T) { } } +func TestGetExpiredPeers_SkipsAlreadyExpired(t *testing.T) { + ctx := context.Background() + + testStore, cleanUp, err := store.NewTestStoreFromSQL(ctx, "testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanUp) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + // Verify the already-expired peer is excluded at the store level + peers, err := testStore.GetAccountPeersWithExpiration(ctx, store.LockingStrengthNone, accountID) + require.NoError(t, err) + + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should be excluded by the store query") + assert.False(t, peer.Status.LoginExpired, "returned peers should not already be marked as login expired") + } + + // Only the non-expired peer with expiration enabled should be returned + require.Len(t, peers, 1) + assert.Equal(t, "notexpired01", peers[0].ID) +} + func TestAccount_GetInactivePeers(t *testing.T) { type test struct { name string @@ -3134,10 +3157,15 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU ctx := context.Background() + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/cache/store.go b/management/server/cache/store.go index 54b0242de..2ca8e8603 100644 --- a/management/server/cache/store.go +++ b/management/server/cache/store.go @@ -17,12 +17,24 @@ import ( // RedisStoreEnvVar is the environment variable that determines if a redis store should be used. // The value should follow redis URL format. https://github.com/redis/redis-specifications/blob/master/uri/redis.txt -const RedisStoreEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" +const RedisStoreEnvVar = "NB_CACHE_REDIS_ADDRESS" + +// legacyIdPCacheRedisEnvVar is the previous environment variable used for IDP cache. +const legacyIdPCacheRedisEnvVar = "NB_IDP_CACHE_REDIS_ADDRESS" + +const ( + // DefaultStoreMaxTimeout is the default max timeout for the shared cache store. + DefaultStoreMaxTimeout = 7 * 24 * time.Hour + // DefaultStoreCleanupInterval is the default cleanup interval for the shared cache store. + DefaultStoreCleanupInterval = 30 * time.Minute + // DefaultStoreMaxConn is the default max connections for the shared cache store. + DefaultStoreMaxConn = 1000 +) // NewStore creates a new cache store with the given max timeout and cleanup interval. It checks for the environment Variable RedisStoreEnvVar // to determine if a redis store should be used. If the environment variable is set, it will attempt to connect to the redis store. func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, maxConn int) (store.StoreInterface, error) { - redisAddr := os.Getenv(RedisStoreEnvVar) + redisAddr := GetAddrFromEnv() if redisAddr != "" { return getRedisStore(ctx, redisAddr, maxConn) } @@ -30,6 +42,15 @@ func NewStore(ctx context.Context, maxTimeout, cleanupInterval time.Duration, ma return gocache_store.NewGoCache(goc), nil } +// GetAddrFromEnv returns the redis address from the environment variable RedisStoreEnvVar or its legacy counterpart. +func GetAddrFromEnv() string { + addr := os.Getenv(RedisStoreEnvVar) + if addr == "" { + addr = os.Getenv(legacyIdPCacheRedisEnvVar) + } + return addr +} + func getRedisStore(ctx context.Context, redisEnvAddr string, maxConn int) (store.StoreInterface, error) { options, err := redis.ParseURL(redisEnvAddr) if err != nil { diff --git a/management/server/dns_test.go b/management/server/dns_test.go index bd0755d0d..0e37a3b22 100644 --- a/management/server/dns_test.go +++ b/management/server/dns_test.go @@ -15,6 +15,7 @@ import ( "github.com/netbirdio/netbird/management/internals/modules/peers" ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" "github.com/netbirdio/netbird/management/server/permissions" @@ -225,11 +226,17 @@ func createDNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.test", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createDNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/http/handler.go b/management/server/http/handler.go index ad36b9d46..56b2d8203 100644 --- a/management/server/http/handler.go +++ b/management/server/http/handler.go @@ -5,9 +5,6 @@ import ( "fmt" "net/http" "net/netip" - "os" - "strconv" - "time" "github.com/gorilla/mux" "github.com/rs/cors" @@ -66,14 +63,11 @@ import ( ) const ( - apiPrefix = "/api" - rateLimitingEnabledKey = "NB_API_RATE_LIMITING_ENABLED" - rateLimitingBurstKey = "NB_API_RATE_LIMITING_BURST" - rateLimitingRPMKey = "NB_API_RATE_LIMITING_RPM" + apiPrefix = "/api" ) // NewAPIHandler creates the Management service HTTP API handler registering all the available endpoints. -func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix) (http.Handler, error) { +func NewAPIHandler(ctx context.Context, accountManager account.Manager, networksManager nbnetworks.Manager, resourceManager resources.Manager, routerManager routers.Manager, groupsManager nbgroups.Manager, LocationManager geolocation.Geolocation, authManager auth.Manager, appMetrics telemetry.AppMetrics, integratedValidator integrated_validator.IntegratedValidator, proxyController port_forwarding.Controller, permissionsManager permissions.Manager, peersManager nbpeers.Manager, settingsManager settings.Manager, zManager zones.Manager, rManager records.Manager, networkMapController network_map.Controller, idpManager idpmanager.Manager, serviceManager service.Manager, reverseProxyDomainManager *manager.Manager, reverseProxyAccessLogsManager accesslogs.Manager, proxyGRPCServer *nbgrpc.ProxyServiceServer, trustedHTTPProxies []netip.Prefix, rateLimiter *middleware.APIRateLimiter) (http.Handler, error) { // Register bypass paths for unauthenticated endpoints if err := bypass.AddBypassPath("/api/instance"); err != nil { @@ -94,34 +88,10 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks return nil, fmt.Errorf("failed to add bypass path: %w", err) } - var rateLimitingConfig *middleware.RateLimiterConfig - if os.Getenv(rateLimitingEnabledKey) == "true" { - rpm := 6 - if v := os.Getenv(rateLimitingRPMKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingRPMKey, err, rpm) - } else { - rpm = value - } - } - - burst := 500 - if v := os.Getenv(rateLimitingBurstKey); v != "" { - value, err := strconv.Atoi(v) - if err != nil { - log.Warnf("parsing %s env var: %v, using default %d", rateLimitingBurstKey, err, burst) - } else { - burst = value - } - } - - rateLimitingConfig = &middleware.RateLimiterConfig{ - RequestsPerMinute: float64(rpm), - Burst: burst, - CleanupInterval: 6 * time.Hour, - LimiterTTL: 24 * time.Hour, - } + if rateLimiter == nil { + log.Warn("NewAPIHandler: nil rate limiter, rate limiting disabled") + rateLimiter = middleware.NewAPIRateLimiter(nil) + rateLimiter.SetEnabled(false) } authMiddleware := middleware.NewAuthMiddleware( @@ -129,7 +99,7 @@ func NewAPIHandler(ctx context.Context, accountManager account.Manager, networks accountManager.GetAccountIDFromUserAuth, accountManager.SyncUserJWTGroups, accountManager.GetUserFromUserAuth, - rateLimitingConfig, + rateLimiter, appMetrics.GetMeter(), ) diff --git a/management/server/http/handlers/networks/routers_handler.go b/management/server/http/handlers/networks/routers_handler.go index c311a29fe..ce9efb78d 100644 --- a/management/server/http/handlers/networks/routers_handler.go +++ b/management/server/http/handlers/networks/routers_handler.go @@ -105,6 +105,12 @@ func (h *routersHandler) createRouter(w http.ResponseWriter, r *http.Request) { router.NetworkID = networkID router.AccountID = accountID router.Enabled = true + + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.CreateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) @@ -157,6 +163,11 @@ func (h *routersHandler) updateRouter(w http.ResponseWriter, r *http.Request) { router.ID = mux.Vars(r)["routerId"] router.AccountID = accountID + if err := router.Validate(); err != nil { + util.WriteErrorResponse(err.Error(), http.StatusBadRequest, w) + return + } + router, err = h.routersManager.UpdateRouter(r.Context(), userID, router) if err != nil { util.WriteError(r.Context(), err, w) diff --git a/management/server/http/handlers/proxy/auth_callback_integration_test.go b/management/server/http/handlers/proxy/auth_callback_integration_test.go index 922bf4352..c99acab63 100644 --- a/management/server/http/handlers/proxy/auth_callback_integration_test.go +++ b/management/server/http/handlers/proxy/auth_callback_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -191,11 +192,11 @@ func setupAuthCallbackTest(t *testing.T) *testSetup { oidcServer := newFakeOIDCServer() - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) usersManager := users.NewManager(testStore) diff --git a/management/server/http/middleware/auth_middleware.go b/management/server/http/middleware/auth_middleware.go index 63be672e6..6d075d9c2 100644 --- a/management/server/http/middleware/auth_middleware.go +++ b/management/server/http/middleware/auth_middleware.go @@ -12,6 +12,7 @@ import ( "go.opentelemetry.io/otel/metric" "github.com/netbirdio/management-integrations/integrations" + serverauth "github.com/netbirdio/netbird/management/server/auth" nbcontext "github.com/netbirdio/netbird/management/server/context" "github.com/netbirdio/netbird/management/server/http/middleware/bypass" @@ -42,14 +43,9 @@ func NewAuthMiddleware( ensureAccount EnsureAccountFunc, syncUserJWTGroups SyncUserJWTGroupsFunc, getUserFromUserAuth GetUserFromUserAuthFunc, - rateLimiterConfig *RateLimiterConfig, + rateLimiter *APIRateLimiter, meter metric.Meter, ) *AuthMiddleware { - var rateLimiter *APIRateLimiter - if rateLimiterConfig != nil { - rateLimiter = NewAPIRateLimiter(rateLimiterConfig) - } - var patUsageTracker *PATUsageTracker if meter != nil { var err error @@ -87,17 +83,14 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { switch authType { case "bearer": - request, err := m.checkJWTFromRequest(r, authHeader) - if err != nil { + if err := m.checkJWTFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Errorf("Error when validating JWT: %s", err.Error()) util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "token invalid"), w) return } - - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) case "token": - request, err := m.checkPATFromRequest(r, authHeader) - if err != nil { + if err := m.checkPATFromRequest(r, authHeader); err != nil { log.WithContext(r.Context()).Debugf("Error when validating PAT: %s", err.Error()) // Check if it's a status error, otherwise default to Unauthorized if _, ok := status.FromError(err); !ok { @@ -106,7 +99,7 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { util.WriteError(r.Context(), err, w) return } - h.ServeHTTP(w, request) + h.ServeHTTP(w, r) default: util.WriteError(r.Context(), status.Errorf(status.Unauthorized, "no valid authentication provided"), w) return @@ -115,19 +108,19 @@ func (m *AuthMiddleware) Handler(h http.Handler) http.Handler { } // CheckJWTFromRequest checks if the JWT is valid -func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromJWTRequest(authHeaderParts) // If an error occurs, call the error handler and return an error if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } ctx := r.Context() userAuth, validatedToken, err := m.authManager.ValidateAndParseToken(ctx, token) if err != nil { - return r, err + return err } if impersonate, ok := r.URL.Query()["account"]; ok && len(impersonate) == 1 { @@ -143,7 +136,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] // we need to call this method because if user is new, we will automatically add it to existing or create a new account accountId, _, err := m.ensureAccount(ctx, userAuth) if err != nil { - return r, err + return err } if userAuth.AccountId != accountId { @@ -153,7 +146,7 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] userAuth, err = m.authManager.EnsureUserAccessByJWTGroups(ctx, userAuth, validatedToken) if err != nil { - return r, err + return err } err = m.syncUserJWTGroups(ctx, userAuth) @@ -164,41 +157,41 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts [] _, err = m.getUserFromUserAuth(ctx, userAuth) if err != nil { log.WithContext(ctx).Errorf("HTTP server failed to update user from user auth: %s", err) - return r, err + return err } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } // CheckPATFromRequest checks if the PAT is valid -func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) (*http.Request, error) { +func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts []string) error { token, err := getTokenFromPATRequest(authHeaderParts) if err != nil { - return r, fmt.Errorf("error extracting token: %w", err) + return fmt.Errorf("error extracting token: %w", err) } if m.patUsageTracker != nil { m.patUsageTracker.IncrementUsage(token) } - if m.rateLimiter != nil && !isTerraformRequest(r) { - if !m.rateLimiter.Allow(token) { - return r, status.Errorf(status.TooManyRequests, "too many requests") - } + if !isTerraformRequest(r) && !m.rateLimiter.Allow(token) { + return status.Errorf(status.TooManyRequests, "too many requests") } ctx := r.Context() user, pat, accDomain, accCategory, err := m.authManager.GetPATInfo(ctx, token) if err != nil { - return r, fmt.Errorf("invalid Token: %w", err) + return fmt.Errorf("invalid Token: %w", err) } if time.Now().After(pat.GetExpirationDate()) { - return r, fmt.Errorf("token expired") + return fmt.Errorf("token expired") } err = m.authManager.MarkPATUsed(ctx, pat.ID) if err != nil { - return r, err + return err } userAuth := auth.UserAuth{ @@ -216,7 +209,9 @@ func (m *AuthMiddleware) checkPATFromRequest(r *http.Request, authHeaderParts [] } } - return nbcontext.SetUserAuthInRequest(r, userAuth), nil + // propagates ctx change to upstream middleware + *r = *nbcontext.SetUserAuthInRequest(r, userAuth) + return nil } func isTerraformRequest(r *http.Request) bool { diff --git a/management/server/http/middleware/auth_middleware_test.go b/management/server/http/middleware/auth_middleware_test.go index f397c63a4..8f736fbfd 100644 --- a/management/server/http/middleware/auth_middleware_test.go +++ b/management/server/http/middleware/auth_middleware_test.go @@ -196,6 +196,8 @@ func TestAuthMiddleware_Handler(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -207,7 +209,7 @@ func TestAuthMiddleware_Handler(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) @@ -266,7 +268,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -318,7 +320,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -361,7 +363,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -405,7 +407,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -469,7 +471,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -528,7 +530,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -583,7 +585,7 @@ func TestAuthMiddleware_RateLimiting(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - rateLimitConfig, + NewAPIRateLimiter(rateLimitConfig), nil, ) @@ -670,6 +672,8 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { GetPATInfoFunc: mockGetAccountInfoFromPAT, } + disabledLimiter := NewAPIRateLimiter(nil) + disabledLimiter.SetEnabled(false) authMiddleware := NewAuthMiddleware( mockAuth, func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) { @@ -681,7 +685,7 @@ func TestAuthMiddleware_Handler_Child(t *testing.T) { func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) { return &types.User{}, nil }, - nil, + disabledLimiter, nil, ) diff --git a/management/server/http/middleware/rate_limiter.go b/management/server/http/middleware/rate_limiter.go index 936b34319..bfd44afee 100644 --- a/management/server/http/middleware/rate_limiter.go +++ b/management/server/http/middleware/rate_limiter.go @@ -4,14 +4,27 @@ import ( "context" "net" "net/http" + "os" + "strconv" "sync" + "sync/atomic" "time" + log "github.com/sirupsen/logrus" "golang.org/x/time/rate" "github.com/netbirdio/netbird/shared/management/http/util" ) +const ( + RateLimitingEnabledEnv = "NB_API_RATE_LIMITING_ENABLED" + RateLimitingBurstEnv = "NB_API_RATE_LIMITING_BURST" + RateLimitingRPMEnv = "NB_API_RATE_LIMITING_RPM" + + defaultAPIRPM = 6 + defaultAPIBurst = 500 +) + // RateLimiterConfig holds configuration for the API rate limiter type RateLimiterConfig struct { // RequestsPerMinute defines the rate at which tokens are replenished @@ -34,6 +47,43 @@ func DefaultRateLimiterConfig() *RateLimiterConfig { } } +func RateLimiterConfigFromEnv() (cfg *RateLimiterConfig, enabled bool) { + rpm := defaultAPIRPM + if v := os.Getenv(RateLimitingRPMEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingRPMEnv, err, rpm) + } else { + rpm = value + } + } + if rpm <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingRPMEnv, rpm, defaultAPIRPM) + rpm = defaultAPIRPM + } + + burst := defaultAPIBurst + if v := os.Getenv(RateLimitingBurstEnv); v != "" { + value, err := strconv.Atoi(v) + if err != nil { + log.Warnf("parsing %s env var: %v, using default %d", RateLimitingBurstEnv, err, burst) + } else { + burst = value + } + } + if burst <= 0 { + log.Warnf("%s=%d is non-positive, using default %d", RateLimitingBurstEnv, burst, defaultAPIBurst) + burst = defaultAPIBurst + } + + return &RateLimiterConfig{ + RequestsPerMinute: float64(rpm), + Burst: burst, + CleanupInterval: 6 * time.Hour, + LimiterTTL: 24 * time.Hour, + }, os.Getenv(RateLimitingEnabledEnv) == "true" +} + // limiterEntry holds a rate limiter and its last access time type limiterEntry struct { limiter *rate.Limiter @@ -46,6 +96,7 @@ type APIRateLimiter struct { limiters map[string]*limiterEntry mu sync.RWMutex stopChan chan struct{} + enabled atomic.Bool } // NewAPIRateLimiter creates a new API rate limiter with the given configuration @@ -59,14 +110,53 @@ func NewAPIRateLimiter(config *RateLimiterConfig) *APIRateLimiter { limiters: make(map[string]*limiterEntry), stopChan: make(chan struct{}), } + rl.enabled.Store(true) go rl.cleanupLoop() return rl } +func (rl *APIRateLimiter) SetEnabled(enabled bool) { + rl.enabled.Store(enabled) +} + +func (rl *APIRateLimiter) Enabled() bool { + return rl.enabled.Load() +} + +func (rl *APIRateLimiter) UpdateConfig(config *RateLimiterConfig) { + if config == nil { + return + } + if config.RequestsPerMinute <= 0 || config.Burst <= 0 { + log.Warnf("UpdateConfig: ignoring invalid rpm=%v burst=%d", config.RequestsPerMinute, config.Burst) + return + } + + newRPS := rate.Limit(config.RequestsPerMinute / 60.0) + newBurst := config.Burst + + rl.mu.Lock() + rl.config.RequestsPerMinute = config.RequestsPerMinute + rl.config.Burst = newBurst + snapshot := make([]*rate.Limiter, 0, len(rl.limiters)) + for _, entry := range rl.limiters { + snapshot = append(snapshot, entry.limiter) + } + rl.mu.Unlock() + + for _, l := range snapshot { + l.SetLimit(newRPS) + l.SetBurst(newBurst) + } +} + // Allow checks if a request for the given key (token) is allowed func (rl *APIRateLimiter) Allow(key string) bool { + if !rl.enabled.Load() { + return true + } limiter := rl.getLimiter(key) return limiter.Allow() } @@ -74,6 +164,9 @@ func (rl *APIRateLimiter) Allow(key string) bool { // Wait blocks until the rate limiter allows another request for the given key // Returns an error if the context is canceled func (rl *APIRateLimiter) Wait(ctx context.Context, key string) error { + if !rl.enabled.Load() { + return nil + } limiter := rl.getLimiter(key) return limiter.Wait(ctx) } @@ -153,6 +246,10 @@ func (rl *APIRateLimiter) Reset(key string) { // Returns 429 Too Many Requests if the rate limit is exceeded. func (rl *APIRateLimiter) Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if !rl.enabled.Load() { + next.ServeHTTP(w, r) + return + } clientIP := getClientIP(r) if !rl.Allow(clientIP) { util.WriteErrorResponse("rate limit exceeded, please try again later", http.StatusTooManyRequests, w) diff --git a/management/server/http/middleware/rate_limiter_test.go b/management/server/http/middleware/rate_limiter_test.go index 68f804e57..4b97d1874 100644 --- a/management/server/http/middleware/rate_limiter_test.go +++ b/management/server/http/middleware/rate_limiter_test.go @@ -1,8 +1,10 @@ package middleware import ( + "fmt" "net/http" "net/http/httptest" + "sync" "testing" "time" @@ -156,3 +158,172 @@ func TestAPIRateLimiter_Reset(t *testing.T) { // Should be allowed again assert.True(t, rl.Allow("test-key")) } + +func TestAPIRateLimiter_SetEnabled(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("key")) + assert.False(t, rl.Allow("key"), "burst exhausted while enabled") + + rl.SetEnabled(false) + assert.False(t, rl.Enabled()) + for i := 0; i < 5; i++ { + assert.True(t, rl.Allow("key"), "disabled limiter must always allow") + } + + rl.SetEnabled(true) + assert.True(t, rl.Enabled()) + assert.False(t, rl.Allow("key"), "re-enabled limiter retains prior bucket state") +} + +func TestAPIRateLimiter_UpdateConfig(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 2, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k1")) + assert.True(t, rl.Allow("k1")) + assert.False(t, rl.Allow("k1"), "burst=2 exhausted") + + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + + // New burst applies to existing keys in place; bucket refills up to new burst over time, + // but importantly newly-added keys use the updated config immediately. + assert.True(t, rl.Allow("k2")) + for i := 0; i < 9; i++ { + assert.True(t, rl.Allow("k2")) + } + assert.False(t, rl.Allow("k2"), "new burst=10 exhausted") +} + +func TestAPIRateLimiter_UpdateConfig_NilIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + rl.UpdateConfig(nil) // must not panic or zero the config + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) +} + +func TestAPIRateLimiter_UpdateConfig_NonPositiveIgnored(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 60, + Burst: 1, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k")) + + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 0, Burst: 0, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: -1, Burst: 5, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + rl.UpdateConfig(&RateLimiterConfig{RequestsPerMinute: 60, Burst: -1, CleanupInterval: time.Minute, LimiterTTL: time.Minute}) + + rl.Reset("k") + assert.True(t, rl.Allow("k")) + assert.False(t, rl.Allow("k"), "burst should still be 1 — invalid UpdateConfig calls were ignored") +} + +func TestAPIRateLimiter_ConcurrentAllowAndUpdate(t *testing.T) { + rl := NewAPIRateLimiter(&RateLimiterConfig{ + RequestsPerMinute: 600, + Burst: 10, + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + defer rl.Stop() + + var wg sync.WaitGroup + stop := make(chan struct{}) + + for i := 0; i < 8; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + key := fmt.Sprintf("k%d", id) + for { + select { + case <-stop: + return + default: + rl.Allow(key) + } + } + }(i) + } + + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 200; i++ { + select { + case <-stop: + return + default: + rl.UpdateConfig(&RateLimiterConfig{ + RequestsPerMinute: float64(30 + (i % 90)), + Burst: 1 + (i % 20), + CleanupInterval: time.Minute, + LimiterTTL: time.Minute, + }) + rl.SetEnabled(i%2 == 0) + } + } + }() + + time.Sleep(100 * time.Millisecond) + close(stop) + wg.Wait() +} + +func TestRateLimiterConfigFromEnv(t *testing.T) { + t.Setenv(RateLimitingEnabledEnv, "true") + t.Setenv(RateLimitingRPMEnv, "42") + t.Setenv(RateLimitingBurstEnv, "7") + + cfg, enabled := RateLimiterConfigFromEnv() + assert.True(t, enabled) + assert.Equal(t, float64(42), cfg.RequestsPerMinute) + assert.Equal(t, 7, cfg.Burst) + + t.Setenv(RateLimitingEnabledEnv, "false") + _, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + + t.Setenv(RateLimitingEnabledEnv, "") + t.Setenv(RateLimitingRPMEnv, "") + t.Setenv(RateLimitingBurstEnv, "") + cfg, enabled = RateLimiterConfigFromEnv() + assert.False(t, enabled) + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute) + assert.Equal(t, defaultAPIBurst, cfg.Burst) + + t.Setenv(RateLimitingRPMEnv, "0") + t.Setenv(RateLimitingBurstEnv, "-5") + cfg, _ = RateLimiterConfigFromEnv() + assert.Equal(t, float64(defaultAPIRPM), cfg.RequestsPerMinute, "non-positive rpm must fall back to default") + assert.Equal(t, defaultAPIBurst, cfg.Burst, "non-positive burst must fall back to default") +} diff --git a/management/server/http/testing/integration/networks_handler_integration_test.go b/management/server/http/testing/integration/networks_handler_integration_test.go index 4cb6b268b..54f204a8f 100644 --- a/management/server/http/testing/integration/networks_handler_integration_test.go +++ b/management/server/http/testing/integration/networks_handler_integration_test.go @@ -1170,13 +1170,17 @@ func Test_NetworkRouters_Create(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.NotEmpty(t, router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Create router without peer and peer_groups", + networkId: "testNetworkId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, { name: "Create router in non-existing network", @@ -1341,13 +1345,18 @@ func Test_NetworkRouters_Update(t *testing.T) { Metric: 100, Enabled: true, }, - expectedStatus: http.StatusOK, - verifyResponse: func(t *testing.T, router *api.NetworkRouter) { - t.Helper() - assert.Equal(t, "testRouterId", router.Id) - assert.Equal(t, peerID, *router.Peer) - assert.Equal(t, 1, len(*router.PeerGroups)) + expectedStatus: http.StatusBadRequest, + }, + { + name: "Update router without peer and peer_groups", + networkId: "testNetworkId", + routerId: "testRouterId", + requestBody: &api.NetworkRouterRequest{ + Masquerade: true, + Metric: 100, + Enabled: true, }, + expectedStatus: http.StatusBadRequest, }, } diff --git a/management/server/http/testing/testing_tools/channel/channel.go b/management/server/http/testing/testing_tools/channel/channel.go index d9d85a0a2..1a8b83c7e 100644 --- a/management/server/http/testing/testing_tools/channel/channel.go +++ b/management/server/http/testing/testing_tools/channel/channel.go @@ -35,6 +35,7 @@ import ( "github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/activity" serverauth "github.com/netbirdio/netbird/management/server/auth" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/geolocation" "github.com/netbirdio/netbird/management/server/groups" http2 "github.com/netbirdio/netbird/management/server/http" @@ -87,22 +88,22 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { @@ -134,7 +135,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } @@ -216,22 +217,22 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin jobManager := job.NewJobManager(nil, store, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("Failed to create cache store: %v", err) + } + requestBuffer := server.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, peersUpdateManager, requestBuffer, server.MockIntegratedValidator{}, settingsManager, "", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManager), &config.Config{}) - am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false) + am, err := server.BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", &activity.InMemoryEventStore{}, geoMock, false, validatorMock, metrics, proxyController, settingsManager, permissionsManager, false, cacheStore) if err != nil { t.Fatalf("Failed to create manager: %v", err) } accessLogsManager := accesslogsmanager.NewManager(store, permissionsManager, nil) - proxyTokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create proxy token store: %v", err) - } - pkceverifierStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - if err != nil { - t.Fatalf("Failed to create PKCE verifier store: %v", err) - } + proxyTokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceverifierStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) noopMeter := noop.NewMeterProvider().Meter("") proxyMgr, err := proxymanager.NewManager(store, noopMeter) if err != nil { @@ -263,7 +264,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin customZonesManager := zonesManager.NewManager(store, am, permissionsManager, "") zoneRecordsManager := recordsManager.NewManager(store, am, permissionsManager) - apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil) + apiHandler, err := http2.NewAPIHandler(context.Background(), am, networksManager, resourcesManager, routersManager, groupsManager, geoMock, authManagerMock, metrics, validatorMock, proxyController, permissionsManager, peersManager, settingsManager, customZonesManager, zoneRecordsManager, networkMapController, nil, serviceManager, nil, nil, nil, nil, nil) if err != nil { t.Fatalf("Failed to create API handler: %v", err) } diff --git a/management/server/identity_provider_test.go b/management/server/identity_provider_test.go index 9fce6b9c0..d51254c55 100644 --- a/management/server/identity_provider_test.go +++ b/management/server/identity_provider_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "path/filepath" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" @@ -19,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/idp" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -83,10 +85,15 @@ func createManagerWithEmbeddedIdP(t testing.TB) (*DefaultAccountManager, *update permissionsManager := permissions.NewManager(testStore) peersManager := peers.NewManager(testStore, permissionsManager) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, testStore) networkMapController := controller.NewController(ctx, testStore, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(testStore, peersManager), &config.Config{}) - manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + manager, err := BuildManager(ctx, &config.Config{}, testStore, networkMapController, job.NewJobManager(nil, testStore, peersManager), idpManager, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/idp/google_workspace.go b/management/server/idp/google_workspace.go index 48e4f3000..dadbfd83e 100644 --- a/management/server/idp/google_workspace.go +++ b/management/server/idp/google_workspace.go @@ -66,14 +66,14 @@ func NewGoogleWorkspaceManager(ctx context.Context, config GoogleWorkspaceClient } // Create a new Admin SDK Directory service client - adminCredentials, err := getGoogleCredentials(ctx, config.ServiceAccountKey) + credentialsOption, err := getGoogleCredentialsOption(ctx, config.ServiceAccountKey) if err != nil { return nil, err } service, err := admin.NewService(context.Background(), option.WithScopes(admin.AdminDirectoryUserReadonlyScope), - option.WithCredentials(adminCredentials), + credentialsOption, ) if err != nil { return nil, err @@ -218,39 +218,32 @@ func (gm *GoogleWorkspaceManager) DeleteUser(_ context.Context, userID string) e return nil } -// getGoogleCredentials retrieves Google credentials based on the provided serviceAccountKey. -// It decodes the base64-encoded serviceAccountKey and attempts to obtain credentials using it. -// If that fails, it falls back to using the default Google credentials path. -// It returns the retrieved credentials or an error if unsuccessful. -func getGoogleCredentials(ctx context.Context, serviceAccountKey string) (*google.Credentials, error) { +// getGoogleCredentialsOption returns the google.golang.org/api option carrying +// Google credentials derived from the provided serviceAccountKey. +// It decodes the base64-encoded serviceAccountKey and uses it as the credentials JSON. +// If the key is empty, it falls back to the default Google credentials path. +func getGoogleCredentialsOption(ctx context.Context, serviceAccountKey string) (option.ClientOption, error) { log.WithContext(ctx).Debug("retrieving google credentials from the base64 encoded service account key") decodeKey, err := base64.StdEncoding.DecodeString(serviceAccountKey) if err != nil { return nil, fmt.Errorf("failed to decode service account key: %w", err) } - creds, err := google.CredentialsFromJSON( - context.Background(), - decodeKey, - admin.AdminDirectoryUserReadonlyScope, - ) - if err == nil { - // No need to fallback to the default Google credentials path - return creds, nil + if len(decodeKey) > 0 { + return option.WithAuthCredentialsJSON(option.ServiceAccount, decodeKey), nil } - log.WithContext(ctx).Debugf("failed to retrieve Google credentials from ServiceAccountKey: %v", err) - log.WithContext(ctx).Debug("falling back to default google credentials location") + log.WithContext(ctx).Debug("no service account key provided, falling back to default google credentials location") - creds, err = google.FindDefaultCredentials( - context.Background(), + creds, err := google.FindDefaultCredentials( + ctx, admin.AdminDirectoryUserReadonlyScope, ) if err != nil { return nil, err } - return creds, nil + return option.WithCredentials(creds), nil } // parseGoogleWorkspaceUser parse google user to UserData. diff --git a/management/server/management_proto_test.go b/management/server/management_proto_test.go index 090c99877..18d85315d 100644 --- a/management/server/management_proto_test.go +++ b/management/server/management_proto_test.go @@ -29,6 +29,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -266,8 +267,8 @@ func Test_SyncProtocol(t *testing.T) { } // expired peers come separately. - if len(networkMap.GetOfflinePeers()) != 1 { - t.Fatal("expecting SyncResponse to have NetworkMap with 1 offline peer") + if len(networkMap.GetOfflinePeers()) != 2 { + t.Fatal("expecting SyncResponse to have NetworkMap with 2 offline peer") } expiredPeerPubKey := "RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=" @@ -369,9 +370,15 @@ func startManagementForTest(t *testing.T, testFile string, config *config.Config requestBuffer := NewAccountRequestBuffer(ctx, store) ephemeralMgr := manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)) + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + cleanup() + return nil, nil, "", cleanup, err + } + networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeralMgr, config) accountManager, err := BuildManager(ctx, nil, store, networkMapController, jobManager, nil, "", - eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { cleanup() diff --git a/management/server/management_test.go b/management/server/management_test.go index de02855bf..3ac28cd4a 100644 --- a/management/server/management_test.go +++ b/management/server/management_test.go @@ -28,6 +28,7 @@ import ( nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -207,6 +208,12 @@ func startServer( jobManager := job.NewJobManager(nil, str, peersManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatalf("failed creating cache store: %v", err) + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := server.NewAccountRequestBuffer(ctx, str) networkMapController := controller.NewController(ctx, str, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(str, peers.NewManager(str, permissionsManager)), config) @@ -227,7 +234,8 @@ func startServer( port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, - false) + false, + cacheStore) if err != nil { t.Fatalf("failed creating an account manager: %v", err) } diff --git a/management/server/nameserver_test.go b/management/server/nameserver_test.go index 90b4b9687..d10d4464f 100644 --- a/management/server/nameserver_test.go +++ b/management/server/nameserver_test.go @@ -17,6 +17,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" nbpeer "github.com/netbirdio/netbird/management/server/peer" @@ -794,11 +795,17 @@ func createNSManager(t *testing.T) (*DefaultAccountManager, error) { peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + return BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) } func createNSStore(t *testing.T) (store.Store, error) { diff --git a/management/server/networks/routers/types/router.go b/management/server/networks/routers/types/router.go index e90c61a97..1293a9934 100644 --- a/management/server/networks/routers/types/router.go +++ b/management/server/networks/routers/types/router.go @@ -21,11 +21,7 @@ type NetworkRouter struct { } func NewNetworkRouter(accountID string, networkID string, peer string, peerGroups []string, masquerade bool, metric int, enabled bool) (*NetworkRouter, error) { - if peer != "" && len(peerGroups) > 0 { - return nil, errors.New("peer and peerGroups cannot be set at the same time") - } - - return &NetworkRouter{ + r := &NetworkRouter{ ID: xid.New().String(), AccountID: accountID, NetworkID: networkID, @@ -34,7 +30,25 @@ func NewNetworkRouter(accountID string, networkID string, peer string, peerGroup Masquerade: masquerade, Metric: metric, Enabled: enabled, - }, nil + } + + if err := r.Validate(); err != nil { + return nil, err + } + + return r, nil +} + +func (n *NetworkRouter) Validate() error { + if n.Peer != "" && len(n.PeerGroups) > 0 { + return errors.New("peer and peer_groups cannot be set at the same time") + } + + if n.Peer == "" && len(n.PeerGroups) == 0 { + return errors.New("either peer or peer_groups must be provided") + } + + return nil } func (n *NetworkRouter) ToAPIResponse() *api.NetworkRouter { diff --git a/management/server/networks/routers/types/router_test.go b/management/server/networks/routers/types/router_test.go index 5801e3bfa..a2f2fe6e3 100644 --- a/management/server/networks/routers/types/router_test.go +++ b/management/server/networks/routers/types/router_test.go @@ -38,7 +38,7 @@ func TestNewNetworkRouter(t *testing.T) { expectedError: false, }, { - name: "Valid with no peer or peerGroups", + name: "Invalid with no peer or peerGroups", networkID: "network-3", accountID: "account-3", peer: "", @@ -46,7 +46,18 @@ func TestNewNetworkRouter(t *testing.T) { masquerade: true, metric: 300, enabled: true, - expectedError: false, + expectedError: true, + }, + { + name: "Invalid with empty peerGroups slice", + networkID: "network-5", + accountID: "account-5", + peer: "", + peerGroups: []string{}, + masquerade: true, + metric: 500, + enabled: true, + expectedError: true, }, // Invalid cases diff --git a/management/server/peer.go b/management/server/peer.go index a02e34e0d..a95ae17a3 100644 --- a/management/server/peer.go +++ b/management/server/peer.go @@ -1405,6 +1405,10 @@ func (am *DefaultAccountManager) getExpiredPeers(ctx context.Context, accountID var peers []*nbpeer.Peer for _, peer := range peersWithExpiry { + if peer.Status.LoginExpired { + continue + } + expired, _ := peer.LoginExpired(settings.PeerLoginExpiration) if expired { peers = append(peers, peer) diff --git a/management/server/peer_test.go b/management/server/peer_test.go index 51c16d730..6f8d924fd 100644 --- a/management/server/peer_test.go +++ b/management/server/peer_test.go @@ -32,6 +32,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/http/testing/testing_tools" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" @@ -1294,11 +1295,15 @@ func Test_RegisterPeerByUser(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1380,11 +1385,15 @@ func Test_RegisterPeerBySetupKey(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1534,11 +1543,15 @@ func Test_RegisterPeerRollbackOnFailure(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" @@ -1615,11 +1628,15 @@ func Test_LoginPeer(t *testing.T) { peersManager := peers.NewManager(s, permissionsManager) ctx := context.Background() + + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + require.NoError(t, err) + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, s) networkMapController := controller.NewController(ctx, s, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(s, peers.NewManager(s, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, s, networkMapController, job.NewJobManager(nil, s, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) assert.NoError(t, err) existingAccountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" diff --git a/management/server/policy.go b/management/server/policy.go index 3e84c3d10..48297ca11 100644 --- a/management/server/policy.go +++ b/management/server/policy.go @@ -5,6 +5,7 @@ import ( _ "embed" "github.com/rs/xid" + "github.com/sirupsen/logrus" "github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/operations" @@ -46,25 +47,40 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user var isUpdate = policy.ID != "" var updateAccountPeers bool var action = activity.PolicyAdded + var unchanged bool err = am.Store.ExecuteInTransaction(ctx, func(transaction store.Store) error { - if err = validatePolicy(ctx, transaction, accountID, policy); err != nil { - return err - } - - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, isUpdate) + existingPolicy, err := validatePolicy(ctx, transaction, accountID, policy) if err != nil { return err } - saveFunc := transaction.CreatePolicy if isUpdate { - action = activity.PolicyUpdated - saveFunc = transaction.SavePolicy - } + if policy.Equal(existingPolicy) { + logrus.WithContext(ctx).Tracef("policy update skipped because equal to stored one - policy id %s", policy.ID) + unchanged = true + return nil + } - if err = saveFunc(ctx, policy); err != nil { - return err + action = activity.PolicyUpdated + + updateAccountPeers, err = arePolicyChangesAffectPeersWithExisting(ctx, transaction, policy, existingPolicy) + if err != nil { + return err + } + + if err = transaction.SavePolicy(ctx, policy); err != nil { + return err + } + } else { + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) + if err != nil { + return err + } + + if err = transaction.CreatePolicy(ctx, policy); err != nil { + return err + } } return transaction.IncrementNetworkSerial(ctx, accountID) @@ -73,6 +89,10 @@ func (am *DefaultAccountManager) SavePolicy(ctx context.Context, accountID, user return nil, err } + if unchanged { + return policy, nil + } + am.StoreEvent(ctx, userID, policy.ID, accountID, action, policy.EventMeta()) if updateAccountPeers { @@ -101,7 +121,7 @@ func (am *DefaultAccountManager) DeletePolicy(ctx context.Context, accountID, po return err } - updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, accountID, policy, false) + updateAccountPeers, err = arePolicyChangesAffectPeers(ctx, transaction, policy) if err != nil { return err } @@ -138,34 +158,37 @@ func (am *DefaultAccountManager) ListPolicies(ctx context.Context, accountID, us return am.Store.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID) } -// arePolicyChangesAffectPeers checks if changes to a policy will affect any associated peers. -func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy, isUpdate bool) (bool, error) { - if isUpdate { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) - if err != nil { - return false, err - } - - if !policy.Enabled && !existingPolicy.Enabled { - return false, nil - } - - for _, rule := range existingPolicy.Rules { - if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { - return true, nil - } - } - - hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) - if err != nil { - return false, err - } - - if hasPeers { +// arePolicyChangesAffectPeers checks if a policy (being created or deleted) will affect any associated peers. +func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, policy *types.Policy) (bool, error) { + for _, rule := range policy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil } } + return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) +} + +func arePolicyChangesAffectPeersWithExisting(ctx context.Context, transaction store.Store, policy *types.Policy, existingPolicy *types.Policy) (bool, error) { + if !policy.Enabled && !existingPolicy.Enabled { + return false, nil + } + + for _, rule := range existingPolicy.Rules { + if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { + return true, nil + } + } + + hasPeers, err := anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, existingPolicy.RuleGroups()) + if err != nil { + return false, err + } + + if hasPeers { + return true, nil + } + for _, rule := range policy.Rules { if rule.SourceResource.Type != "" || rule.DestinationResource.Type != "" { return true, nil @@ -175,12 +198,15 @@ func arePolicyChangesAffectPeers(ctx context.Context, transaction store.Store, a return anyGroupHasPeersOrResources(ctx, transaction, policy.AccountID, policy.RuleGroups()) } -// validatePolicy validates the policy and its rules. -func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) error { +// validatePolicy validates the policy and its rules. For updates it returns +// the existing policy loaded from the store so callers can avoid a second read. +func validatePolicy(ctx context.Context, transaction store.Store, accountID string, policy *types.Policy) (*types.Policy, error) { + var existingPolicy *types.Policy if policy.ID != "" { - existingPolicy, err := transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) + var err error + existingPolicy, err = transaction.GetPolicyByID(ctx, store.LockingStrengthNone, accountID, policy.ID) if err != nil { - return err + return nil, err } // TODO: Refactor to support multiple rules per policy @@ -191,7 +217,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri for _, rule := range policy.Rules { if rule.ID != "" && !existingRuleIDs[rule.ID] { - return status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) + return nil, status.Errorf(status.InvalidArgument, "invalid rule ID: %s", rule.ID) } } } else { @@ -201,12 +227,12 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri groups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, policy.RuleGroups()) if err != nil { - return err + return nil, err } postureChecks, err := transaction.GetPostureChecksByIDs(ctx, store.LockingStrengthNone, accountID, policy.SourcePostureChecks) if err != nil { - return err + return nil, err } for i, rule := range policy.Rules { @@ -225,7 +251,7 @@ func validatePolicy(ctx context.Context, transaction store.Store, accountID stri policy.SourcePostureChecks = getValidPostureCheckIDs(postureChecks, policy.SourcePostureChecks) } - return nil + return existingPolicy, nil } // getValidPostureCheckIDs filters and returns only the valid posture check IDs from the provided list. diff --git a/management/server/route_test.go b/management/server/route_test.go index d4882eff8..91b2cf982 100644 --- a/management/server/route_test.go +++ b/management/server/route_test.go @@ -20,6 +20,7 @@ import ( ephemeral_manager "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral/manager" "github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/job" resourceTypes "github.com/netbirdio/netbird/management/server/networks/resources/types" @@ -1293,11 +1294,17 @@ func createRouterManager(t *testing.T) (*DefaultAccountManager, *update_channel. peersManager := peers.NewManager(store, permissionsManager) ctx := context.Background() + + cacheStore, err := cache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + return nil, nil, err + } + updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peers.NewManager(store, permissionsManager)), &config.Config{}) - am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false) + am, err := BuildManager(context.Background(), nil, store, networkMapController, job.NewJobManager(nil, store, peersManager), nil, "", eventStore, nil, false, MockIntegratedValidator{}, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore) if err != nil { return nil, nil, err } diff --git a/management/server/store/sql_store.go b/management/server/store/sql_store.go index 8189548b7..0ff57b752 100644 --- a/management/server/store/sql_store.go +++ b/management/server/store/sql_store.go @@ -3310,7 +3310,7 @@ func (s *SqlStore) GetAccountPeersWithExpiration(ctx context.Context, lockStreng var peers []*nbpeer.Peer result := tx. - Where("login_expiration_enabled = ? AND user_id IS NOT NULL AND user_id != ''", true). + Where("login_expiration_enabled = ? AND peer_status_login_expired != ? AND user_id IS NOT NULL AND user_id != ''", true, true). Find(&peers, accountIDCondition, accountID) if err := result.Error; err != nil { log.WithContext(ctx).Errorf("failed to get peers with expiration from the store: %s", result.Error) diff --git a/management/server/store/sql_store_test.go b/management/server/store/sql_store_test.go index 8ea6c2ae5..5a5616abc 100644 --- a/management/server/store/sql_store_test.go +++ b/management/server/store/sql_store_test.go @@ -2729,7 +2729,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { { name: "should retrieve peers for an existing account ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 4, + expectedCount: 5, }, { name: "should return no peers for a non-existing account ID", @@ -2751,7 +2751,7 @@ func TestSqlStore_GetAccountPeers(t *testing.T) { name: "should filter peers by partial name", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", nameFilter: "host", - expectedCount: 3, + expectedCount: 4, }, { name: "should filter peers by ip", @@ -2777,14 +2777,16 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { require.NoError(t, err) tests := []struct { - name string - accountID string - expectedCount int + name string + accountID string + expectedCount int + expectedPeerIDs []string }{ { - name: "should retrieve peers with expiration for an existing account ID", - accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", - expectedCount: 1, + name: "should retrieve only non-expired peers with expiration enabled", + accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", + expectedCount: 1, + expectedPeerIDs: []string{"notexpired01"}, }, { name: "should return no peers with expiration for a non-existing account ID", @@ -2803,10 +2805,30 @@ func TestSqlStore_GetAccountPeersWithExpiration(t *testing.T) { peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, tt.accountID) require.NoError(t, err) require.Len(t, peers, tt.expectedCount) + for i, peer := range peers { + assert.Equal(t, tt.expectedPeerIDs[i], peer.ID) + } }) } } +func TestSqlStore_GetAccountPeersWithExpiration_ExcludesAlreadyExpired(t *testing.T) { + store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) + t.Cleanup(cleanup) + require.NoError(t, err) + + accountID := "bf1c8084-ba50-4ce7-9439-34653001fc3b" + + peers, err := store.GetAccountPeersWithExpiration(context.Background(), LockingStrengthNone, accountID) + require.NoError(t, err) + + // Verify the already-expired peer (cg05lnblo1hkg2j514p0) is not returned + for _, peer := range peers { + assert.NotEqual(t, "cg05lnblo1hkg2j514p0", peer.ID, "already expired peer should not be returned") + assert.False(t, peer.Status.LoginExpired, "returned peers should not have LoginExpired set") + } +} + func TestSqlStore_GetAccountPeersWithInactivity(t *testing.T) { store, cleanup, err := NewTestStoreFromSQL(context.Background(), "../testdata/store_with_expired_peers.sql", t.TempDir()) t.Cleanup(cleanup) @@ -2887,7 +2909,7 @@ func TestSqlStore_GetUserPeers(t *testing.T) { name: "should retrieve peers for another valid account ID and user ID", accountID: "bf1c8084-ba50-4ce7-9439-34653001fc3b", userID: "edafee4e-63fb-11ec-90d6-0242ac120003", - expectedCount: 2, + expectedCount: 3, }, { name: "should return no peers for existing account ID with empty user ID", diff --git a/management/server/telemetry/http_api_metrics.go b/management/server/telemetry/http_api_metrics.go index 28e8457e2..e48e6d64a 100644 --- a/management/server/telemetry/http_api_metrics.go +++ b/management/server/telemetry/http_api_metrics.go @@ -193,20 +193,12 @@ func (m *HTTPMiddleware) Handler(h http.Handler) http.Handler { } }) - h.ServeHTTP(w, r.WithContext(ctx)) + // Hold on to req so auth's in-place ctx update is visible after ServeHTTP. + req := r.WithContext(ctx) + h.ServeHTTP(w, req) close(handlerDone) - userAuth, err := nbContext.GetUserAuthFromContext(r.Context()) - if err == nil { - if userAuth.AccountId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.AccountIDKey, userAuth.AccountId) - } - if userAuth.UserId != "" { - //nolint - ctx = context.WithValue(ctx, nbContext.UserIDKey, userAuth.UserId) - } - } + ctx = req.Context() if w.Status() > 399 { log.WithContext(ctx).Errorf("HTTP response %v: %v %v status %v", reqID, r.Method, r.URL, w.Status()) diff --git a/management/server/testdata/store_with_expired_peers.sql b/management/server/testdata/store_with_expired_peers.sql index dfcaeee6f..189bd1262 100644 --- a/management/server/testdata/store_with_expired_peers.sql +++ b/management/server/testdata/store_with_expired_peers.sql @@ -31,6 +31,7 @@ INSERT INTO peers VALUES('cfvprsrlo1hqoo49ohog','bf1c8084-ba50-4ce7-9439-3465300 INSERT INTO peers VALUES('cg05lnblo1hkg2j514p0','bf1c8084-ba50-4ce7-9439-34653001fc3b','RlSy2vzoG2HyMBTUImXOiVhCBiiBa5qD5xzMxkiFDW4=','','"100.64.39.54"','expiredhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'expiredhost','expiredhost','2023-03-02 09:19:57.276717255+01:00',0,1,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIMbK5ZXJsGOOWoBT4OmkPtgdPZe2Q7bDuS/zjn2CZxhK',0,1,0,'2023-03-02 09:14:21.791679181+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('cg3161rlo1hs9cq94gdg','bf1c8084-ba50-4ce7-9439-34653001fc3b','mVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HU=','','"100.64.117.96"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO peers VALUES('csrnkiq7qv9d8aitqd50','bf1c8084-ba50-4ce7-9439-34653001fc3b','nVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HX=','','"100.64.117.97"','testhost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'testhost','testhost-1','2023-03-06 18:21:27.252010027+01:00',0,0,0,'f4f6d672-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,0,1,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); +INSERT INTO peers VALUES('notexpired01','bf1c8084-ba50-4ce7-9439-34653001fc3b','oVABSKj28gv+JRsf7e0NEGKgSOGTfU/nPB2cpuG56HY=','','"100.64.117.98"','activehost','linux','Linux','22.04','x86_64','Ubuntu','','development','','',NULL,'','','','{"Cloud":"","Platform":""}',NULL,'activehost','activehost','2023-03-06 18:21:27.252010027+01:00',0,0,0,'edafee4e-63fb-11ec-90d6-0242ac120003','ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAINWvvUkFFcrj48CWTkNUb/do/n52i1L5dH4DhGu+4ZuM',0,1,0,'2023-03-07 09:02:47.442857106+01:00','2024-10-02 17:00:32.527947+02:00',0,'""','','',0); INSERT INTO users VALUES('f4f6d672-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','user',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO users VALUES('edafee4e-63fb-11ec-90d6-0242ac120003','bf1c8084-ba50-4ce7-9439-34653001fc3b','admin',0,0,'','[]',0,NULL,'2024-10-02 17:00:32.528196+02:00','api',0,''); INSERT INTO installations VALUES(1,''); diff --git a/management/server/types/policy.go b/management/server/types/policy.go index d4e1a8816..d410aec8d 100644 --- a/management/server/types/policy.go +++ b/management/server/types/policy.go @@ -93,6 +93,44 @@ func (p *Policy) Copy() *Policy { return c } +func (p *Policy) Equal(other *Policy) bool { + if p == nil || other == nil { + return p == other + } + + if p.ID != other.ID || + p.AccountID != other.AccountID || + p.Name != other.Name || + p.Description != other.Description || + p.Enabled != other.Enabled { + return false + } + + if !stringSlicesEqualUnordered(p.SourcePostureChecks, other.SourcePostureChecks) { + return false + } + + if len(p.Rules) != len(other.Rules) { + return false + } + + otherRules := make(map[string]*PolicyRule, len(other.Rules)) + for _, r := range other.Rules { + otherRules[r.ID] = r + } + for _, r := range p.Rules { + otherRule, ok := otherRules[r.ID] + if !ok { + return false + } + if !r.Equal(otherRule) { + return false + } + } + + return true +} + // EventMeta returns activity event meta related to this policy func (p *Policy) EventMeta() map[string]any { return map[string]any{"name": p.Name} diff --git a/management/server/types/policy_test.go b/management/server/types/policy_test.go new file mode 100644 index 000000000..b1d7aabc2 --- /dev/null +++ b/management/server/types/policy_test.go @@ -0,0 +1,193 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyEqual_SameRulesDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1", Ports: []string{"443"}}, + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRules(t *testing.T) { + a := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"80"}}, + }, + } + b := &Policy{ + ID: "pol1", + Enabled: true, + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1", Ports: []string{"443"}}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentRuleCount(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_PostureChecksDifferentOrder(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc3", "pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2", "pc3"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentPostureChecks(t *testing.T) { + a := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc2"}, + } + b := &Policy{ + ID: "pol1", + SourcePostureChecks: []string{"pc1", "pc3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_DifferentScalarFields(t *testing.T) { + base := Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "test", + Description: "desc", + Enabled: true, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Description = "changed" + assert.False(t, base.Equal(&other)) +} + +func TestPolicyEqual_NilCases(t *testing.T) { + var a *Policy + var b *Policy + assert.True(t, a.Equal(b)) + + a = &Policy{ID: "pol1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyEqual_RulesMismatchByID(t *testing.T) { + a := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r1", PolicyID: "pol1"}, + }, + } + b := &Policy{ + ID: "pol1", + Rules: []*PolicyRule{ + {ID: "r2", PolicyID: "pol1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyEqual_FullScenario(t *testing.T) { + a := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc2", "pc1"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g2", "g1"}, + Destinations: []string{"g4", "g3"}, + Ports: []string{"443", "80", "8080"}, + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + }, + }, + } + b := &Policy{ + ID: "pol1", + AccountID: "acc1", + Name: "Web Access", + Description: "Allow web access", + Enabled: true, + SourcePostureChecks: []string{"pc1", "pc2"}, + Rules: []*PolicyRule{ + { + ID: "r1", + PolicyID: "pol1", + Name: "HTTP", + Enabled: true, + Action: PolicyTrafficActionAccept, + Protocol: PolicyRuleProtocolTCP, + Bidirectional: true, + Sources: []string{"g1", "g2"}, + Destinations: []string{"g3", "g4"}, + Ports: []string{"80", "8080", "443"}, + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + }, + }, + } + assert.True(t, a.Equal(b)) +} diff --git a/management/server/types/policyrule.go b/management/server/types/policyrule.go index bb75dd555..52c494a6a 100644 --- a/management/server/types/policyrule.go +++ b/management/server/types/policyrule.go @@ -1,6 +1,8 @@ package types import ( + "slices" + "github.com/netbirdio/netbird/shared/management/proto" ) @@ -118,3 +120,106 @@ func (pm *PolicyRule) Copy() *PolicyRule { } return rule } + +func (pm *PolicyRule) Equal(other *PolicyRule) bool { + if pm == nil || other == nil { + return pm == other + } + + if pm.ID != other.ID || + pm.PolicyID != other.PolicyID || + pm.Name != other.Name || + pm.Description != other.Description || + pm.Enabled != other.Enabled || + pm.Action != other.Action || + pm.Bidirectional != other.Bidirectional || + pm.Protocol != other.Protocol || + pm.SourceResource != other.SourceResource || + pm.DestinationResource != other.DestinationResource || + pm.AuthorizedUser != other.AuthorizedUser { + return false + } + + if !stringSlicesEqualUnordered(pm.Sources, other.Sources) { + return false + } + if !stringSlicesEqualUnordered(pm.Destinations, other.Destinations) { + return false + } + if !stringSlicesEqualUnordered(pm.Ports, other.Ports) { + return false + } + if !portRangeSlicesEqualUnordered(pm.PortRanges, other.PortRanges) { + return false + } + if !authorizedGroupsEqual(pm.AuthorizedGroups, other.AuthorizedGroups) { + return false + } + + return true +} + +func stringSlicesEqualUnordered(a, b []string) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + sorted1 := make([]string, len(a)) + sorted2 := make([]string, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.Sort(sorted1) + slices.Sort(sorted2) + return slices.Equal(sorted1, sorted2) +} + +func portRangeSlicesEqualUnordered(a, b []RulePortRange) bool { + if len(a) != len(b) { + return false + } + if len(a) == 0 { + return true + } + cmp := func(x, y RulePortRange) int { + if x.Start != y.Start { + if x.Start < y.Start { + return -1 + } + return 1 + } + if x.End != y.End { + if x.End < y.End { + return -1 + } + return 1 + } + return 0 + } + sorted1 := make([]RulePortRange, len(a)) + sorted2 := make([]RulePortRange, len(b)) + copy(sorted1, a) + copy(sorted2, b) + slices.SortFunc(sorted1, cmp) + slices.SortFunc(sorted2, cmp) + return slices.EqualFunc(sorted1, sorted2, func(x, y RulePortRange) bool { + return x.Start == y.Start && x.End == y.End + }) +} + +func authorizedGroupsEqual(a, b map[string][]string) bool { + if len(a) != len(b) { + return false + } + for k, va := range a { + vb, ok := b[k] + if !ok { + return false + } + if !stringSlicesEqualUnordered(va, vb) { + return false + } + } + return true +} diff --git a/management/server/types/policyrule_test.go b/management/server/types/policyrule_test.go new file mode 100644 index 000000000..816e72abb --- /dev/null +++ b/management/server/types/policyrule_test.go @@ -0,0 +1,194 @@ +package types + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestPolicyRuleEqual_SamePortsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80", "22"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"22", "443", "80"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPorts(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "80"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{"443", "22"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_SourcesDestinationsDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2", "g3"}, + Destinations: []string{"g4", "g5"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g3", "g1", "g2"}, + Destinations: []string{"g5", "g4"}, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentSources(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g2"}, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Sources: []string{"g1", "g3"}, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_PortRangesDifferentOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 8000, End: 9000}, + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + {Start: 8000, End: 9000}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentPortRanges(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 80}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + PortRanges: []RulePortRange{ + {Start: 80, End: 443}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_AuthorizedGroupsDifferentValueOrder(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1", "u2", "u3"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u3", "u1", "u2"}, + }, + } + assert.True(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentAuthorizedGroups(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g1": {"u1"}, + }, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + AuthorizedGroups: map[string][]string{ + "g2": {"u1"}, + }, + } + assert.False(t, a.Equal(b)) +} + +func TestPolicyRuleEqual_DifferentScalarFields(t *testing.T) { + base := PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Name: "test", + Description: "desc", + Enabled: true, + Action: PolicyTrafficActionAccept, + Bidirectional: true, + Protocol: PolicyRuleProtocolTCP, + } + + other := base + other.Name = "changed" + assert.False(t, base.Equal(&other)) + + other = base + other.Enabled = false + assert.False(t, base.Equal(&other)) + + other = base + other.Action = PolicyTrafficActionDrop + assert.False(t, base.Equal(&other)) + + other = base + other.Protocol = PolicyRuleProtocolUDP + assert.False(t, base.Equal(&other)) +} + +func TestPolicyRuleEqual_NilCases(t *testing.T) { + var a *PolicyRule + var b *PolicyRule + assert.True(t, a.Equal(b)) + + a = &PolicyRule{ID: "rule1"} + assert.False(t, a.Equal(nil)) +} + +func TestPolicyRuleEqual_EmptySlices(t *testing.T) { + a := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: []string{}, + Sources: nil, + } + b := &PolicyRule{ + ID: "rule1", + PolicyID: "pol1", + Ports: nil, + Sources: []string{}, + } + assert.True(t, a.Equal(b)) +} + diff --git a/proxy/internal/auth/middleware.go b/proxy/internal/auth/middleware.go index 055e4510f..3b383f8b4 100644 --- a/proxy/internal/auth/middleware.go +++ b/proxy/internal/auth/middleware.go @@ -433,6 +433,7 @@ func setSessionCookie(w http.ResponseWriter, token string, expiration time.Durat http.SetCookie(w, &http.Cookie{ Name: auth.SessionCookieName, Value: token, + Path: "/", HttpOnly: true, Secure: true, SameSite: http.SameSiteLaxMode, diff --git a/proxy/internal/auth/middleware_test.go b/proxy/internal/auth/middleware_test.go index 16d09800c..2c93d7912 100644 --- a/proxy/internal/auth/middleware_test.go +++ b/proxy/internal/auth/middleware_test.go @@ -391,6 +391,15 @@ func TestProtect_SchemeAuthRedirectsWithCookie(t *testing.T) { assert.Equal(t, http.SameSiteLaxMode, sessionCookie.SameSite) } +func TestSetSessionCookieHasRootPath(t *testing.T) { + w := httptest.NewRecorder() + setSessionCookie(w, "test-token", time.Hour) + + cookies := w.Result().Cookies() + require.Len(t, cookies, 1) + assert.Equal(t, "/", cookies[0].Path, "session cookie must be scoped to root so it applies to all paths") +} + func TestProtect_FailedAuthDoesNotSetCookie(t *testing.T) { mw := NewMiddleware(log.StandardLogger(), nil, nil) kp := generateTestKeyPair(t) diff --git a/proxy/management_integration_test.go b/proxy/management_integration_test.go index 17510f37e..4b1ecf922 100644 --- a/proxy/management_integration_test.go +++ b/proxy/management_integration_test.go @@ -22,6 +22,7 @@ import ( nbproxy "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service" nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/types" "github.com/netbirdio/netbird/management/server/users" @@ -113,11 +114,11 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup { } // Create real token store - tokenStore, err := nbgrpc.NewOneTimeTokenStore(ctx, 5*time.Minute, 10*time.Minute, 100) + cacheStore, err := nbcache.NewStore(ctx, 30*time.Minute, 10*time.Minute, 100) require.NoError(t, err) - pkceStore, err := nbgrpc.NewPKCEVerifierStore(ctx, 10*time.Minute, 10*time.Minute, 100) - require.NoError(t, err) + tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, cacheStore) + pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, cacheStore) // Create real users manager usersManager := users.NewManager(testStore) diff --git a/shared/management/client/client_test.go b/shared/management/client/client_test.go index f5edb6b95..d9a1a7d65 100644 --- a/shared/management/client/client_test.go +++ b/shared/management/client/client_test.go @@ -31,6 +31,7 @@ import ( "github.com/netbirdio/netbird/management/internals/server/config" mgmt "github.com/netbirdio/netbird/management/server" "github.com/netbirdio/netbird/management/server/activity" + nbcache "github.com/netbirdio/netbird/management/server/cache" "github.com/netbirdio/netbird/management/server/groups" "github.com/netbirdio/netbird/management/server/integrations/port_forwarding" "github.com/netbirdio/netbird/management/server/mock_server" @@ -95,9 +96,16 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { settingsManagerMock := settings.NewMockManager(ctrl) jobManager := job.NewJobManager(nil, store, peersManger) - ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManger, settingsManagerMock, eventStore) + ctx := context.Background() - metrics, err := telemetry.NewDefaultAppMetrics(context.Background()) + cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100) + if err != nil { + t.Fatal(err) + } + + ia, _ := integrations.NewIntegratedValidator(ctx, peersManger, settingsManagerMock, eventStore, cacheStore) + + metrics, err := telemetry.NewDefaultAppMetrics(ctx) require.NoError(t, err) settingsMockManager := settings.NewMockManager(ctrl) @@ -116,11 +124,10 @@ func startManagement(t *testing.T) (*grpc.Server, net.Listener) { Return(&types.ExtraSettings{}, nil). AnyTimes() - ctx := context.Background() updateManager := update_channel.NewPeersUpdateManager(metrics) requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store) networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), ephemeral_manager.NewEphemeralManager(store, peersManger), config) - accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false) + accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore) if err != nil { t.Fatal(err) } diff --git a/shared/management/client/grpc.go b/shared/management/client/grpc.go index a01e51abc..e9bea7ffb 100644 --- a/shared/management/client/grpc.go +++ b/shared/management/client/grpc.go @@ -30,6 +30,8 @@ import ( const ConnectTimeout = 10 * time.Second +const healthCheckTimeout = 5 * time.Second + const ( // EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB) // for the management client connection. Value is in bytes. @@ -532,7 +534,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.GetServerKey(ctx, &proto.Empty{}) diff --git a/shared/signal/client/grpc.go b/shared/signal/client/grpc.go index 5368b57a2..d0f598dd7 100644 --- a/shared/signal/client/grpc.go +++ b/shared/signal/client/grpc.go @@ -23,6 +23,8 @@ import ( "github.com/netbirdio/netbird/util/wsproxy" ) +const healthCheckTimeout = 5 * time.Second + // ConnStateNotifier is a wrapper interface of the status recorder type ConnStateNotifier interface { MarkSignalDisconnected(error) @@ -263,7 +265,7 @@ func (c *GrpcClient) IsHealthy() bool { case connectivity.Ready: } - ctx, cancel := context.WithTimeout(c.ctx, 1*time.Second) + ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout) defer cancel() _, err := c.realClient.Send(ctx, &proto.EncryptedMessage{ Key: c.key.PublicKey().String(),