Compare commits

..

5 Commits

Author SHA1 Message Date
Viktor Liu
8f7fd69813 Merge remote-tracking branch 'origin/main' into reduce-embed-wg-pool
# Conflicts:
#	proxy/internal/debug/handler.go
2026-05-07 12:25:30 +02:00
Viktor Liu
7720484c85 Merge branch 'main' into reduce-embed-wg-pool 2026-05-06 11:10:27 +02:00
Viktor Liu
8520bbbe57 Use builtin clear instead of maps.Clear 2026-05-04 12:00:59 +02:00
Viktor Liu
533dd9577b Merge remote-tracking branch 'origin/main' into reduce-embed-wg-pool
# Conflicts:
#	client/embed/embed.go
#	proxy/cmd/proxy/cmd/debug.go
#	proxy/internal/debug/handler.go
2026-05-04 11:37:29 +02:00
Viktor Liu
e81ce81494 Bound embed client WireGuard per-Device memory 2026-04-22 13:04:40 +02:00
12 changed files with 397 additions and 290 deletions

View File

@@ -19,7 +19,7 @@ jobs:
- name: codespell
uses: codespell-project/actions-codespell@v2
with:
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA,ede,additionals
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
skip: go.mod,go.sum,**/proxy/web/**
golangci:
strategy:

View File

@@ -12,6 +12,7 @@ import (
"sync"
"github.com/sirupsen/logrus"
wgdevice "golang.zx2c4.com/wireguard/device"
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface"
@@ -473,6 +474,55 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// WGTuning bundles runtime-adjustable WireGuard knobs exposed by the embed
// client. Nil fields are left unchanged; set a non-nil pointer to apply.
type WGTuning struct {
// PreallocatedBuffersPerPool caps each per-Device WaitPool.
// Zero means "unbounded" (no cap). Live-tunable only if the underlying
// Device was originally created with a nonzero cap.
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this client's live Device.
// Startup-only knobs (batch size) must be set via the package-level
// setters before Start.
func (c *Client) SetWGTuning(t WGTuning) error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetWGTuning(internal.WGTuning{
PreallocatedBuffersPerPool: t.PreallocatedBuffersPerPool,
})
}
// SetWGDefaultPreallocatedBuffersPerPool sets the default WaitPool cap
// applied to Devices created after this call. Zero disables the cap.
// Existing Devices are unaffected; use Client.SetWGTuning for that.
func SetWGDefaultPreallocatedBuffersPerPool(n uint32) {
wgdevice.SetPreallocatedBuffersPerPool(n)
}
// WGDefaultPreallocatedBuffersPerPool returns the current default WaitPool
// cap applied to newly-created Devices.
func WGDefaultPreallocatedBuffersPerPool() uint32 {
return wgdevice.PreallocatedBuffersPerPool
}
// SetWGDefaultMaxBatchSize sets the default per-Device batch size applied
// to Devices created after this call. Zero means "use the bind+tun default"
// (NOT unlimited). Must be called before Start to take effect for a new
// Client.
func SetWGDefaultMaxBatchSize(n uint32) {
wgdevice.SetMaxBatchSizeOverride(n)
}
// WGDefaultMaxBatchSize returns the current default batch-size override.
// Zero means "no override".
func WGDefaultMaxBatchSize() uint32 {
return wgdevice.MaxBatchSizeOverride
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.

View File

@@ -30,27 +30,6 @@ import (
var currentMTU uint16 = iface.DefaultMTU
// nonRetryableEDECodes lists EDE info codes (RFC 8914) for which a SERVFAIL
// from one upstream means another upstream would return the same answer:
// DNSSEC validation outcomes and policy-based blocks. Transient errors
// (network, cached, not ready) are not included.
var nonRetryableEDECodes = map[uint16]struct{}{
dns.ExtendedErrorCodeUnsupportedDNSKEYAlgorithm: {},
dns.ExtendedErrorCodeUnsupportedDSDigestType: {},
dns.ExtendedErrorCodeDNSSECIndeterminate: {},
dns.ExtendedErrorCodeDNSBogus: {},
dns.ExtendedErrorCodeSignatureExpired: {},
dns.ExtendedErrorCodeSignatureNotYetValid: {},
dns.ExtendedErrorCodeDNSKEYMissing: {},
dns.ExtendedErrorCodeRRSIGsMissing: {},
dns.ExtendedErrorCodeNoZoneKeyBitSet: {},
dns.ExtendedErrorCodeNSECMissing: {},
dns.ExtendedErrorCodeBlocked: {},
dns.ExtendedErrorCodeCensored: {},
dns.ExtendedErrorCodeFiltered: {},
dns.ExtendedErrorCodeProhibited: {},
}
// privateClientIface is the subset of the WireGuard interface needed by GetClientPrivate.
type privateClientIface interface {
Name() string
@@ -271,18 +250,6 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
var t time.Duration
var err error
// Advertise EDNS0 so the upstream may include Extended DNS Errors
// (RFC 8914) in failure responses; we use those to short-circuit
// failover for definitive answers like DNSSEC validation failures.
// Operate on a copy so the inbound request is unchanged: a client that
// did not advertise EDNS0 must not see an OPT in the response.
hadEdns := r.IsEdns0() != nil
reqUp := r
if !hadEdns {
reqUp = r.Copy()
reqUp.SetEdns0(upstreamUDPSize(), false)
}
var startTime time.Time
var upstreamProto *upstreamProtocolResult
func() {
@@ -290,7 +257,7 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
defer cancel()
ctx, upstreamProto = contextWithupstreamProtocolResult(ctx)
startTime = time.Now()
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), reqUp)
rm, t, err = u.upstreamClient.exchange(ctx, upstream.String(), r)
}()
if err != nil {
@@ -302,49 +269,13 @@ func (u *upstreamResolverBase) queryUpstream(parentCtx context.Context, w dns.Re
}
if rm.Rcode == dns.RcodeServerFailure || rm.Rcode == dns.RcodeRefused {
if code, ok := nonRetryableEDE(rm); ok {
resutil.SetMeta(w, "ede", edeName(code))
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
return &upstreamFailure{upstream: upstream, reason: dns.RcodeToString[rm.Rcode]}
}
if !hadEdns {
stripOPT(rm)
}
u.writeSuccessResponse(w, rm, upstream, r.Question[0].Name, t, upstreamProto, logger)
return nil
}
// upstreamUDPSize returns the EDNS0 UDP buffer size we advertise to upstreams,
// derived from the tunnel MTU and bounded against underflow.
func upstreamUDPSize() uint16 {
if currentMTU > ipUDPHeaderSize {
return currentMTU - ipUDPHeaderSize
}
return dns.MinMsgSize
}
// stripOPT removes any OPT pseudo-RRs from the response's Extra section so
// the response complies with RFC 6891 when the client did not advertise EDNS0.
func stripOPT(rm *dns.Msg) {
if len(rm.Extra) == 0 {
return
}
out := rm.Extra[:0]
for _, rr := range rm.Extra {
if _, ok := rr.(*dns.OPT); ok {
continue
}
out = append(out, rr)
}
rm.Extra = out
}
func (u *upstreamResolverBase) handleUpstreamError(err error, upstream netip.AddrPort, startTime time.Time) *upstreamFailure {
if !errors.Is(err, context.DeadlineExceeded) && !isTimeout(err) {
return &upstreamFailure{upstream: upstream, reason: err.Error()}
@@ -406,34 +337,6 @@ func formatFailures(failures []upstreamFailure) string {
return strings.Join(parts, ", ")
}
// nonRetryableEDE returns the first non-retryable EDE code carried in the
// response, if any.
func nonRetryableEDE(rm *dns.Msg) (uint16, bool) {
opt := rm.IsEdns0()
if opt == nil {
return 0, false
}
for _, o := range opt.Option {
ede, ok := o.(*dns.EDNS0_EDE)
if !ok {
continue
}
if _, ok := nonRetryableEDECodes[ede.InfoCode]; ok {
return ede.InfoCode, true
}
}
return 0, false
}
// edeName returns a human-readable name for an EDE code, falling back to
// the numeric code when unknown.
func edeName(code uint16) string {
if name, ok := dns.ExtendedErrorCodeToString[code]; ok {
return name
}
return fmt.Sprintf("EDE %d", code)
}
// ProbeAvailability tests all upstream servers simultaneously and
// disables the resolver if none work
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {

View File

@@ -770,132 +770,3 @@ func TestExchangeWithFallback_TCPTruncatesToClientSize(t *testing.T) {
assert.Less(t, len(rm2.Answer), 20, "small EDNS0 client should get fewer records")
assert.True(t, rm2.Truncated, "response should be truncated for small buffer client")
}
func msgWithEDE(rcode int, codes ...uint16) *dns.Msg {
m := new(dns.Msg)
m.Response = true
m.Rcode = rcode
if len(codes) == 0 {
return m
}
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.SetUDPSize(dns.MinMsgSize)
for _, c := range codes {
opt.Option = append(opt.Option, &dns.EDNS0_EDE{InfoCode: c})
}
m.Extra = append(m.Extra, opt)
return m
}
func TestNonRetryableEDE(t *testing.T) {
tests := []struct {
name string
msg *dns.Msg
wantOK bool
wantCode uint16
}{
{name: "no edns0", msg: msgWithEDE(dns.RcodeServerFailure)},
{
name: "opt without ede",
msg: func() *dns.Msg {
m := msgWithEDE(dns.RcodeServerFailure)
opt := &dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}}
opt.Option = append(opt.Option, &dns.EDNS0_NSID{Code: dns.EDNS0NSID})
m.Extra = []dns.RR{opt}
return m
}(),
},
{name: "ede dnsbogus", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus), wantOK: true, wantCode: dns.ExtendedErrorCodeDNSBogus},
{name: "ede signature expired", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeSignatureExpired), wantOK: true, wantCode: dns.ExtendedErrorCodeSignatureExpired},
{name: "ede blocked", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeBlocked), wantOK: true, wantCode: dns.ExtendedErrorCodeBlocked},
{name: "ede prohibited", msg: msgWithEDE(dns.RcodeRefused, dns.ExtendedErrorCodeProhibited), wantOK: true, wantCode: dns.ExtendedErrorCodeProhibited},
{name: "ede cached error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeCachedError)},
{name: "ede network error retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError)},
{name: "ede not ready retryable", msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNotReady)},
{
name: "first non-retryable wins",
msg: msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeNetworkError, dns.ExtendedErrorCodeDNSBogus),
wantOK: true,
wantCode: dns.ExtendedErrorCodeDNSBogus,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
code, ok := nonRetryableEDE(tc.msg)
assert.Equal(t, tc.wantOK, ok, "ok should match")
if tc.wantOK {
assert.Equal(t, tc.wantCode, code, "code should match")
}
})
}
}
func TestEDEName(t *testing.T) {
assert.Equal(t, "DNSSEC Bogus", edeName(dns.ExtendedErrorCodeDNSBogus))
assert.Equal(t, "Signature Expired", edeName(dns.ExtendedErrorCodeSignatureExpired))
assert.Equal(t, "EDE 9999", edeName(9999), "unknown code falls back to numeric")
}
func TestStripOPT(t *testing.T) {
rm := &dns.Msg{
Extra: []dns.RR{
&dns.OPT{Hdr: dns.RR_Header{Name: ".", Rrtype: dns.TypeOPT}},
&dns.A{Hdr: dns.RR_Header{Name: "x.", Rrtype: dns.TypeA}, A: net.IPv4(1, 2, 3, 4)},
},
}
stripOPT(rm)
assert.Len(t, rm.Extra, 1, "OPT should be removed, A kept")
_, isOPT := rm.Extra[0].(*dns.OPT)
assert.False(t, isOPT, "remaining record must not be OPT")
}
func TestUpstreamResolver_NonRetryableEDEShortCircuits(t *testing.T) {
upstream1 := netip.MustParseAddrPort("192.0.2.1:53")
upstream2 := netip.MustParseAddrPort("192.0.2.2:53")
servfailWithEDE := msgWithEDE(dns.RcodeServerFailure, dns.ExtendedErrorCodeDNSBogus)
successResp := buildMockResponse(dns.RcodeSuccess, "192.0.2.100")
var queried []string
tracking := &trackingMockClient{
inner: &mockUpstreamResolverPerServer{
responses: map[string]mockUpstreamResponse{
upstream1.String(): {msg: servfailWithEDE},
upstream2.String(): {msg: successResp},
},
rtt: time.Millisecond,
},
queriedUpstreams: &queried,
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
resolver := &upstreamResolverBase{
ctx: ctx,
upstreamClient: tracking,
upstreamServers: []netip.AddrPort{upstream1, upstream2},
upstreamTimeout: UpstreamTimeout,
}
var written *dns.Msg
w := &test.MockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
written = m
return nil
},
}
// Client query without EDNS0 must not see an OPT in the response.
q := new(dns.Msg).SetQuestion("example.com.", dns.TypeA)
resolver.ServeDNS(w, q)
require.NotNil(t, written, "response must be written")
assert.Equal(t, dns.RcodeServerFailure, written.Rcode, "SERVFAIL must propagate")
assert.Len(t, queried, 1, "only first upstream should be queried")
assert.Equal(t, upstream1.String(), queried[0])
for _, rr := range written.Extra {
_, isOPT := rr.(*dns.OPT)
assert.False(t, isOPT, "synthetic OPT must not leak to a non-EDNS0 client")
}
}

View File

@@ -1979,6 +1979,29 @@ func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
return e.clientMetrics
}
// WGTuning bundles runtime-adjustable WireGuard pool knobs.
// See Engine.SetWGTuning. Nil fields are ignored.
type WGTuning struct {
PreallocatedBuffersPerPool *uint32
}
// SetWGTuning applies the given tuning to this engine's live Device.
func (e *Engine) SetWGTuning(t WGTuning) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
if e.wgInterface == nil {
return fmt.Errorf("wg interface not initialized")
}
dev := e.wgInterface.GetWGDevice()
if dev == nil {
return fmt.Errorf("wg device not initialized")
}
if t.PreallocatedBuffersPerPool != nil {
dev.SetPreallocatedBuffersPerPool(*t.PreallocatedBuffersPerPool)
}
return nil
}
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
iface, err := net.InterfaceByName(ifaceName)
if err != nil {

4
go.mod
View File

@@ -72,7 +72,7 @@ require (
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.72
github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45
@@ -317,7 +317,7 @@ replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-2024
replace github.com/getlantern/systray => github.com/netbirdio/systray v0.0.0-20231030152038-ef1ed2a27949
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0
replace golang.zx2c4.com/wireguard => github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58
replace github.com/cloudflare/circl => github.com/cunicu/circl v0.0.0-20230801113412-fec58fc7b5f6

8
go.sum
View File

@@ -421,8 +421,8 @@ github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFe
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.72 h1:vhmr+TF2A3tuoGNkLDFK9zi36F2LS+hKTRW0Uf8kbzI=
github.com/miekg/dns v1.1.72/go.mod h1:+EuEPhdHOsfk6Wk5TT2CzssZdqkmFhf8r+aVyDEToIs=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
github.com/miekg/dns v1.1.59/go.mod h1:nZpewl5p6IvctfgrckopVx2OlSEHPRO/U4SYkRklrEk=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721 h1:RlZweED6sbSArvlE924+mUcZuXKLBHA35U7LN621Bws=
github.com/mikioh/ipaddr v0.0.0-20190404000644-d465c8ab6721/go.mod h1:Ickgr2WtCLZ2MDGd4Gr0geeCH5HybhRJbonOgQpvSxc=
github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw=
@@ -463,8 +463,8 @@ github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 h1:3tHlFmhTdX9ax
github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502/go.mod h1:CIMRFEJVL+0DS1a3Nx06NaMn4Dz63Ng6O7dl0qH0zVM=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45 h1:ujgviVYmx243Ksy7NdSwrdGPSRNE3pb8kEDSpH0QuAQ=
github.com/netbirdio/signal-dispatcher/dispatcher v0.0.0-20250805121659-6b4ac470ca45/go.mod h1:5/sjFmLb8O96B5737VCqhHyGRzNFIaN/Bu7ZodXc3qQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0 h1:h/QnNzm7xzHPm+gajcblYUOclrW2FeNeDlUNj6tTWKQ=
github.com/netbirdio/wireguard-go v0.0.0-20260107100953-33b7c9d03db0/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58 h1:6REpBYpJBLTTgqCcLGpTqvRDoEoLbA5r2nAXqMd2La0=
github.com/netbirdio/wireguard-go v0.0.0-20260422100739-63c67f59bf58/go.mod h1:rpwXGsirqLqN2L0JDJQlwOboGHmptD5ZD6T2VmcqhTw=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
github.com/nicksnyder/go-i18n/v2 v2.5.1 h1:IxtPxYsR9Gp60cGXjfuR/llTqV8aYMsC472zD0D1vHk=

View File

@@ -109,6 +109,35 @@ var debugStopCmd = &cobra.Command{
SilenceUsage: true,
}
var debugWGTuneCmd = &cobra.Command{
Use: "wgtune",
Short: "Inspect and live-tune WireGuard pool settings",
}
var debugWGTuneGetCmd = &cobra.Command{
Use: "get",
Short: "Show pool cap and batch size defaults",
Args: cobra.NoArgs,
RunE: runDebugWGTuneGet,
SilenceUsage: true,
}
var debugWGTuneSetCmd = &cobra.Command{
Use: "set <pool-cap>",
Short: "Set the pool cap (new and live clients)",
Args: cobra.ExactArgs(1),
RunE: runDebugWGTuneSet,
SilenceUsage: true,
}
var debugRuntimeCmd = &cobra.Command{
Use: "runtime",
Short: "Show runtime stats (heap, goroutines, RSS)",
Args: cobra.NoArgs,
RunE: runDebugRuntime,
SilenceUsage: true,
}
var debugCaptureCmd = &cobra.Command{
Use: "capture <account-id> [filter expression]",
Short: "Capture packets on a client's WireGuard interface",
@@ -159,6 +188,10 @@ func init() {
debugCmd.AddCommand(debugLogCmd)
debugCmd.AddCommand(debugStartCmd)
debugCmd.AddCommand(debugStopCmd)
debugWGTuneCmd.AddCommand(debugWGTuneGetCmd)
debugWGTuneCmd.AddCommand(debugWGTuneSetCmd)
debugCmd.AddCommand(debugWGTuneCmd)
debugCmd.AddCommand(debugRuntimeCmd)
debugCmd.AddCommand(debugCaptureCmd)
rootCmd.AddCommand(debugCmd)
@@ -220,6 +253,22 @@ func runDebugStop(cmd *cobra.Command, args []string) error {
return getDebugClient(cmd).StopClient(cmd.Context(), args[0])
}
func runDebugWGTuneGet(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).WGTuneGet(cmd.Context())
}
func runDebugWGTuneSet(cmd *cobra.Command, args []string) error {
n, err := strconv.ParseUint(args[0], 10, 32)
if err != nil {
return fmt.Errorf("invalid value %q: %w", args[0], err)
}
return getDebugClient(cmd).WGTuneSet(cmd.Context(), uint32(n))
}
func runDebugRuntime(cmd *cobra.Command, _ []string) error {
return getDebugClient(cmd).Runtime(cmd.Context())
}
func runDebugCapture(cmd *cobra.Command, args []string) error {
duration, _ := cmd.Flags().GetDuration("duration")
forcePcap, _ := cmd.Flags().GetBool("pcap")

View File

@@ -15,11 +15,22 @@ import (
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/client/embed"
"github.com/netbirdio/netbird/proxy"
nbacme "github.com/netbirdio/netbird/proxy/internal/acme"
"github.com/netbirdio/netbird/util"
)
const (
// envWGPreallocatedBuffers caps the per-Device WireGuard buffer pool
// size. Zero (unset) keeps the uncapped upstream default.
envWGPreallocatedBuffers = "NB_WG_PREALLOCATED_BUFFERS"
// envWGMaxBatchSize overrides the per-Device WireGuard batch size,
// which controls how many buffers each receive/TUN worker eagerly
// allocates. Zero (unset) keeps the bind+tun default.
envWGMaxBatchSize = "NB_WG_MAX_BATCH_SIZE"
)
const DefaultManagementURL = "https://api.netbird.io:443"
// envProxyToken is the environment variable name for the proxy access token.
@@ -145,6 +156,42 @@ func runServer(cmd *cobra.Command, args []string) error {
logger.Infof("configured log level: %s", level)
var wgPool, wgBatch uint64
if raw := os.Getenv(envWGPreallocatedBuffers); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGPreallocatedBuffers, raw, err)
}
wgPool = n
embed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
logger.Infof("wireguard preallocated buffers per pool: %d", n)
}
if raw := os.Getenv(envWGMaxBatchSize); raw != "" {
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
return fmt.Errorf("invalid %s %q: %w", envWGMaxBatchSize, raw, err)
}
wgBatch = n
embed.SetWGDefaultMaxBatchSize(uint32(n))
logger.Infof("wireguard max batch size override: %d", n)
}
if wgPool > 0 {
// Each bind recv goroutine (IPv4 + IPv6 + ICE relay) plus
// RoutineReadFromTUN eagerly reserves `batch` message buffers for
// the lifetime of the Device. A pool cap below that floor blocks
// the receive pipeline at startup.
batch := wgBatch
if batch == 0 {
batch = 128
}
const recvGoroutines = 4
floor := batch * recvGoroutines
if wgPool < floor {
logger.Warnf("%s=%d is below the eager-allocation floor (~%d for batch=%d); startup may deadlock",
envWGPreallocatedBuffers, wgPool, floor, batch)
}
}
switch forwardedProto {
case "auto", "http", "https":
default:

View File

@@ -284,6 +284,74 @@ func (c *Client) printLogLevelResult(data map[string]any) {
}
}
// WGTuneGet fetches the current WireGuard pool cap.
func (c *Client) WGTuneGet(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/wgtune", c.printWGTuneGet)
}
// WGTuneSet updates the WireGuard pool cap on the global default and all live clients.
func (c *Client) WGTuneSet(ctx context.Context, value uint32) error {
path := fmt.Sprintf("/debug/wgtune?value=%d", value)
return c.fetchAndPrint(ctx, path, c.printWGTuneSet)
}
func (c *Client) printWGTuneGet(data map[string]any) {
def, _ := data["default"].(float64)
batch, _ := data["batch_size"].(float64)
_, _ = fmt.Fprintf(c.out, "Default: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Batch size: %d (0 = unset)\n", uint32(batch))
}
func (c *Client) printWGTuneSet(data map[string]any) {
if errMsg, ok := data["error"].(string); ok && errMsg != "" {
c.printError(data)
return
}
def, _ := data["default"].(float64)
applied, _ := data["applied"].(float64)
_, _ = fmt.Fprintf(c.out, "Default set to: %d\n", uint32(def))
_, _ = fmt.Fprintf(c.out, "Applied to %d live clients\n", int(applied))
if failed, ok := data["failed"].(map[string]any); ok && len(failed) > 0 {
_, _ = fmt.Fprintln(c.out, "Failed:")
for k, v := range failed {
_, _ = fmt.Fprintf(c.out, " %s: %v\n", k, v)
}
}
}
// Runtime fetches runtime stats (heap, goroutines, RSS).
func (c *Client) Runtime(ctx context.Context) error {
return c.fetchAndPrint(ctx, "/debug/runtime", c.printRuntime)
}
func (c *Client) printRuntime(data map[string]any) {
i := func(k string) uint64 {
v, _ := data[k].(float64)
return uint64(v)
}
mb := func(n uint64) string { return fmt.Sprintf("%.1f MB", float64(n)/(1<<20)) }
_, _ = fmt.Fprintf(c.out, "Uptime: %v\n", data["uptime"])
_, _ = fmt.Fprintf(c.out, "Go: %v on %d CPU (GOMAXPROCS=%d)\n", data["go_version"], uint32(i("num_cpu")), uint32(i("gomaxprocs")))
_, _ = fmt.Fprintf(c.out, "Goroutines: %d\n", i("goroutines"))
_, _ = fmt.Fprintf(c.out, "Live objects: %d\n", i("live_objects"))
_, _ = fmt.Fprintf(c.out, "GC: %d cycles, %v pause total\n", i("num_gc"), time.Duration(i("pause_total_ns")))
_, _ = fmt.Fprintln(c.out, "Heap:")
_, _ = fmt.Fprintf(c.out, " alloc: %s\n", mb(i("heap_alloc")))
_, _ = fmt.Fprintf(c.out, " in-use: %s\n", mb(i("heap_inuse")))
_, _ = fmt.Fprintf(c.out, " idle: %s\n", mb(i("heap_idle")))
_, _ = fmt.Fprintf(c.out, " released: %s\n", mb(i("heap_released")))
_, _ = fmt.Fprintf(c.out, " sys: %s\n", mb(i("heap_sys")))
_, _ = fmt.Fprintf(c.out, "Total sys: %s\n", mb(i("sys")))
if _, ok := data["vm_rss"]; ok {
_, _ = fmt.Fprintln(c.out, "Process:")
_, _ = fmt.Fprintf(c.out, " VmRSS: %s\n", mb(i("vm_rss")))
_, _ = fmt.Fprintf(c.out, " VmSize: %s\n", mb(i("vm_size")))
_, _ = fmt.Fprintf(c.out, " VmData: %s\n", mb(i("vm_data")))
}
_, _ = fmt.Fprintf(c.out, "Clients: %d (%d started)\n", i("clients"), i("started"))
}
// StartClient starts a specific client.
func (c *Client) StartClient(ctx context.Context, accountID string) error {
path := "/debug/clients/" + url.PathEscape(accountID) + "/start"

View File

@@ -11,6 +11,8 @@ import (
"maps"
"net"
"net/http"
"os"
"runtime"
"slices"
"strconv"
"strings"
@@ -59,6 +61,7 @@ func sortedAccountIDs(m map[types.AccountID]roundtrip.ClientDebugInfo) []types.A
type clientProvider interface {
GetClient(accountID types.AccountID) (*nbembed.Client, bool)
ListClientsForDebug() map[types.AccountID]roundtrip.ClientDebugInfo
ListClientsForStartup() map[types.AccountID]*nbembed.Client
}
// healthChecker provides health probe state.
@@ -140,6 +143,10 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
h.handleListClients(w, r, wantJSON)
case "/debug/health":
h.handleHealth(w, r, wantJSON)
case "/debug/wgtune":
h.handleWGTune(w, r)
case "/debug/runtime":
h.handleRuntime(w, r)
default:
if h.handleClientRoutes(w, r, path, wantJSON) {
return
@@ -233,10 +240,10 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
}
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -245,7 +252,7 @@ func (h *Handler) handleIndex(w http.ResponseWriter, _ *http.Request, wantJSON b
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
resp := map[string]interface{}{
resp := map[string]any{
"version": version.NetbirdVersion(),
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
@@ -323,10 +330,10 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
sortedIDs := sortedAccountIDs(clients)
if wantJSON {
clientsJSON := make([]map[string]interface{}, 0, len(clients))
clientsJSON := make([]map[string]any, 0, len(clients))
for _, id := range sortedIDs {
info := clients[id]
clientsJSON = append(clientsJSON, map[string]interface{}{
clientsJSON = append(clientsJSON, map[string]any{
"account_id": info.AccountID,
"service_count": info.ServiceCount,
"service_keys": info.ServiceKeys,
@@ -335,7 +342,7 @@ func (h *Handler) handleListClients(w http.ResponseWriter, _ *http.Request, want
"age": time.Since(info.CreatedAt).Round(time.Second).String(),
})
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"client_count": len(clients),
"clients": clientsJSON,
@@ -421,7 +428,7 @@ func (h *Handler) handleClientStatus(w http.ResponseWriter, r *http.Request, acc
})
if wantJSON {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"account_id": accountID,
"status": overview.FullDetailSummary(),
})
@@ -504,20 +511,20 @@ func (h *Handler) handleClientTools(w http.ResponseWriter, _ *http.Request, acco
func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
host := r.URL.Query().Get("host")
portStr := r.URL.Query().Get("port")
if host == "" || portStr == "" {
h.writeJSON(w, map[string]interface{}{"error": "host and port parameters required"})
h.writeJSON(w, map[string]any{"error": "host and port parameters required"})
return
}
port, err := strconv.Atoi(portStr)
if err != nil || port < 1 || port > 65535 {
h.writeJSON(w, map[string]interface{}{"error": "invalid port"})
h.writeJSON(w, map[string]any{"error": "invalid port"})
return
}
@@ -541,7 +548,7 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
conn, err := client.Dial(ctx, network, address)
if err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"host": host,
"port": port,
@@ -556,39 +563,38 @@ func (h *Handler) handlePingTCP(w http.ResponseWriter, r *http.Request, accountI
}
latency := time.Since(start)
resp := map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"host": host,
"port": port,
"remote": remote,
"latency_ms": latency.Milliseconds(),
"latency": formatDuration(latency),
}
h.writeJSON(w, resp)
})
}
func (h *Handler) handleLogLevel(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
level := r.URL.Query().Get("level")
if level == "" {
h.writeJSON(w, map[string]interface{}{"error": "level parameter required (trace, debug, info, warn, error)"})
h.writeJSON(w, map[string]any{"error": "level parameter required (trace, debug, info, warn, error)"})
return
}
if err := client.SetLogLevel(level); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"level": level,
})
@@ -599,7 +605,7 @@ const clientActionTimeout = 30 * time.Second
func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -607,14 +613,14 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
defer cancel()
if err := client.Start(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client started",
})
@@ -623,7 +629,7 @@ func (h *Handler) handleClientStart(w http.ResponseWriter, r *http.Request, acco
func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accountID types.AccountID) {
client, ok := h.provider.GetClient(accountID)
if !ok {
h.writeJSON(w, map[string]interface{}{"error": "client not found"})
h.writeJSON(w, map[string]any{"error": "client not found"})
return
}
@@ -631,19 +637,136 @@ func (h *Handler) handleClientStop(w http.ResponseWriter, r *http.Request, accou
defer cancel()
if err := client.Stop(ctx); err != nil {
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": false,
"error": err.Error(),
})
return
}
h.writeJSON(w, map[string]interface{}{
h.writeJSON(w, map[string]any{
"success": true,
"message": "client stopped",
})
}
func (h *Handler) handleWGTune(w http.ResponseWriter, r *http.Request) {
values, ok := r.URL.Query()["value"]
if !ok {
h.writeJSON(w, map[string]any{
"default": nbembed.WGDefaultPreallocatedBuffersPerPool(),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
})
return
}
if len(values) == 0 || values[0] == "" {
http.Error(w, "value parameter must not be empty", http.StatusBadRequest)
return
}
raw := values[0]
n, err := strconv.ParseUint(raw, 10, 32)
if err != nil {
http.Error(w, fmt.Sprintf("invalid value %q: %v", raw, err), http.StatusBadRequest)
return
}
nbembed.SetWGDefaultPreallocatedBuffersPerPool(uint32(n))
applied := 0
failed := map[string]string{}
for accountID, client := range h.provider.ListClientsForStartup() {
capN := uint32(n)
if err := client.SetWGTuning(nbembed.WGTuning{PreallocatedBuffersPerPool: &capN}); err != nil {
failed[string(accountID)] = err.Error()
continue
}
applied++
}
resp := map[string]any{
"success": true,
"default": uint32(n),
"batch_size": nbembed.WGDefaultMaxBatchSize(),
"applied": applied,
}
if len(failed) > 0 {
resp["failed"] = failed
}
h.writeJSON(w, resp)
}
// handleRuntime returns cheap runtime and process stats. Safe to hit on a
// running proxy; does not read pprof profiles.
func (h *Handler) handleRuntime(w http.ResponseWriter, _ *http.Request) {
var m runtime.MemStats
runtime.ReadMemStats(&m)
clients := h.provider.ListClientsForDebug()
started := 0
for _, c := range clients {
if c.HasClient {
started++
}
}
resp := map[string]any{
"uptime": time.Since(h.startTime).Round(time.Second).String(),
"goroutines": runtime.NumGoroutine(),
"num_cpu": runtime.NumCPU(),
"gomaxprocs": runtime.GOMAXPROCS(0),
"go_version": runtime.Version(),
"heap_alloc": m.HeapAlloc,
"heap_inuse": m.HeapInuse,
"heap_idle": m.HeapIdle,
"heap_released": m.HeapReleased,
"heap_sys": m.HeapSys,
"sys": m.Sys,
"live_objects": m.Mallocs - m.Frees,
"num_gc": m.NumGC,
"pause_total_ns": m.PauseTotalNs,
"clients": len(clients),
"started": started,
}
if proc := readProcStatus(); proc != nil {
resp["vm_rss"] = proc["VmRSS"]
resp["vm_size"] = proc["VmSize"]
resp["vm_data"] = proc["VmData"]
}
h.writeJSON(w, resp)
}
// readProcStatus parses /proc/self/status on Linux and returns size fields
// in bytes. Returns nil on non-Linux or read failure.
func readProcStatus() map[string]uint64 {
raw, err := os.ReadFile("/proc/self/status")
if err != nil {
return nil
}
out := map[string]uint64{}
for _, line := range strings.Split(string(raw), "\n") {
k, v, ok := strings.Cut(line, ":")
if !ok {
continue
}
if k != "VmRSS" && k != "VmSize" && k != "VmData" {
continue
}
fields := strings.Fields(v)
if len(fields) < 1 {
continue
}
n, err := strconv.ParseUint(fields[0], 10, 64)
if err != nil {
continue
}
// Values are reported in kB.
out[k] = n * 1024
}
return out
}
const maxCaptureDuration = 30 * time.Minute
// handleCapture streams a pcap or text packet capture for the given client.
@@ -772,7 +895,7 @@ func (h *Handler) handleHealth(w http.ResponseWriter, r *http.Request, wantJSON
h.writeJSON(w, resp)
}
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interface{}) {
func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data any) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
tmpl := h.getTemplates()
if tmpl == nil {
@@ -785,7 +908,7 @@ func (h *Handler) renderTemplate(w http.ResponseWriter, name string, data interf
}
}
func (h *Handler) writeJSON(w http.ResponseWriter, v interface{}) {
func (h *Handler) writeJSON(w http.ResponseWriter, v any) {
w.Header().Set("Content-Type", "application/json")
enc := json.NewEncoder(w)
enc.SetIndent("", " ")

View File

@@ -12,7 +12,6 @@ import (
"github.com/google/gopacket"
"github.com/google/gopacket/layers"
"github.com/miekg/dns"
)
// TextWriter writes human-readable one-line-per-packet summaries.
@@ -595,45 +594,19 @@ func formatDNSResponse(d *layers.DNS, rd string, plen int) string {
anCount := d.ANCount
nsCount := d.NSCount
arCount := d.ARCount
ede := formatEDE(d)
if d.ResponseCode != layers.DNSResponseCodeNoErr {
return fmt.Sprintf("%04x %d/%d/%d %s%s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s (%d)", d.ID, anCount, nsCount, arCount, d.ResponseCode, plen)
}
if anCount > 0 && len(d.Answers) > 0 {
rr := d.Answers[0]
if rdata := shortRData(&rr); rdata != "" {
return fmt.Sprintf("%04x %d/%d/%d %s %s%s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, ede, plen)
return fmt.Sprintf("%04x %d/%d/%d %s %s (%d)", d.ID, anCount, nsCount, arCount, rr.Type, rdata, plen)
}
}
return fmt.Sprintf("%04x %d/%d/%d%s (%d)", d.ID, anCount, nsCount, arCount, ede, plen)
}
// dnsOPTCodeEDE is the EDNS0 option code for Extended DNS Errors (RFC 8914).
const dnsOPTCodeEDE layers.DNSOptionCode = layers.DNSOptionCode(dns.EDNS0EDE)
// formatEDE returns " EDE=Name" for the first Extended DNS Error option
// found in the response, or empty string if none is present.
func formatEDE(d *layers.DNS) string {
for _, rr := range d.Additionals {
if rr.Type != layers.DNSTypeOPT {
continue
}
for _, opt := range rr.OPT {
if opt.Code != dnsOPTCodeEDE || len(opt.Data) < 2 {
continue
}
info := binary.BigEndian.Uint16(opt.Data[:2])
name, ok := dns.ExtendedErrorCodeToString[info]
if !ok {
name = fmt.Sprintf("%d", info)
}
return " EDE=" + name
}
}
return ""
return fmt.Sprintf("%04x %d/%d/%d (%d)", d.ID, anCount, nsCount, arCount, plen)
}
func shortRData(rr *layers.DNSResourceRecord) string {