[proxy, client] Bound embed client WireGuard per-Device memory (#5962)

This commit is contained in:
Viktor Liu
2026-05-26 18:51:53 +09:00
committed by GitHub
parent d542c60e21
commit e89b1e0596
11 changed files with 365 additions and 30 deletions

View File

@@ -333,6 +333,63 @@ func (c *Client) printLogLevelResult(data map[string]any) {
}
}
// PerfSet live-retunes the tunnel buffer pool cap on all running embedded
// clients. Batch size is not live-tunable; configure it at proxy startup.
func (c *Client) PerfSet(ctx context.Context, value uint32) error {
path := fmt.Sprintf("/debug/perf?value=%d", value)
return c.fetchAndPrint(ctx, path, c.printPerfSet)
}
func (c *Client) printPerfSet(data map[string]any) {
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
c.printError(data)
return
}
val, _ := data["value"].(float64)
applied, _ := data["applied"].(float64)
_, _ = fmt.Fprintf(c.out, "Pool cap set to: %d\n", uint32(val))
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
_, _ = fmt.Fprintln(c.out, "Failed:")
for k, v := range failed {
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
}
}
}
// Runtime fetches runtime stats (heap, goroutines, RSS).
func (c *Client) Runtime(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
}
func (c *Client) printRuntime(data map[string]any) {
i := func(k string) uint64 {
v, _ := data[k].(float64)
return uint64(v)
}
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
_, _ = fmt.Fprintln(c.out, "Heap:")
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
if _, ok := data["vm_rss"]; ok {
_, _ = fmt.Fprintln(c.out, "Process:")
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
}
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
}
// StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"

View File

@@ -11,6 +11,8 @@ import (
"maps"
"net"
"net/http"
"os"
"runtime"
"slices"
"strconv"
"strings"
@@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
ListClientsForStartup() map[types.AccountID]*nbembed.Client
}
// InboundListenerInfo describes a per-account inbound listener as
@@ -165,6 +168,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handleListClients(w, r, wantJSON)
case "/debug/health":
h.handleHealth(w, r, wantJSON)
case "/debug/perf":
h.handlePerf(w, r)
case "/debug/runtime":
h.handleRuntime(w, r)
default:
if h.handleClientRoutes(w, r, path, wantJSON) {
return
@@ -258,10 +265,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
}
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -270,7 +277,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
resp := map[string]interface{}{
resp := map[string]any{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
@@ -352,10 +359,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
if h.inbound != nil {
inboundAll = h.inbound.InboundListeners()
}
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
row := map[string]interface{}{
row := map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -368,7 +375,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
}
clientsJSON = append(clientsJSON, row)
}
resp := map[string]interface{}{
resp := map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
@@ -458,7 +465,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
resp := map[string]interface{}{
resp := map[string]any{
"account_id": accountID,
"status": overview.FullDetailSummary(),
}
@@ -557,20 +564,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
h.writeJSON(w, map[string]any{"error": "invalid port"})
return
}
@@ -594,7 +601,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
conn, err := client.Dial(ctx, network, address)
if err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"host": host,
"port": port,
@@ -609,39 +616,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
}
latency := time.Since(start)
resp := map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"host": host,
"port": port,
"remote": remote,
"latency_ms": latency.Milliseconds(),
"latency": formatDuration(latency),
}
h.writeJSON(w, resp)
})
}
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
level := r.URL.Query().Get("level")
if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
return
}
if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"level": level,
})
@@ -652,7 +658,7 @@ const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -660,14 +666,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
defer cancel()
if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client started",
})
@@ -676,7 +682,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -684,19 +690,125 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
defer cancel()
if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client stopped",
})
}
func (h *Handler) handlePerf(w http.ResponseWriter, r *http.Request) {
raw := r.URL.Query().Get("value")
if raw == "" {
http.Error(w, "value parameter is required", http.StatusBadRequest)
return
}
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
return
}
capN := uint32(n)
applied := 0
failed := map[string]string{}
for accountID, client := range h.provider.ListClientsForStartup() {
if err := client.SetPerformance(nbembed.Performance{PreallocatedBuffersPerPool: &capN}); err != nil {
failed[string(accountID)] = err.Error()
continue
}
applied++
}
resp := map[string]any{
"success": true,
"value": capN,
"applied": applied,
}
if len(failed) > 0 {
resp["failed"] = failed
}
h.writeJSON(w, resp)
}
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
// running proxy; does not read pprof profiles.
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
clients := h.provider.ListClientsForDebug()
started := 0
for _, c := range clients {
if c.HasClient {
started++
}
}
resp := map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"goroutines": runtime.NumGoroutine(),
"num_cpu": runtime.NumCPU(),
"gomaxprocs": runtime.GOMAXPROCS(0),
"go_version": runtime.Version(),
"heap_alloc": m.HeapAlloc,
"heap_inuse": m.HeapInuse,
"heap_idle": m.HeapIdle,
"heap_released": m.HeapReleased,
"heap_sys": m.HeapSys,
"sys": m.Sys,
"live_objects": m.Mallocs - m.Frees,
"num_gc": m.NumGC,
"pause_total_ns": m.PauseTotalNs,
"clients": len(clients),
"started": started,
}
if proc := readProcStatus(); proc != nil {
resp["vm_rss"] = proc["VmRSS"]
resp["vm_size"] = proc["VmSize"]
resp["vm_data"] = proc["VmData"]
}
h.writeJSON(w, resp)
}
// readProcStatus parses /proc/self/status on Linux and returns size fields
// in bytes. Returns nil on non-Linux or read failure.
func readProcStatus() map[string]uint64 {
raw, err := os.ReadFile("/proc/self/status")
if err != nil {
return nil
}
out := map[string]uint64{}
for _, line := range strings.Split(string(raw), "\n") {
k, v, ok := strings.Cut(line, ":")
if !ok {
continue
}
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
continue
}
fields := strings.Fields(v)
if len(fields) < 1 {
continue
}
n, err := strconv.ParseUint(fields[0], 10, 64)
if err != nil {
continue
}
// Values are reported in kB.
out[k] = n * 1024
}
return out
}
const maxCaptureDuration = 30 * time.Minute
// handleCapture streams a pcap or text packet capture for the given client.
@@ -825,7 +937,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
h.writeJSON(w, resp)
}
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates()
if tmpl == nil {
@@ -838,7 +950,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
}
}
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")