Compare commits

...

3 Commits

Author SHA1 Message Date
Viktor Liu
8520bbbe57 Use builtin clear instead of maps.Clear 2026-05-04 12:00:59 +02:00
Viktor Liu
533dd9577b Merge remote-tracking branch 'origin/main' into reduce-embed-wg-pool
# Conflicts:
#	client/embed/embed.go
#	proxy/cmd/proxy/cmd/debug.go
#	proxy/internal/debug/handler.go
2026-05-04 11:37:29 +02:00
Viktor Liu
e81ce81494 Bound embed client WireGuard per-Device memory 2026-04-22 13:04:40 +02:00
10 changed files with 393 additions and 33 deletions

View File

@@ -12,6 +12,7 @@ import (
"sync" "sync"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack" wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
@@ -470,6 +471,55 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress) return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
} }
// WGTuning bundles runtime-adjustable WireGuard knobs exposed by the embed
// client. Nil fields are left unchanged; set a non-nil pointer to apply.
type WGTuning struct {
// PreallocatedBuffersPerPool caps each per-Device WaitPool.
// Zero means "unbounded" (no cap). Live-tunable only if the underlying
// Device was originally created with a nonzero cap.
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this client's live Device.
// Startup-only knobs (batch size) must be set via the package-level
// setters before Start.
func (c *Client) SetWGTuning(t WGTuning) error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetWGTuning(internal.WGTuning{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// SetWGDefaultPreallocatedBuffersPerPool sets the default WaitPool cap
// applied to Devices created after this call. Zero disables the cap.
// Existing Devices are unaffected; use Client.SetWGTuning for that.
func SetWGDefaultPreallocatedBuffersPerPool(n uint32) {
wgdevice.SetPreallocatedBuffersPerPool(n)
}
// WGDefaultPreallocatedBuffersPerPool returns the current default WaitPool
// cap applied to newly-created Devices.
func WGDefaultPreallocatedBuffersPerPool() uint32 {
return wgdevice.PreallocatedBuffersPerPool
}
// SetWGDefaultMaxBatchSize sets the default per-Device batch size applied
// to Devices created after this call. Zero means "use the bind+tun default"
// (NOT unlimited). Must be called before Start to take effect for a new
// Client.
func SetWGDefaultMaxBatchSize(n uint32) {
wgdevice.SetMaxBatchSizeOverride(n)
}
// WGDefaultMaxBatchSize returns the current default batch-size override.
// Zero means "no override".
func WGDefaultMaxBatchSize() uint32 {
return wgdevice.MaxBatchSizeOverride
}
// StartCapture begins capturing packets on this client's tunnel device. // StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous. // Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it. // Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -19,7 +19,6 @@ import (
"github.com/google/gopacket/layers" "github.com/google/gopacket/layers"
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
firewall "github.com/netbirdio/netbird/client/firewall/manager" firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common" "github.com/netbirdio/netbird/client/firewall/uspfilter/common"
@@ -612,10 +611,10 @@ func (m *Manager) Flush() error { return nil }
// resetState clears all firewall rules and closes connection trackers. // resetState clears all firewall rules and closes connection trackers.
// Must be called with m.mutex held. // Must be called with m.mutex held.
func (m *Manager) resetState() { func (m *Manager) resetState() {
maps.Clear(m.outgoingRules) clear(m.outgoingRules)
maps.Clear(m.incomingDenyRules) clear(m.incomingDenyRules)
maps.Clear(m.incomingRules) clear(m.incomingRules)
maps.Clear(m.routeRulesMap) clear(m.routeRulesMap)
m.routeRules = m.routeRules[:0] m.routeRules = m.routeRules[:0]
m.udpHookOut.Store(nil) m.udpHookOut.Store(nil)
m.tcpHookOut.Store(nil) m.tcpHookOut.Store(nil)

View File

@@ -410,7 +410,7 @@ func (s *DefaultServer) Stop() {
log.Errorf("failed to disable DNS: %v", err) log.Errorf("failed to disable DNS: %v", err)
} }
maps.Clear(s.extraDomains) clear(s.extraDomains)
} }
func (s *DefaultServer) disableDNS() (retErr error) { func (s *DefaultServer) disableDNS() (retErr error) {

View File

@@ -1888,6 +1888,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics return e.clientMetrics
} }
// WGTuning bundles runtime-adjustable WireGuard pool knobs.
// See Engine.SetWGTuning. Nil fields are ignored.
type WGTuning struct {
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this engine's live Device.
func (e *Engine) SetWGTuning(t WGTuning) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) { func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName) iface, err := net.InterfaceByName(ifaceName)
if err != nil { if err != nil {

2
go.mod
View File

@@ -315,7 +315,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949 replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6 replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

4
go.sum
View File

@@ -461,8 +461,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM= github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ= github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ= github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw= github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk= github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=

View File

@@ -104,6 +104,35 @@ var debugStopCmd = &cobra.Command{
SilenceUsage: true, SilenceUsage: true,
} }
var debugWGTuneCmd = &cobra.Command{
Use: "wgtune",
Short: "Inspect and live-tune WireGuard pool settings",
}
var debugWGTuneGetCmd = &cobra.Command{
Use: "get",
Short: "Show pool cap and batch size defaults",
Args: cobra.NoArgs,
RunE: runDebugWGTuneGet,
SilenceUsage: true,
}
var debugWGTuneSetCmd = &cobra.Command{
Use: "set <pool-cap>",
Short: "Set the pool cap (new and live clients)",
Args: cobra.ExactArgs(1),
RunE: runDebugWGTuneSet,
SilenceUsage: true,
}
var debugRuntimeCmd = &cobra.Command{
Use: "runtime",
Short: "Show runtime stats (heap, goroutines, RSS)",
Args: cobra.NoArgs,
RunE: runDebugRuntime,
SilenceUsage: true,
}
var debugCaptureCmd = &cobra.Command{ var debugCaptureCmd = &cobra.Command{
Use: "capture <account-id> [filter expression]", Use: "capture <account-id> [filter expression]",
Short: "Capture packets on a client's WireGuard interface", Short: "Capture packets on a client's WireGuard interface",
@@ -151,6 +180,10 @@ func init() {
debugCmd.AddCommand(debugLogCmd) debugCmd.AddCommand(debugLogCmd)
debugCmd.AddCommand(debugStartCmd) debugCmd.AddCommand(debugStartCmd)
debugCmd.AddCommand(debugStopCmd) debugCmd.AddCommand(debugStopCmd)
debugWGTuneCmd.AddCommand(debugWGTuneGetCmd)
debugWGTuneCmd.AddCommand(debugWGTuneSetCmd)
debugCmd.AddCommand(debugWGTuneCmd)
debugCmd.AddCommand(debugRuntimeCmd)
debugCmd.AddCommand(debugCaptureCmd) debugCmd.AddCommand(debugCaptureCmd)
rootCmd.AddCommand(debugCmd) rootCmd.AddCommand(debugCmd)
@@ -205,6 +238,22 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).StopClient(cmd.Context(), args[0]) return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
} }
func runDebugWGTuneGet(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).WGTuneGet(cmd.Context())
}
func runDebugWGTuneSet(cmd *cobra.Command, args []string) error {
n, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
return fmt.Errorf("invalid value %q: %w", args[0], err)
}
return getDebugClient(cmd).WGTuneSet(cmd.Context(), uint32(n))
}
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).Runtime(cmd.Context())
}
func runDebugCapture(cmd *cobra.Command, args []string) error { func runDebugCapture(cmd *cobra.Command, args []string) error {
duration, _ := cmd.Flags().GetDuration("duration") duration, _ := cmd.Flags().GetDuration("duration")
forcePcap, _ := cmd.Flags().GetBool("pcap") forcePcap, _ := cmd.Flags().GetBool("pcap")

View File

@@ -15,11 +15,22 @@ import (
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy" "github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme" nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
const (
// envWGPreallocatedBuffers caps the per-Device WireGuard buffer pool
// size. Zero (unset) keeps the uncapped upstream default.
envWGPreallocatedBuffers = "NB_WG_PREALLOCATED_BUFFERS"
// envWGMaxBatchSize overrides the per-Device WireGuard batch size,
// which controls how many buffers each receive/TUN worker eagerly
// allocates. Zero (unset) keeps the bind+tun default.
envWGMaxBatchSize = "NB_WG_MAX_BATCH_SIZE"
)
const DefaultManagementURL = "https://api.netbird.io:443" const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token. // envProxyToken is the environment variable name for the proxy access token.
@@ -145,6 +156,42 @@ func runServer(cmd *cobra.Command, args []string) error {
logger.Infof("configured log level: %s", level) logger.Infof("configured log level: %s", level)
var wgPool, wgBatch uint64
if raw := os.Getenv(envWGPreallocatedBuffers); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGPreallocatedBuffers, raw, err)
}
wgPool = n
embed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
logger.Infof("wireguard preallocated buffers per pool: %d", n)
}
if raw := os.Getenv(envWGMaxBatchSize); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGMaxBatchSize, raw, err)
}
wgBatch = n
embed.SetWGDefaultMaxBatchSize(uint32(n))
logger.Infof("wireguard max batch size override: %d", n)
}
if wgPool > 0 {
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
// the lifetime of the Device. A pool cap below that floor blocks
// the receive pipeline at startup.
batch := wgBatch
if batch == 0 {
batch = 128
}
const recvGoroutines = 4
floor := batch * recvGoroutines
if wgPool < floor {
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
envWGPreallocatedBuffers, wgPool, floor, batch)
}
}
switch forwardedProto { switch forwardedProto {
case "auto", "http", "https": case "auto", "http", "https":
default: default:

View File

@@ -272,6 +272,74 @@ func (c *Client) printLogLevelResult(data map[string]any) {
} }
} }
// WGTuneGet fetches the current WireGuard pool cap.
func (c *Client) WGTuneGet(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/wgtune", c.printWGTuneGet)
}
// WGTuneSet updates the WireGuard pool cap on the global default and all live clients.
func (c *Client) WGTuneSet(ctx context.Context, value uint32) error {
path := fmt.Sprintf("/debug/wgtune?value=%d", value)
return c.fetchAndPrint(ctx, path, c.printWGTuneSet)
}
func (c *Client) printWGTuneGet(data map[string]any) {
def, _ := data["default"].(float64)
batch, _ := data["batch_size"].(float64)
_, _ = fmt.Fprintf(c.out, "Default: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Batch size: %d (0 = unset)\n", uint32(batch))
}
func (c *Client) printWGTuneSet(data map[string]any) {
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
c.printError(data)
return
}
def, _ := data["default"].(float64)
applied, _ := data["applied"].(float64)
_, _ = fmt.Fprintf(c.out, "Default set to: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
_, _ = fmt.Fprintln(c.out, "Failed:")
for k, v := range failed {
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
}
}
}
// Runtime fetches runtime stats (heap, goroutines, RSS).
func (c *Client) Runtime(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
}
func (c *Client) printRuntime(data map[string]any) {
i := func(k string) uint64 {
v, _ := data[k].(float64)
return uint64(v)
}
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
_, _ = fmt.Fprintln(c.out, "Heap:")
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
if _, ok := data["vm_rss"]; ok {
_, _ = fmt.Fprintln(c.out, "Process:")
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
}
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
}
// StartClient starts a specific client. // StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error { func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start" path := "/debug/clients/" + url.PathEscape(accountID) + "/start"

