From e81ce81494bf69e440107971fb8c09893626f1c7 Mon Sep 17 00:00:00 2001 From: Viktor Liu Date: Wed, 22 Apr 2026 10:38:04 +0200 Subject: [PATCH] Bound embed client WireGuard per-Device memory --- client/embed/embed.go | 50 ++++++++++ client/internal/engine.go | 23 +++++ go.mod | 2 +- go.sum | 4 +- proxy/cmd/proxy/cmd/debug.go | 49 +++++++++ proxy/cmd/proxy/cmd/root.go | 47 +++++++++ proxy/internal/debug/client.go | 68 +++++++++++++ proxy/internal/debug/handler.go | 172 +++++++++++++++++++++++++++----- 8 files changed, 388 insertions(+), 27 deletions(-) diff --git a/client/embed/embed.go b/client/embed/embed.go index 88f7e541c..c15a75be0 100644 --- a/client/embed/embed.go +++ b/client/embed/embed.go @@ -12,6 +12,7 @@ import ( "sync" "github.com/sirupsen/logrus" + wgdevice "golang.zx2c4.com/wireguard/device" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" "github.com/netbirdio/netbird/client/iface" @@ -469,6 +470,55 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error { return sshcommon.VerifyHostKey(storedKey, key, peerAddress) } +// WGTuning bundles runtime-adjustable WireGuard knobs exposed by the embed +// client. Nil fields are left unchanged; set a non-nil pointer to apply. +type WGTuning struct { + // PreallocatedBuffersPerPool caps each per-Device WaitPool. + // Zero means "unbounded" (no cap). Live-tunable only if the underlying + // Device was originally created with a nonzero cap. + PreallocatedBuffersPerPool *uint32 +} + +// SetWGTuning applies the given tuning to this client's live Device. +// Startup-only knobs (batch size) must be set via the package-level +// setters before Start. +func (c *Client) SetWGTuning(t WGTuning) error { + engine, err := c.getEngine() + if err != nil { + return err + } + return engine.SetWGTuning(internal.WGTuning{ + PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool, + }) +} + +// SetWGDefaultPreallocatedBuffersPerPool sets the default WaitPool cap +// applied to Devices created after this call. Zero disables the cap. +// Existing Devices are unaffected; use Client.SetWGTuning for that. +func SetWGDefaultPreallocatedBuffersPerPool(n uint32) { + wgdevice.SetPreallocatedBuffersPerPool(n) +} + +// WGDefaultPreallocatedBuffersPerPool returns the current default WaitPool +// cap applied to newly-created Devices. +func WGDefaultPreallocatedBuffersPerPool() uint32 { + return wgdevice.PreallocatedBuffersPerPool +} + +// SetWGDefaultMaxBatchSize sets the default per-Device batch size applied +// to Devices created after this call. Zero means "use the bind+tun default" +// (NOT unlimited). Must be called before Start to take effect for a new +// Client. +func SetWGDefaultMaxBatchSize(n uint32) { + wgdevice.SetMaxBatchSizeOverride(n) +} + +// WGDefaultMaxBatchSize returns the current default batch-size override. +// Zero means "no override". +func WGDefaultMaxBatchSize() uint32 { + return wgdevice.MaxBatchSizeOverride +} + // getEngine safely retrieves the engine from the client with proper locking. // Returns ErrClientNotStarted if the client is not started. // Returns ErrEngineNotStarted if the engine is not available. diff --git a/client/internal/engine.go b/client/internal/engine.go index 8d7e02bd5..f6fe5a64e 100644 --- a/client/internal/engine.go +++ b/client/internal/engine.go @@ -1874,6 +1874,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics { return e.clientMetrics } +// WGTuning bundles runtime-adjustable WireGuard pool knobs. +// See Engine.SetWGTuning. Nil fields are ignored. +type WGTuning struct { + PreallocatedBuffersPerPool *uint32 +} + +// SetWGTuning applies the given tuning to this engine's live Device. +func (e *Engine) SetWGTuning(t WGTuning) error { + e.syncMsgMux.Lock() + defer e.syncMsgMux.Unlock() + if e.wgInterface == nil { + return fmt.Errorf("wg interface not initialized") + } + dev := e.wgInterface.GetWGDevice() + if dev == nil { + return fmt.Errorf("wg device not initialized") + } + if t.PreallocatedBuffersPerPool != nil { + dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool) + } + return nil +} + func findIPFromInterfaceName(ifaceName string) (net.IP, error) { iface, err := net.InterfaceByName(ifaceName) if err != nil { diff --git a/go.mod b/go.mod index 1b5861a37..5e60cf2a4 100644 --- a/go.mod +++ b/go.mod @@ -314,7 +314,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 -replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 +replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 diff --git a/go.sum b/go.sum index 3772946e1..bee926939 100644 --- a/go.sum +++ b/go.sum @@ -459,8 +459,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= -github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= +github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0= +github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= diff --git a/proxy/cmd/proxy/cmd/debug.go b/proxy/cmd/proxy/cmd/debug.go index 59f7a6b65..2bf8b4e8e 100644 --- a/proxy/cmd/proxy/cmd/debug.go +++ b/proxy/cmd/proxy/cmd/debug.go @@ -99,6 +99,35 @@ var debugStopCmd = &cobra.Command{ SilenceUsage: true, } +var debugWGTuneCmd = &cobra.Command{ + Use: "wgtune", + Short: "Inspect and live-tune WireGuard pool settings", +} + +var debugWGTuneGetCmd = &cobra.Command{ + Use: "get", + Short: "Show pool cap and batch size defaults", + Args: cobra.NoArgs, + RunE: runDebugWGTuneGet, + SilenceUsage: true, +} + +var debugWGTuneSetCmd = &cobra.Command{ + Use: "set ", + Short: "Set the pool cap (new and live clients)", + Args: cobra.ExactArgs(1), + RunE: runDebugWGTuneSet, + SilenceUsage: true, +} + +var debugRuntimeCmd = &cobra.Command{ + Use: "runtime", + Short: "Show runtime stats (heap, goroutines, RSS)", + Args: cobra.NoArgs, + RunE: runDebugRuntime, + SilenceUsage: true, +} + func init() { debugCmd.PersistentFlags().StringVar(&debugAddr, "addr", envStringOrDefault("NB_PROXY_DEBUG_ADDRESS", "localhost:8444"), "Debug endpoint address") debugCmd.PersistentFlags().BoolVar(&jsonOutput, "json", false, "Output JSON instead of pretty format") @@ -119,6 +148,10 @@ func init() { debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStopCmd) + debugWGTuneCmd.AddCommand(debugWGTuneGetCmd) + debugWGTuneCmd.AddCommand(debugWGTuneSetCmd) + debugCmd.AddCommand(debugWGTuneCmd) + debugCmd.AddCommand(debugRuntimeCmd) rootCmd.AddCommand(debugCmd) } @@ -171,3 +204,19 @@ func runDebugStart(cmd *cobra.Command, args []string) error { func runDebugStop(cmd *cobra.Command, args []string) error { return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) } + +func runDebugWGTuneGet(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).WGTuneGet(cmd.Context()) +} + +func runDebugWGTuneSet(cmd *cobra.Command, args []string) error { + n, err := strconv.ParseUint(args[0], 10, 32) + if err != nil { + return fmt.Errorf("invalid value %q: %w", args[0], err) + } + return getDebugClient(cmd).WGTuneSet(cmd.Context(), uint32(n)) +} + +func runDebugRuntime(cmd *cobra.Command, _ []string) error { + return getDebugClient(cmd).Runtime(cmd.Context()) +} diff --git a/proxy/cmd/proxy/cmd/root.go b/proxy/cmd/proxy/cmd/root.go index ec8980ad9..fc09ec492 100644 --- a/proxy/cmd/proxy/cmd/root.go +++ b/proxy/cmd/proxy/cmd/root.go @@ -15,11 +15,22 @@ import ( "github.com/netbirdio/netbird/shared/management/domain" + "github.com/netbirdio/netbird/client/embed" "github.com/netbirdio/netbird/proxy" nbacme "github.com/netbirdio/netbird/proxy/internal/acme" "github.com/netbirdio/netbird/util" ) +const ( + // envWGPreallocatedBuffers caps the per-Device WireGuard buffer pool + // size. Zero (unset) keeps the uncapped upstream default. + envWGPreallocatedBuffers = "NB_WG_PREALLOCATED_BUFFERS" + // envWGMaxBatchSize overrides the per-Device WireGuard batch size, + // which controls how many buffers each receive/TUN worker eagerly + // allocates. Zero (unset) keeps the bind+tun default. + envWGMaxBatchSize = "NB_WG_MAX_BATCH_SIZE" +) + const DefaultManagementURL = "https://api.netbird.io:443" // envProxyToken is the environment variable name for the proxy access token. @@ -145,6 +156,42 @@ func runServer(cmd *cobra.Command, args []string) error { logger.Infof("configured log level: %s", level) + var wgPool, wgBatch uint64 + if raw := os.Getenv(envWGPreallocatedBuffers); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envWGPreallocatedBuffers, raw, err) + } + wgPool = n + embed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n)) + logger.Infof("wireguard preallocated buffers per pool: %d", n) + } + if raw := os.Getenv(envWGMaxBatchSize); raw != "" { + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + return fmt.Errorf("invalid %s %q: %w", envWGMaxBatchSize, raw, err) + } + wgBatch = n + embed.SetWGDefaultMaxBatchSize(uint32(n)) + logger.Infof("wireguard max batch size override: %d", n) + } + if wgPool > 0 { + // Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus + // RoutineReadFromTUN eagerly reserves `batch` message buffers for + // the lifetime of the Device. A pool cap below that floor blocks + // the receive pipeline at startup. + batch := wgBatch + if batch == 0 { + batch = 128 + } + const recvGoroutines = 4 + floor := batch * recvGoroutines + if wgPool < floor { + logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock", + envWGPreallocatedBuffers, wgPool, floor, batch) + } + } + switch forwardedProto { case "auto", "http", "https": default: diff --git a/proxy/internal/debug/client.go b/proxy/internal/debug/client.go index 01b0bc8e6..a5f51029e 100644 --- a/proxy/internal/debug/client.go +++ b/proxy/internal/debug/client.go @@ -272,6 +272,74 @@ func (c *Client) printLogLevelResult(data map[string]any) { } } +// WGTuneGet fetches the current WireGuard pool cap. +func (c *Client) WGTuneGet(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/wgtune", c.printWGTuneGet) +} + +// WGTuneSet updates the WireGuard pool cap on the global default and all live clients. +func (c *Client) WGTuneSet(ctx context.Context, value uint32) error { + path := fmt.Sprintf("/debug/wgtune?value=%d", value) + return c.fetchAndPrint(ctx, path, c.printWGTuneSet) +} + +func (c *Client) printWGTuneGet(data map[string]any) { + def, _ := data["default"].(float64) + batch, _ := data["batch_size"].(float64) + _, _ = fmt.Fprintf(c.out, "Default: %d\n", uint32(def)) + _, _ = fmt.Fprintf(c.out, "Batch size: %d (0 = unset)\n", uint32(batch)) +} + +func (c *Client) printWGTuneSet(data map[string]any) { + if errMsg, ok := data["error"].(string); ok && errMsg != "" { + c.printError(data) + return + } + def, _ := data["default"].(float64) + applied, _ := data["applied"].(float64) + _, _ = fmt.Fprintf(c.out, "Default set to: %d\n", uint32(def)) + _, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied)) + if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 { + _, _ = fmt.Fprintln(c.out, "Failed:") + for k, v := range failed { + _, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v) + } + } +} + +// Runtime fetches runtime stats (heap, goroutines, RSS). +func (c *Client) Runtime(ctx context.Context) error { + return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime) +} + +func (c *Client) printRuntime(data map[string]any) { + i := func(k string) uint64 { + v, _ := data[k].(float64) + return uint64(v) + } + mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) } + + _, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"]) + _, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs"))) + _, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines")) + _, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects")) + _, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns"))) + _, _ = fmt.Fprintln(c.out, "Heap:") + _, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc"))) + _, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse"))) + _, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle"))) + _, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released"))) + _, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys"))) + _, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys"))) + if _, ok := data["vm_rss"]; ok { + _, _ = fmt.Fprintln(c.out, "Process:") + _, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss"))) + _, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size"))) + _, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data"))) + } + _, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started")) +} + // StartClient starts a specific client. func (c *Client) StartClient(ctx context.Context, accountID string) error { path := "/debug/clients/" + url.PathEscape(accountID) + "/start" diff --git a/proxy/internal/debug/handler.go b/proxy/internal/debug/handler.go index c507cfad9..63420c59a 100644 --- a/proxy/internal/debug/handler.go +++ b/proxy/internal/debug/handler.go @@ -10,6 +10,8 @@ import ( "html/template" "maps" "net/http" + "os" + "runtime" "slices" "strconv" "strings" @@ -58,6 +60,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A type clientProvider interface { GetClient(accountID types.AccountID) (*nbembed.Client, bool) ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo + ListClientsForStartup() map[types.AccountID]*nbembed.Client } // healthChecker provides health probe state. @@ -139,6 +142,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.handleListClients(w, r, wantJSON) case "/debug/health": h.handleHealth(w, r, wantJSON) + case "/debug/wgtune": + h.handleWGTune(w, r) + case "/debug/runtime": + h.handleRuntime(w, r) default: if h.handleClientRoutes(w, r, path, wantJSON) { return @@ -230,10 +237,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b } if wantJSON { - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - clientsJSON = append(clientsJSON, map[string]interface{}{ + clientsJSON = append(clientsJSON, map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -242,7 +249,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } - resp := map[string]interface{}{ + resp := map[string]any{ "version": version.NetbirdVersion(), "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), @@ -320,10 +327,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want sortedIDs := sortedAccountIDs(clients) if wantJSON { - clientsJSON := make([]map[string]interface{}, 0, len(clients)) + clientsJSON := make([]map[string]any, 0, len(clients)) for _, id := range sortedIDs { info := clients[id] - clientsJSON = append(clientsJSON, map[string]interface{}{ + clientsJSON = append(clientsJSON, map[string]any{ "account_id": info.AccountID, "service_count": info.ServiceCount, "service_keys": info.ServiceKeys, @@ -332,7 +339,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want "age": time.Since(info.CreatedAt).Round(time.Second).String(), }) } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "uptime": time.Since(h.startTime).Round(time.Second).String(), "client_count": len(clients), "clients": clientsJSON, @@ -418,7 +425,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc }) if wantJSON { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "account_id": accountID, "status": overview.FullDetailSummary(), }) @@ -501,20 +508,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } host := r.URL.Query().Get("host") portStr := r.URL.Query().Get("port") if host == "" || portStr == "" { - h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"}) + h.writeJSON(w, map[string]any{"error": "host and port parameters required"}) return } port, err := strconv.Atoi(portStr) if err != nil || port < 1 || port > 65535 { - h.writeJSON(w, map[string]interface{}{"error": "invalid port"}) + h.writeJSON(w, map[string]any{"error": "invalid port"}) return } @@ -533,7 +540,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI conn, err := client.Dial(ctx, "tcp", address) if err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "host": host, "port": port, @@ -546,7 +553,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI } latency := time.Since(start) - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "host": host, "port": port, @@ -558,25 +565,25 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } level := r.URL.Query().Get("level") if level == "" { - h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"}) + h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"}) return } if err := client.SetLogLevel(level); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "level": level, }) @@ -587,7 +594,7 @@ const clientActionTimeout = 30 * time.Second func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -595,14 +602,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco defer cancel() if err := client.Start(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client started", }) @@ -611,7 +618,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { client, ok := h.provider.GetClient(accountID) if !ok { - h.writeJSON(w, map[string]interface{}{"error": "client not found"}) + h.writeJSON(w, map[string]any{"error": "client not found"}) return } @@ -619,19 +626,136 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou defer cancel() if err := client.Stop(ctx); err != nil { - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": false, "error": err.Error(), }) return } - h.writeJSON(w, map[string]interface{}{ + h.writeJSON(w, map[string]any{ "success": true, "message": "client stopped", }) } +func (h *Handler) handleWGTune(w http.ResponseWriter, r *http.Request) { + values, ok := r.URL.Query()["value"] + if !ok { + h.writeJSON(w, map[string]any{ + "default": nbembed.WGDefaultPreallocatedBuffersPerPool(), + "batch_size": nbembed.WGDefaultMaxBatchSize(), + }) + return + } + if len(values) == 0 || values[0] == "" { + http.Error(w, "value parameter must not be empty", http.StatusBadRequest) + return + } + raw := values[0] + + n, err := strconv.ParseUint(raw, 10, 32) + if err != nil { + http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest) + return + } + nbembed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n)) + + applied := 0 + failed := map[string]string{} + for accountID, client := range h.provider.ListClientsForStartup() { + capN := uint32(n) + if err := client.SetWGTuning(nbembed.WGTuning{PreallocatedBuffersPerPool: &capN}); err != nil { + failed[string(accountID)] = err.Error() + continue + } + applied++ + } + + resp := map[string]any{ + "success": true, + "default": uint32(n), + "batch_size": nbembed.WGDefaultMaxBatchSize(), + "applied": applied, + } + if len(failed) > 0 { + resp["failed"] = failed + } + h.writeJSON(w, resp) +} + +// handleRuntime returns cheap runtime and process stats. Safe to hit on a +// running proxy; does not read pprof profiles. +func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) { + var m runtime.MemStats + runtime.ReadMemStats(&m) + + clients := h.provider.ListClientsForDebug() + started := 0 + for _, c := range clients { + if c.HasClient { + started++ + } + } + + resp := map[string]any{ + "uptime": time.Since(h.startTime).Round(time.Second).String(), + "goroutines": runtime.NumGoroutine(), + "num_cpu": runtime.NumCPU(), + "gomaxprocs": runtime.GOMAXPROCS(0), + "go_version": runtime.Version(), + "heap_alloc": m.HeapAlloc, + "heap_inuse": m.HeapInuse, + "heap_idle": m.HeapIdle, + "heap_released": m.HeapReleased, + "heap_sys": m.HeapSys, + "sys": m.Sys, + "live_objects": m.Mallocs - m.Frees, + "num_gc": m.NumGC, + "pause_total_ns": m.PauseTotalNs, + "clients": len(clients), + "started": started, + } + + if proc := readProcStatus(); proc != nil { + resp["vm_rss"] = proc["VmRSS"] + resp["vm_size"] = proc["VmSize"] + resp["vm_data"] = proc["VmData"] + } + + h.writeJSON(w, resp) +} + +// readProcStatus parses /proc/self/status on Linux and returns size fields +// in bytes. Returns nil on non-Linux or read failure. +func readProcStatus() map[string]uint64 { + raw, err := os.ReadFile("/proc/self/status") + if err != nil { + return nil + } + out := map[string]uint64{} + for _, line := range strings.Split(string(raw), "\n") { + k, v, ok := strings.Cut(line, ":") + if !ok { + continue + } + if k != "VmRSS" && k != "VmSize" && k != "VmData" { + continue + } + fields := strings.Fields(v) + if len(fields) < 1 { + continue + } + n, err := strconv.ParseUint(fields[0], 10, 64) + if err != nil { + continue + } + // Values are reported in kB. + out[k] = n * 1024 + } + return out +} + func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON bool) { if !wantJSON { http.Redirect(w, r, "/debug", http.StatusSeeOther) @@ -685,7 +809,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON h.writeJSON(w, resp) } -func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) { +func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) { w.Header().Set("Content-Type", "text/html; charset=utf-8") tmpl := h.getTemplates() if tmpl == nil { @@ -698,7 +822,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf } } -func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) { +func (h *Handler) writeJSON(w http.ResponseWriter, v any) { w.Header().Set("Content-Type", "application/json") enc := json.NewEncoder(w) enc.SetIndent("", " ")