diff --git a/client/embed/embed.go b/client/embed/embed.go index 7e7f6c337..04bc60fb8 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/sirupsen/logrus" + wgdevice "golang.zx2c4.com/wireguard/device" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" @@ -100,6 +101,26 @@ type Options struct { MTU *uint16 // DNSLabels defines additional DNS labels configured in the peer. DNSLabels []string + // Performance configures the tunnel's buffer pool cap and batch size. + Performance Performance +} + +// Performance configures the embedded client's tunnel memory/throughput knobs. +// +// These settings are process-global: any non-nil field also becomes the +// default for Clients constructed by later embed.New calls in the same +// process. Nil fields are ignored. +type Performance struct { + // PreallocatedBuffersPerPool caps the per-tunnel buffer pool. Zero + // leaves the pool unbounded. Lower values trade throughput for a + // tighter memory ceiling. May also be changed on a running Client via + // Client.SetPerformance, provided this field was nonzero at construction. + PreallocatedBuffersPerPool *uint32 + // MaxBatchSize overrides the number of packets the tunnel reads or + // writes per syscall, which also bounds eager buffer allocation per + // worker. Zero uses the platform default. Applied at construction + // only; ignored by Client.SetPerformance. + MaxBatchSize *uint32 } // validateCredentials checks that exactly one credential type is provided @@ -199,6 +220,13 @@ func New(opts Options) (*Client, error) { config.PrivateKey = opts.PrivateKey } + if opts.Performance.PreallocatedBuffersPerPool != nil { + wgdevice.SetPreallocatedBuffersPerPool(*opts.Performance.PreallocatedBuffersPerPool) + } + if opts.Performance.MaxBatchSize != nil { + wgdevice.SetMaxBatchSizeOverride(*opts.Performance.MaxBatchSize) + } + return &Client{ deviceName: opts.DeviceName, setupKey: opts.SetupKey, @@ -495,6 +523,25 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// SetPerformance retunes a running Client. Only PreallocatedBuffersPerPool +// takes effect, and only when it was nonzero at construction; +// MaxBatchSize is construction-only and returns an error if set here. +// +// Returns ErrClientNotStarted / ErrEngineNotStarted if the Client is not +// running yet. +func (c *Client) SetPerformance(t Performance) error { + if t.MaxBatchSize != nil { + return errors.New("MaxBatchSize is construction-only and cannot be changed at runtime") + } + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetPerformance(internal.Performance{ + PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool, + }) +} + // StartCapture begins capturing packets on this client's tunnel device. // Only one capture can be active at a time; starting a new one stops the previous. // Call StopCapture (or CaptureSession.Stop) to end it. diff --git a/client/internal/engine.go b/client/internal/engine.go index 3bd0d4621..b82eb95b7 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1967,6 +1967,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { return e.clientMetrics } +// Performance bundles runtime-adjustable tunnel pool knobs. +// See Engine.SetPerformance. Nil fields are ignored. +type Performance struct { + PreallocatedBuffersPerPool *uint32 +} + +// SetPerformance applies the given tuning to this engine's live Device. +func (e *Engine) SetPerformance(t Performance) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + if e.wgInterface == nil { + return fmt.Errorf("wg interface not initialized") + } + dev := e.wgInterface.GetWGDevice() + if dev == nil { + return fmt.Errorf("wg device not initialized") + } + if t.PreallocatedBuffersPerPool != nil { + dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool) + } + return nil +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/go.mod b/go.mod index 7c1a95e79..ea0d8d73d 100644 --- a/go.mod +++ b/go.mod @@ -335,7 +335,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f replace github.com/cloudflare/circl => codeberg.org/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 53789f49d..f95efefa6 100644 --- a/go.sum +++ b/go.sum @@ -499,8 +499,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= +github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f h1:ff2D57RBjWtyQ2wVwJOxOgXAXOe/J2lJWtSX0Bz/BRk= +github.com/netbirdio/wireguard-go v0.0.0-20260523085312-4b4a4e36017f/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 49afc7638..360c7a516 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -109,6 +109,22 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugPerfCmd = &cobra.Command{ + Use: "perf ", + Short: "Live-retune the tunnel buffer pool cap on all running clients", + Args: cobra.ExactArgs(1), + RunE: runDebugPerfSet, + SilenceUsage: true, +} + +var debugRuntimeCmd = &cobra.Command{ + Use: "runtime", + Short: "Show runtime stats (heap, goroutines, RSS)", + Args: cobra.NoArgs, + RunE: runDebugRuntime, + SilenceUsage: true, +} + var debugCaptureCmd = &cobra.Command{ Use: "capture [filter expression]", Short: "Capture packets on a client's WireGuard interface", @@ -159,6 +175,8 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugCmd.AddCommand(debugPerfCmd) + debugCmd.AddCommand(debugRuntimeCmd) debugCmd.AddCommand(debugCaptureCmd) rootCmd.AddCommand(debugCmd) @@ -220,6 +238,18 @@ func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } +func runDebugPerfSet(cmd *cobra.Command, args []string) error { + n, err := strconv.ParseUint(args[0], 10, 32) + if err != nil { + return fmt.Errorf("invalid value %q: %w", args[0], err) + } + return getDebugClient(cmd).PerfSet(cmd.Context(), uint32(n)) +} + +func runDebugRuntime(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).Runtime(cmd.Context()) +} + func runDebugCapture(cmd *cobra.Command, args []string) error { duration, _ := cmd.Flags().GetDuration("duration") forcePcap, _ := cmd.Flags().GetBool("pcap") diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index 5970886da..405fa2789 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -15,11 +15,22 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy" nbacme "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/util" ) +const ( + // envPreallocatedBuffers caps the per-tunnel buffer pool. Zero (unset) + // keeps the upstream uncapped default. + envPreallocatedBuffers = "NB_PROXY_PREALLOCATED_BUFFERS" + // envMaxBatchSize overrides the per-tunnel batch size, which controls + // how many buffers each receive/TUN worker eagerly allocates. Zero + // (unset) keeps the platform default. + envMaxBatchSize = "NB_PROXY_MAX_BATCH_SIZE" +) + const DefaultManagementURL = "https://api.netbird.io:443" // envProxyToken is the environment variable name for the proxy access token. @@ -148,6 +159,45 @@ func runServer(cmd *cobra.Command, args []string) error { logger.Infof("configured log level: %s", level) + var wgPool, wgBatch uint64 + var perf embed.Performance + if raw := os.Getenv(envPreallocatedBuffers); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envPreallocatedBuffers, raw, err) + } + wgPool = n + v := uint32(n) + perf.PreallocatedBuffersPerPool = &v + logger.Infof("tunnel preallocated buffers per pool: %d", n) + } + if raw := os.Getenv(envMaxBatchSize); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envMaxBatchSize, raw, err) + } + wgBatch = n + v := uint32(n) + perf.MaxBatchSize = &v + logger.Infof("tunnel max batch size override: %d", n) + } + if wgPool > 0 { + // Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus + // RoutineReadFromTUN eagerly reserves `batch` message buffers for + // the lifetime of the Device. A pool cap below that floor blocks + // the receive pipeline at startup. + batch := wgBatch + if batch == 0 { + batch = 128 + } + const recvGoroutines = 4 + floor := batch * recvGoroutines + if wgPool < floor { + logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock", + envPreallocatedBuffers, wgPool, floor, batch) + } + } + switch forwardedProto { case "auto", "http", "https": default: @@ -188,6 +238,7 @@ func runServer(cmd *cobra.Command, args []string) error { CertLockMethod: nbacme.CertLockMethod(certLockMethod), WildcardCertDir: wildcardCertDir, WireguardPort: wgPort, + Performance: perf, ProxyProtocol: proxyProtocol, PreSharedKey: preSharedKey, SupportsCustomPorts: supportsCustomPorts, diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 736781652..77772637c 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -333,6 +333,63 @@ func (c *Client) printLogLevelResult(data map[string]any) { } } +// PerfSet live-retunes the tunnel buffer pool cap on all running embedded +// clients. Batch size is not live-tunable; configure it at proxy startup. +func (c *Client) PerfSet(ctx context.Context, value uint32) error { + path := fmt.Sprintf("/debug/perf?value=%d", value) + return c.fetchAndPrint(ctx, path, c.printPerfSet) +} + +func (c *Client) printPerfSet(data map[string]any) { + if errMsg, ok := data["error"].(string); ok && errMsg != "" { + c.printError(data) + return + } + val, _ := data["value"].(float64) + applied, _ := data["applied"].(float64) + _, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val)) + _, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied)) + if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 { + _, _ = fmt.Fprintln(c.out, "Failed:") + for k, v := range failed { + _, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v) + } + } +} + +// Runtime fetches runtime stats (heap, goroutines, RSS). +func (c *Client) Runtime(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime) +} + +func (c *Client) printRuntime(data map[string]any) { + i := func(k string) uint64 { + v, _ := data[k].(float64) + return uint64(v) + } + mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) } + + _, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"]) + _, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs"))) + _, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines")) + _, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects")) + _, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns"))) + _, _ = fmt.Fprintln(c.out, "Heap:") + _, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc"))) + _, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse"))) + _, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle"))) + _, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released"))) + _, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys"))) + _, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys"))) + if _, ok := data["vm_rss"]; ok { + _, _ = fmt.Fprintln(c.out, "Process:") + _, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss"))) + _, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size"))) + _, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data"))) + } + _, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started")) +} + // StartClient starts a specific client. func (c *Client) StartClient(ctx context.Context, accountID string) error { path := "/debug/clients/" + url.PathEscape(accountID) + "/start" diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index 1dbfe1522..826c6817f 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -11,6 +11,8 @@ import ( "maps" "net" "net/http" + "os" + "runtime" "slices" "strconv" "strings" @@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A type clientProvider interface { GetClient(accountID types.AccountID) (*nbembed.Client, bool) ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo + ListClientsForStartup() map[types.AccountID]*nbembed.Client } // InboundListenerInfo describes a per-account inbound listener as @@ -165,6 +168,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleListClients(w, r, wantJSON) case "/debug/health": h.handleHealth(w, r, wantJSON) + case "/debug/perf": + h.handlePerf(w, r) + case "/debug/runtime": + h.handleRuntime(w, r) default: if h.handleClientRoutes(w, r, path, wantJSON) { return @@ -258,10 +265,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } if wantJSON { - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - clientsJSON = append(clientsJSON, map[string]interface{}{ + clientsJSON = append(clientsJSON, map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -270,7 +277,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } - resp := map[string]interface{}{ + resp := map[string]any{ "version": version.NetbirdVersion(), "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), @@ -352,10 +359,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want if h.inbound != nil { inboundAll = h.inbound.InboundListeners() } - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - row := map[string]interface{}{ + row := map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -368,7 +375,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want } clientsJSON = append(clientsJSON, row) } - resp := map[string]interface{}{ + resp := map[string]any{ "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), "clients": clientsJSON, @@ -458,7 +465,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc }) if wantJSON { - resp := map[string]interface{}{ + resp := map[string]any{ "account_id": accountID, "status": overview.FullDetailSummary(), } @@ -557,20 +564,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } host := r.URL.Query().Get("host") portStr := r.URL.Query().Get("port") if host == "" || portStr == "" { - h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"}) + h.writeJSON(w, map[string]any{"error": "host and port parameters required"}) return } port, err := strconv.Atoi(portStr) if err != nil || port < 1 || port > 65535 { - h.writeJSON(w, map[string]interface{}{"error": "invalid port"}) + h.writeJSON(w, map[string]any{"error": "invalid port"}) return } @@ -594,7 +601,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI conn, err := client.Dial(ctx, network, address) if err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "host": host, "port": port, @@ -609,39 +616,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } latency := time.Since(start) - resp := map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "host": host, "port": port, "remote": remote, "latency_ms": latency.Milliseconds(), "latency": formatDuration(latency), - } - h.writeJSON(w, resp) + }) } func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } level := r.URL.Query().Get("level") if level == "" { - h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"}) + h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"}) return } if err := client.SetLogLevel(level); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "level": level, }) @@ -652,7 +658,7 @@ const clientActionTimeout = 30 * time.Second func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -660,14 +666,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco defer cancel() if err := client.Start(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client started", }) @@ -676,7 +682,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -684,19 +690,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou defer cancel() if err := client.Stop(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client stopped", }) } +func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) { + raw := r.URL.Query().Get("value") + if raw == "" { + http.Error(w, "value parameter is required", http.StatusBadRequest) + return + } + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest) + return + } + + capN := uint32(n) + applied := 0 + failed := map[string]string{} + for accountID, client := range h.provider.ListClientsForStartup() { + if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil { + failed[string(accountID)] = err.Error() + continue + } + applied++ + } + + resp := map[string]any{ + "success": true, + "value": capN, + "applied": applied, + } + if len(failed) > 0 { + resp["failed"] = failed + } + h.writeJSON(w, resp) +} + +// handleRuntime returns cheap runtime and process stats. Safe to hit on a +// running proxy; does not read pprof profiles. +func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + + clients := h.provider.ListClientsForDebug() + started := 0 + for _, c := range clients { + if c.HasClient { + started++ + } + } + + resp := map[string]any{ + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "goroutines": runtime.NumGoroutine(), + "num_cpu": runtime.NumCPU(), + "gomaxprocs": runtime.GOMAXPROCS(0), + "go_version": runtime.Version(), + "heap_alloc": m.HeapAlloc, + "heap_inuse": m.HeapInuse, + "heap_idle": m.HeapIdle, + "heap_released": m.HeapReleased, + "heap_sys": m.HeapSys, + "sys": m.Sys, + "live_objects": m.Mallocs - m.Frees, + "num_gc": m.NumGC, + "pause_total_ns": m.PauseTotalNs, + "clients": len(clients), + "started": started, + } + + if proc := readProcStatus(); proc != nil { + resp["vm_rss"] = proc["VmRSS"] + resp["vm_size"] = proc["VmSize"] + resp["vm_data"] = proc["VmData"] + } + + h.writeJSON(w, resp) +} + +// readProcStatus parses /proc/self/status on Linux and returns size fields +// in bytes. Returns nil on non-Linux or read failure. +func readProcStatus() map[string]uint64 { + raw, err := os.ReadFile("/proc/self/status") + if err != nil { + return nil + } + out := map[string]uint64{} + for _, line := range strings.Split(string(raw), "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + if k != "VmRSS" && k != "VmSize" && k != "VmData" { + continue + } + fields := strings.Fields(v) + if len(fields) < 1 { + continue + } + n, err := strconv.ParseUint(fields[0], 10, 64) + if err != nil { + continue + } + // Values are reported in kB. + out[k] = n * 1024 + } + return out +} + const maxCaptureDuration = 30 * time.Minute // handleCapture streams a pcap or text packet capture for the given client. @@ -825,7 +937,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON h.writeJSON(w, resp) } -func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) { +func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) { w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl := h.getTemplates() if tmpl == nil { @@ -838,7 +950,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf } } -func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) { +func (h *Handler) writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) enc.SetIndent("", " ") diff --git a/proxy/internal/roundtrip/netbird.go b/proxy/internal/roundtrip/netbird.go index 133e86f05..11bca22e3 100644 --- a/proxy/internal/roundtrip/netbird.go +++ b/proxy/internal/roundtrip/netbird.go @@ -131,6 +131,7 @@ type ClientConfig struct { MgmtAddr string WGPort uint16 PreSharedKey string + Performance embed.Performance // BlockInbound mirrors embed.Options.BlockInbound. Set to true on the // standalone proxy where the embedded client never accepts inbound; // set to false on the private/embedded proxy so the engine creates @@ -306,7 +307,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account ManagementURL: n.clientCfg.MgmtAddr, PrivateKey: privateKey.String(), LogLevel: log.WarnLevel.String(), - BlockInbound: n.clientCfg.BlockInbound, + BlockInbound: n.clientCfg.BlockInbound, // The embedded proxy peer must never be a stepping stone into // the proxy host's LAN: it only exists to reach NetBird mesh // targets or, when direct_upstream is set, the host network @@ -315,6 +316,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account BlockLANAccess: true, WireguardPort: &wgPort, PreSharedKey: n.clientCfg.PreSharedKey, + Performance: n.clientCfg.Performance, }) if err != nil { return nil, fmt.Errorf("create netbird client: %w", err) diff --git a/proxy/lifecycle.go b/proxy/lifecycle.go index 6cb420722..9787f237e 100644 --- a/proxy/lifecycle.go +++ b/proxy/lifecycle.go @@ -6,6 +6,7 @@ import ( log "github.com/sirupsen/logrus" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy/internal/acme" ) @@ -89,6 +90,10 @@ type Config struct { // PreSharedKey is the WireGuard pre-shared key used between the // proxy's embedded clients and peers. PreSharedKey string + // Performance configures the tunnel pool/batch sizes for every + // embedded client this proxy creates. Zero values fall back to + // upstream defaults. + Performance embed.Performance // SupportsCustomPorts indicates whether the proxy can bind arbitrary // ports for TCP/UDP/TLS services. @@ -148,6 +153,7 @@ func New(cfg Config) *Server { WireguardPort: cfg.WireguardPort, ProxyProtocol: cfg.ProxyProtocol, PreSharedKey: cfg.PreSharedKey, + Performance: cfg.Performance, SupportsCustomPorts: cfg.SupportsCustomPorts, RequireSubdomain: cfg.RequireSubdomain, Private: cfg.Private, diff --git a/proxy/server.go b/proxy/server.go index 63a0c577a..037da925c 100644 --- a/proxy/server.go +++ b/proxy/server.go @@ -41,6 +41,7 @@ import ( goproto "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy/internal/accesslog" "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/proxy/internal/auth" @@ -185,6 +186,9 @@ type Server struct { // single-account deployments; multiple accounts will fail to bind // the same port. WireguardPort uint16 + // Performance configures the tunnel pool/batch sizes for every + // embedded client this proxy spawns. + Performance embed.Performance // ProxyProtocol enables PROXY protocol (v1/v2) on TCP listeners. // When enabled, the real client IP is extracted from the PROXY header // sent by upstream L4 proxies that support PROXY protocol. @@ -333,6 +337,8 @@ func (s *Server) Start(ctx context.Context) error { s.runCancel = runCancel s.initNetBirdClient() + // Create health checker before the mapping worker so it can track + // management connectivity from the first stream connection. s.healthChecker = health.NewChecker(s.Logger, s.netbird) s.crowdsecRegistry = crowdsec.NewRegistry(s.CrowdSecAPIURL, s.CrowdSecAPIKey, log.NewEntry(s.Logger)) @@ -529,6 +535,7 @@ func (s *Server) initNetBirdClient() { MgmtAddr: s.ManagementAddress, WGPort: s.WireguardPort, PreSharedKey: s.PreSharedKey, + Performance: s.Performance, // On --private the embedded client serves per-account inbound // listeners and must apply management's ACL: keep BlockInbound off // so the engine creates the ACL manager. On the standalone proxy