View File

@@ -10,6 +10,8 @@ import (
"html/template" "html/template"
"maps" "maps"
"net/http" "net/http"
"os"
"runtime"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -58,6 +60,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
type clientProvider interface { type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool) GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
ListClientsForStartup() map[types.AccountID]*nbembed.Client
} }
// healthChecker provides health probe state. // healthChecker provides health probe state.
@@ -139,6 +142,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handleListClients(w, r, wantJSON) h.handleListClients(w, r, wantJSON)
case "/debug/health": case "/debug/health":
h.handleHealth(w, r, wantJSON) h.handleHealth(w, r, wantJSON)
case "/debug/wgtune":
h.handleWGTune(w, r)
case "/debug/runtime":
h.handleRuntime(w, r)
default: default:
if h.handleClientRoutes(w, r, path, wantJSON) { if h.handleClientRoutes(w, r, path, wantJSON) {
return return
@@ -232,10 +239,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
} }
if wantJSON { if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients)) clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs { for _, id := range sortedIDs {
info := clients[id] info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{ clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID, "account_id": info.AccountID,
"service_count": info.ServiceCount, "service_count": info.ServiceCount,
"service_keys": info.ServiceKeys, "service_keys": info.ServiceKeys,
@@ -244,7 +251,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
"age": time.Since(info.CreatedAt).Round(time.Second).String(), "age": time.Since(info.CreatedAt).Round(time.Second).String(),
}) })
} }
resp := map[string]interface{}{ resp := map[string]any{
"version": version.NetbirdVersion(), "version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(), "uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients), "client_count": len(clients),
@@ -322,10 +329,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients) sortedIDs := sortedAccountIDs(clients)
if wantJSON { if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients)) clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs { for _, id := range sortedIDs {
info := clients[id] info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{ clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID, "account_id": info.AccountID,
"service_count": info.ServiceCount, "service_count": info.ServiceCount,
"service_keys": info.ServiceKeys, "service_keys": info.ServiceKeys,
@@ -334,7 +341,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
"age": time.Since(info.CreatedAt).Round(time.Second).String(), "age": time.Since(info.CreatedAt).Round(time.Second).String(),
}) })
} }
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(), "uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients), "client_count": len(clients),
"clients": clientsJSON, "clients": clientsJSON,
@@ -420,7 +427,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
}) })
if wantJSON { if wantJSON {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"account_id": accountID, "account_id": accountID,
"status": overview.FullDetailSummary(), "status": overview.FullDetailSummary(),
}) })
@@ -503,20 +510,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID) client, ok := h.provider.GetClient(accountID)
if !ok { if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"}) h.writeJSON(w, map[string]any{"error": "client not found"})
return return
} }
host := r.URL.Query().Get("host") host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port") portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" { if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"}) h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
return return
} }
port, err := strconv.Atoi(portStr) port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 { if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"}) h.writeJSON(w, map[string]any{"error": "invalid port"})
return return
} }
@@ -535,7 +542,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
conn, err := client.Dial(ctx, "tcp", address) conn, err := client.Dial(ctx, "tcp", address)
if err != nil { if err != nil {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": false, "success": false,
"host": host, "host": host,
"port": port, "port": port,
@@ -548,7 +555,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
} }
latency := time.Since(start) latency := time.Since(start)
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": true, "success": true,
"host": host, "host": host,
"port": port, "port": port,
@@ -560,25 +567,25 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID) client, ok := h.provider.GetClient(accountID)
if !ok { if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"}) h.writeJSON(w, map[string]any{"error": "client not found"})
return return
} }
level := r.URL.Query().Get("level") level := r.URL.Query().Get("level")
if level == "" { if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"}) h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
return return
} }
if err := client.SetLogLevel(level); err != nil { if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": false, "success": false,
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": true, "success": true,
"level": level, "level": level,
}) })
@@ -589,7 +596,7 @@ const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID) client, ok := h.provider.GetClient(accountID)
if !ok { if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"}) h.writeJSON(w, map[string]any{"error": "client not found"})
return return
} }
@@ -597,14 +604,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
defer cancel() defer cancel()
if err := client.Start(ctx); err != nil { if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": false, "success": false,
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": true, "success": true,
"message": "client started", "message": "client started",
}) })
@@ -613,7 +620,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) { func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID) client, ok := h.provider.GetClient(accountID)
if !ok { if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"}) h.writeJSON(w, map[string]any{"error": "client not found"})
return return
} }
@@ -621,19 +628,136 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
defer cancel() defer cancel()
if err := client.Stop(ctx); err != nil { if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": false, "success": false,
"error": err.Error(), "error": err.Error(),
}) })
return return
} }
h.writeJSON(w, map[string]interface{}{ h.writeJSON(w, map[string]any{
"success": true, "success": true,
"message": "client stopped", "message": "client stopped",
}) })
} }
func (h *Handler) handleWGTune(w http.ResponseWriter, r *http.Request) {
values, ok := r.URL.Query()["value"]
if !ok {
h.writeJSON(w, map[string]any{
"default": nbembed.WGDefaultPreallocatedBuffersPerPool(),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
})
return
}
if len(values) == 0 || values[0] == "" {
http.Error(w, "value parameter must not be empty", http.StatusBadRequest)
return
}
raw := values[0]
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
return
}
nbembed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
applied := 0
failed := map[string]string{}
for accountID, client := range h.provider.ListClientsForStartup() {
capN := uint32(n)
if err := client.SetWGTuning(nbembed.WGTuning{PreallocatedBuffersPerPool: &capN}); err != nil {
failed[string(accountID)] = err.Error()
continue
}
applied++
}
resp := map[string]any{
"success": true,
"default": uint32(n),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
"applied": applied,
}
if len(failed) > 0 {
resp["failed"] = failed
}
h.writeJSON(w, resp)
}
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
// running proxy; does not read pprof profiles.
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
clients := h.provider.ListClientsForDebug()
started := 0
for _, c := range clients {
if c.HasClient {
started++
}
}
resp := map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"goroutines": runtime.NumGoroutine(),
"num_cpu": runtime.NumCPU(),
"gomaxprocs": runtime.GOMAXPROCS(0),
"go_version": runtime.Version(),
"heap_alloc": m.HeapAlloc,
"heap_inuse": m.HeapInuse,
"heap_idle": m.HeapIdle,
"heap_released": m.HeapReleased,
"heap_sys": m.HeapSys,
"sys": m.Sys,
"live_objects": m.Mallocs - m.Frees,
"num_gc": m.NumGC,
"pause_total_ns": m.PauseTotalNs,
"clients": len(clients),
"started": started,
}
if proc := readProcStatus(); proc != nil {
resp["vm_rss"] = proc["VmRSS"]
resp["vm_size"] = proc["VmSize"]
resp["vm_data"] = proc["VmData"]
}
h.writeJSON(w, resp)
}
// readProcStatus parses /proc/self/status on Linux and returns size fields
// in bytes. Returns nil on non-Linux or read failure.
func readProcStatus() map[string]uint64 {
raw, err := os.ReadFile("/proc/self/status")
if err != nil {
return nil
}
out := map[string]uint64{}
for _, line := range strings.Split(string(raw), "\n") {
k, v, ok := strings.Cut(line, ":")
if !ok {
continue
}
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
continue
}
fields := strings.Fields(v)
if len(fields) < 1 {
continue
}
n, err := strconv.ParseUint(fields[0], 10, 64)
if err != nil {
continue
}
// Values are reported in kB.
out[k] = n * 1024
}
return out
}
const maxCaptureDuration = 30 * time.Minute const maxCaptureDuration = 30 * time.Minute
// handleCapture streams a pcap or text packet capture for the given client. // handleCapture streams a pcap or text packet capture for the given client.
@@ -762,7 +886,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
h.writeJSON(w, resp) h.writeJSON(w, resp)
} }
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) { func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates() tmpl := h.getTemplates()
if tmpl == nil { if tmpl == nil {
@@ -775,7 +899,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
} }
} }
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) { func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w) enc := json.NewEncoder(w)
enc.SetIndent("", " ") enc.SetIndent("", " ")