Compare commits

..

1 Commits

Author SHA1 Message Date
mlsmaycon
2d8b0310a4 [client, proxy] IPv6 in-place apply + accept-loop hardening on netstack listeners
Two related fixes for the embedded netbird client and the per-account
inbound listeners that ride on its gVisor netstack.

client/internal/engine.go — replace hasIPv6Changed with reconcileIPv6:

  - First v6 assignment (current had no v6, conf carries one) is applied
    in place via WGIface.UpdateAddr instead of returning ErrResetConnection.
    Pre-fix, every embedded client whose account had IPv6 enabled would
    reset on its first NetworkMap sync — boot config has no v6, the sync
    introduces one, the engine tore itself down to "apply" it. That
    teardown destroys the gVisor netstack and orphans every listener
    bound on it, which is what made the proxy's per-account :80/:443
    silently stop accepting traffic.
  - v6 removed clears in place.
  - v6 swapped to a different non-empty value still resets (gVisor
    netstack can't safely swap its address at runtime).
  - Mutates e.config.WgAddr to match the applied state so subsequent
    PeerConfig comparisons are stable.

proxy/internal/tcp/accept.go (new) + proxy/inbound.go +
proxy/internal/tcp/router.go — harden the two Accept() loops on
netstack-backed listeners:

  - IsClosedListenerErr recognises net.ErrClosed AND gVisor's
    "endpoint is in invalid state" — the latter survives gonet's
    *net.OpError wrapping in a way errors.Is(.., net.ErrClosed) does
    not. Without this the loop spins CPU-hot after the underlying
    netstack is destroyed (peer rekey, embedded-client reset, account
    churn), emitting one log line per iteration.
  - AcceptBackoff implements the exponential backoff that
    net/http.Server.Serve uses on transient Accept errors: 5ms doubling
    up to 1s. Defence-in-depth so an unknown sticky error cannot burn
    a CPU core even if IsClosedListenerErr misses its signature.

proxy/internal/roundtrip/netbird.go — emit a single structured INFO
line summarising every embed.Options flag (account_id, service_id,
public_key, management_url, wg_port, block_inbound, block_lan_access,
disable_ipv6, no_userspace, presence of credentials) when each
per-account embedded client is created. Secrets reduced to a "present"
boolean — never logged verbatim. Diagnostic-only; no behavior change,
but it makes the "why is this embedded peer misbehaving" loop a single
log read instead of a code dive.

Tests (real listeners, scripted errors, no mocks of production code):
  - engine_reconcileipv6_test.go: 8 cases for every transition (first
    assignment, no change, removed, prefix-length changed, value
    changed, invalid bytes, UpdateAddr error) plus a updateConfig
    integration check that the fix actually fires on a v6-added
    PeerConfig.
  - accept_test.go: IsClosedListenerErr matrix + AcceptBackoff
    progression / cap / reset / cancel-during-wait / cancel-before-call.
  - router_test.go, inbound_test.go: scriptedAcceptListener +
    TestRouter_Serve_ExitsOnGVisorInvalidEndpoint and
    TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint —
    regression guards that fail in 2 s if the loop ever spins.
2026-06-18 10:37:51 +02:00
15 changed files with 968 additions and 223 deletions

View File

@@ -247,7 +247,7 @@ dockers_v2:
- netbirdio/netbird
- ghcr.io/netbirdio/netbird
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: client/Dockerfile
extra_files:
@@ -295,7 +295,7 @@ dockers_v2:
- netbirdio/relay
- ghcr.io/netbirdio/relay
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: relay/Dockerfile
platforms:
@@ -317,7 +317,7 @@ dockers_v2:
- netbirdio/signal
- ghcr.io/netbirdio/signal
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: signal/Dockerfile
platforms:
@@ -339,7 +339,7 @@ dockers_v2:
- netbirdio/management
- ghcr.io/netbirdio/management
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: management/Dockerfile
platforms:
@@ -361,7 +361,7 @@ dockers_v2:
- netbirdio/upload
- ghcr.io/netbirdio/upload
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: upload-server/Dockerfile
platforms:
@@ -383,7 +383,7 @@ dockers_v2:
- netbirdio/netbird-server
- ghcr.io/netbirdio/netbird-server
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: combined/Dockerfile
platforms:
@@ -405,7 +405,7 @@ dockers_v2:
- netbirdio/reverse-proxy
- ghcr.io/netbirdio/reverse-proxy
tags:
- "{{ .Version }}"
- "v{{ .Version }}"
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
dockerfile: proxy/Dockerfile
platforms:

View File

@@ -64,7 +64,6 @@ import (
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -1078,11 +1077,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return ErrResetConnection
}
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
log.Infof("peer IPv6 address changed, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
if !e.config.DisableIPv6 {
reset, err := e.reconcileIPv6(conf)
if err != nil {
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
}
if reset {
log.Infof("peer IPv6 address changed value, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
}
if conf.GetSshConfig() != nil {
@@ -1104,25 +1109,58 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
// differs from the configured address (added, removed, or changed).
// Compares against e.config.WgAddr (not the interface address, which may have
// been cleared by ClearIPv6 if OS assignment failed).
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
current := e.config.WgAddr
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
// engine's WireGuard interface in place when possible. Three transitions:
//
// - First v6 assignment (current had no v6, conf carries one): apply via
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
// boot config has no v6 — without this we reset on every fresh start
// once management has v6 enabled, orphaning any netstack listeners
// held outside the engine.
// - v6 removed (current had v6, conf carries none): clear in place, no
// reset.
// - v6 swapped to a different non-empty value: returns reset=true so the
// caller falls back to the engine-recreate path — the underlying
// interface address can't be safely swapped in place across all
// backends (gVisor netstack in particular fixes its address at
// CreateNetTUN time).
//
// Mutates e.config.WgAddr to match the applied state so subsequent
// PeerConfig comparisons are stable.
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
raw := conf.GetAddressV6()
current := e.config.WgAddr
if len(raw) == 0 {
return current.HasIPv6()
if !current.HasIPv6() {
return false, nil
}
current.ClearIPv6()
e.config.WgAddr = current
if err := e.wgInterface.UpdateAddr(current); err != nil {
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
}
return false, nil
}
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
log.Errorf("decode v6 overlay address: %v", err)
return false
incoming := current
if err := incoming.SetIPv6FromCompact(raw); err != nil {
return false, fmt.Errorf("decode v6 overlay address: %w", err)
}
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
if !current.HasIPv6() {
e.config.WgAddr = incoming
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
}
return false, nil
}
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
return false, nil
}
return true, nil
}
func (e *Engine) receiveJobEvents() {

View File

@@ -0,0 +1,305 @@
package internal
import (
"context"
"errors"
"net/netip"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
// netstack clients: any first NetworkMap update that brings an IPv6 address
// used to trigger ErrResetConnection, which destroys the netstack and orphans
// every listener bound on it (proxy-side inbound listeners in particular).
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
// from "v6 swapped value" (must reset).
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err, "encode v6 prefix %s", p)
return b
}
// reconcileIPv6Fixture builds the smallest Engine the function under test
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
// whose UpdateAddr call we can observe.
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
t.Helper()
var applied wgaddr.Address
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return initial },
UpdateAddrFunc: func(a wgaddr.Address) error {
applied = a
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: initial},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
return e, mock, &applied
}
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
// Embedded clients boot v4-only; management later assigns a v6 overlay.
// The fix: apply v6 in place, return reset=false. Pre-fix this case
// fell through to the "v6 changed" branch and reset the engine.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, mock, applied := reconcileIPv6Fixture(t, v4)
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
_ = mock
}
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
// Steady state: management redelivers the same PeerConfig. No interface
// mutation, no reset. Guards against an infinite reset loop if the
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
}
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
// Management withdraws v6 (e.g. account toggled off the v6 group).
// Cleared in place, no reset.
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
e, _, applied := reconcileIPv6Fixture(t, addr)
e.config.WgAddr = addr
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: nil,
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "v6 removed must NOT trigger reset")
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
}
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
// change because the new netmask redefines the broadcast/scope.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::1/80")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 prefix length change must request a reset")
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
}
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
// v6 was X, now Y. The netstack backend can't safely swap an existing
// address in place — fall back to the engine recreate path.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::2/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 value change must request a reset")
assert.False(t, updateAddrCalled,
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
}
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
// trigger a spurious reset.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, _, applied := reconcileIPv6Fixture(t, v4)
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "malformed v6 bytes must surface an error")
assert.False(t, reset, "decode error must NOT request a reset")
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
}
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
// kernel device), reconcileIPv6 returns the error to the caller for
// logging — but it must NOT request a reset. The whole point of the
// fix is to AVOID the reset cascade on v6 transitions.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "UpdateAddr failure must surface")
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
}
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
// The integration check: updateConfig must not return ErrResetConnection
// when the only change between current state and the new PeerConfig is
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
// every listener bound on the engine's netstack.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
IsUserspaceBindFunc: func() bool { return true },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
statusRecorder: peer.NewRecorder("https://mgm.test"),
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
err := e.updateConfig(conf)
assert.NoError(t, err,
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
assert.NotErrorIs(t, err, ErrResetConnection)
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
}

View File

@@ -66,7 +66,6 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -1707,82 +1706,12 @@ func getPeers(e *Engine) int {
return len(e.peerStore.PeersPubKey())
}
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestEngine_hasIPv6Changed(t *testing.T) {
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
tests := []struct {
name string
current wgaddr.Address
confV6 []byte
expected bool
}{
{
name: "no v6 before, no v6 now",
current: v4Only,
confV6: nil,
expected: false,
},
{
name: "no v6 before, v6 added",
current: v4Only,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: true,
},
{
name: "had v6, now removed",
current: v4v6,
confV6: nil,
expected: true,
},
{
name: "had v6, same v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: false,
},
{
name: "had v6, different v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
expected: true,
},
{
name: "same v6 addr, different prefix length",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
expected: true,
},
{
name: "decode error keeps status quo",
current: v4Only,
confV6: []byte{1, 2, 3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{WgAddr: tt.current},
}
conf := &mgmtProto.PeerConfig{
AddressV6: tt.confV6,
}
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
})
}
}
// The former TestEngine_hasIPv6Changed has been superseded by
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
// in place instead of demanding a full engine reset. The behavioral
// matrix the old test enforced is now covered, with corrected expectations,
// by TestReconcileIPv6_* in that sibling file.
func TestFilterAllowedIPs(t *testing.T) {
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")

View File

@@ -1026,12 +1026,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
return err
}
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return err
}
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return err
}
@@ -1129,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer, shouldUpdatePeers bool
var shouldStorePeer bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1156,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
@@ -1180,16 +1174,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
}
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(login.Meta)
return nil
})
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1199,7 +1190,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if shouldUpdatePeers {
if isStatusChanged || shouldStorePeer {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
@@ -1295,22 +1286,12 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, nil, false, nil
}
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return nil, nil, false, err
}
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
if err != nil {
return nil, nil, false, err
}
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
if err != nil {
return nil, nil, false, err
}
@@ -1318,16 +1299,32 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
return network, postureChecks, enableSSH, nil
}
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
for _, peerID := range peerGroupIDs {
groupIDsMap[peerID] = struct{}{}
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return false, err
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
if err != nil {
return false, err
}
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
for _, g := range peerGroups {
peerGroupIDs[g.ID] = struct{}{}
}
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
}
// getPeerPostureChecks returns the posture checks for the peer.
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
if err != nil {
return nil, err
}
if len(policies) == 0 {
return nil, nil
}
@@ -1339,7 +1336,11 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
continue
}
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
if err != nil {
return nil, err
}
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
}
@@ -1352,19 +1353,29 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
}
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
for _, rule := range policy.Rules {
if !rule.Enabled {
continue
}
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
if err != nil {
return nil, err
}
for _, sourceGroup := range rule.Sources {
if slices.Contains(peerGroupIDs, sourceGroup) {
return policy.SourcePostureChecks
group, ok := sourceGroups[sourceGroup]
if !ok {
return nil, fmt.Errorf("failed to check peer in policy source group")
}
if slices.Contains(group.Peers, peerID) {
return policy.SourcePostureChecks, nil
}
}
}
return nil
return nil, nil
}
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO

View File

@@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
_ = ln.Close()
}()
var backoff nbtcp.AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
backoff.Reset()
router.HandleConn(ctx, conn)
}
}

View File

@@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
// to drive the feedRouterFromListener tests without binding a real
// socket — the production code path is a netstack-backed listener that
// returns gVisor's "endpoint is in invalid state" forever after its
// endpoint is destroyed.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// errSentinel carries a literal error message so tests can synthesise
// the exact gVisor text without importing the netstack package.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
// regression guard for the inbound side of the tight-loop bug. The
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
// invalid state" and exit, otherwise it pegs a CPU core and floods the
// account-scoped log with the same accept error every iteration.
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
}()
select {
case <-done:
// Expected: loop recognised the gVisor error and returned.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
// defence-in-depth path: an unknown sticky Accept error must NOT cause
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
const transientCount = 5
errs := make([]error, transientCount)
for i := range errs {
errs[i] = errSentinel("transient: temporary network error")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
}()
time.AfterFunc(150*time.Millisecond, cancel)
select {
case <-done:
// Expected.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the 5 scripted errors would burn in microseconds.
// With backoff the first delay alone is 5ms, so the loop must take
// at least that long even though ctx fires at 150ms.
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
"loop ran without backing off — would burn CPU in production")
}

View File

@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
client, err := embed.New(embed.Options{
embedOpts := embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
@@ -371,7 +371,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
Performance: n.clientCfg.Performance,
})
}
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
client, err := embed.New(embedOpts)
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
@@ -847,3 +849,53 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}
// logEmbedOptions emits a single structured INFO line summarising every
// operationally meaningful flag handed to embed.New for this per-account
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
// boolean — never logged verbatim. Use this when an embedded peer
// silently misbehaves: most failure modes (inbound drops, wrong
// management URL, v6 unexpectedly on, userspace flipped, port clash)
// are obvious from these flags before any traffic flows.
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
wgPort := 0
if opts.WireguardPort != nil {
wgPort = *opts.WireguardPort
}
mtu := uint16(0)
if opts.MTU != nil {
mtu = *opts.MTU
}
perfBuffers := uint32(0)
if opts.Performance.PreallocatedBuffersPerPool != nil {
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
}
perfBatch := uint32(0)
if opts.Performance.MaxBatchSize != nil {
perfBatch = *opts.Performance.MaxBatchSize
}
logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey,
"device_name": opts.DeviceName,
"management_url": opts.ManagementURL,
"log_level": opts.LogLevel,
"wg_port": wgPort,
"mtu": mtu,
"block_inbound": opts.BlockInbound,
"block_lan_access": opts.BlockLANAccess,
"disable_ipv6": opts.DisableIPv6,
"disable_client_routes": opts.DisableClientRoutes,
"no_userspace": opts.NoUserspace,
"config_path_set": opts.ConfigPath != "",
"state_path_set": opts.StatePath != "",
"private_key_present": opts.PrivateKey != "",
"presharedkey_present": opts.PreSharedKey != "",
"setup_key_present": opts.SetupKey != "",
"jwt_token_present": opts.JWTToken != "",
"dns_labels": opts.DNSLabels,
"perf_buffers_per_pool": perfBuffers,
"perf_max_batch_size": perfBatch,
}).Info("starting embedded netbird client for account")
}

View File

@@ -0,0 +1,85 @@
package tcp
import (
"context"
"errors"
"net"
"strings"
"time"
)
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
// when Accept() is called on a listener whose underlying endpoint has
// been destroyed (peer rekey, embedded-client reset, account churn).
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
// so we fall back to a string check. Stable across the gVisor versions
// netbird pins.
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
// IsClosedListenerErr reports whether err signals that an accept loop
// should exit because the underlying listener can no longer serve
// connections. It recognises:
//
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
// - gVisor's "endpoint is in invalid state" for netstack-backed
// listeners whose endpoint was destroyed out from under them
// (typically when a per-account WireGuard netstack is reset without
// also tearing the listener entry down).
//
// Without the gVisor branch an accept loop on a netstack listener spins
// CPU-hot forever after the endpoint dies, because Accept never blocks
// again and the error neither matches net.ErrClosed nor cancels ctx.
func IsClosedListenerErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
}
// AcceptBackoff implements the exponential backoff used by
// net/http.Server.Serve for transient Accept errors. Without it a loop
// hitting a sticky unknown error burns a full CPU core. The zero value
// is ready to use; call Reset after a successful Accept.
type AcceptBackoff struct {
delay time.Duration
}
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
// (net/http.Server.Serve) and keep us well below 1 log line per second
// per orphaned listener.
const (
minAcceptDelay = 5 * time.Millisecond
maxAcceptDelay = time.Second
)
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
// returns true when the wait completed. Returns false if ctx fired
// during the wait — callers should treat that as "exit the loop".
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
b.advance()
select {
case <-ctx.Done():
return false
case <-time.After(b.delay):
return true
}
}
// Reset clears the accumulated delay so the next failure starts at the
// minimum delay again. Call after a successful Accept.
func (b *AcceptBackoff) Reset() { b.delay = 0 }
func (b *AcceptBackoff) advance() {
if b.delay == 0 {
b.delay = minAcceptDelay
} else {
b.delay *= 2
}
if b.delay > maxAcceptDelay {
b.delay = maxAcceptDelay
}
}

View File

@@ -0,0 +1,142 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
// and IsClosedListenerErr must unwrap it.
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
assert.True(t, IsClosedListenerErr(wrapped),
"net.OpError wrapping net.ErrClosed must be recognised as closed")
}
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
// regression guard. A gVisor netstack listener whose endpoint has been
// destroyed returns this exact text. Without recognising it the accept
// loop spins forever and burns a CPU core.
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
assert.True(t, IsClosedListenerErr(err),
"gVisor 'endpoint is in invalid state' must be recognised as closed")
}
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
// transient errors must keep returning false so the backoff path runs.
func TestIsClosedListenerErr_OtherError(t *testing.T) {
cases := []error{
errors.New("temporary failure"),
errors.New("accept tcp 10.10.1.254:80: too many open files"),
nil,
}
for _, c := range cases {
assert.False(t, IsClosedListenerErr(c),
"unexpected match on %v — must not be treated as closed", c)
}
}
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
// real timer but uses tight bounds so a slow CI machine still passes.
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
var b AcceptBackoff
expected := []time.Duration{
5 * time.Millisecond,
10 * time.Millisecond,
20 * time.Millisecond,
40 * time.Millisecond,
}
for i, want := range expected {
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
assert.GreaterOrEqual(t, elapsed, want,
"backoff %d (%v) must wait at least the configured delay", i, want)
assert.Less(t, elapsed, want*4,
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
}
// Burn enough rounds to reach the cap, then assert subsequent
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
// never exceed it.
for range 6 {
b.Backoff(context.Background())
}
assert.Equal(t, maxAcceptDelay, b.delay,
"after enough doublings the delay must clamp to maxAcceptDelay")
}
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
// after recovery.
func TestAcceptBackoff_Reset(t *testing.T) {
var b AcceptBackoff
for range 5 {
b.Backoff(context.Background())
}
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
b.Reset()
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff after Reset must complete")
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"after Reset the next backoff must restart at minAcceptDelay")
assert.Less(t, elapsed, 50*time.Millisecond,
"after Reset the next backoff must NOT carry over the prior delay")
}
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
// when ctx fires mid-wait. Without this, a tear-down would still take
// up to 1 second per orphaned listener.
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
var b AcceptBackoff
// Drive the backoff up so the next call will wait ~1s — long
// enough that we can detect early cancellation.
for range 10 {
b.Backoff(context.Background())
}
require.Equal(t, maxAcceptDelay, b.delay)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
assert.Less(t, elapsed, 200*time.Millisecond,
"cancellation must short-circuit the timer; took %v", elapsed)
}
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
// loop exits without sleeping at all.
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
var b AcceptBackoff
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
assert.Less(t, elapsed, 50*time.Millisecond,
"already-cancelled ctx must return immediately; took %v", elapsed)
}

View File

@@ -297,18 +297,23 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}()
var backoff AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ctx.Err() != nil || IsClosedListenerErr(err) {
if ok := r.Drain(DefaultDrainTimeout); !ok {
r.logger.Warn("timed out waiting for connections to drain")
}
return nil
}
r.logger.Debugf("SNI router accept: %v", err)
r.logger.Debugf("SNI router accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return nil
}
continue
}
backoff.Reset()
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {

View File

@@ -1836,3 +1836,132 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
t.Fatal("TLS conn never reached the TLS channel")
}
}
// scriptedAcceptListener is a net.Listener whose Accept() returns
// pre-scripted errors. Used by the accept-loop exit tests to simulate
// the failure mode that triggers the tight-loop bug: a netstack
// listener whose endpoint has been destroyed and now returns the gVisor
// "endpoint is in invalid state" error from every Accept call.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
// for the tight-loop bug: when the underlying netstack endpoint is
// destroyed, Accept returns "endpoint is in invalid state" forever. The
// loop must recognise that signal and return, otherwise it pegs a CPU
// core and floods logs.
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan error, 1)
go func() {
done <- router.Serve(context.Background(), ln)
}()
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
// depth path: when Accept returns an unknown transient error, the loop
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
// "Bounded call count" stands in for "no CPU spin" — without backoff
// the goroutine would issue thousands of Accept calls in this window.
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
const transientErrCount = 5
errs := make([]error, transientErrCount)
for i := range errs {
errs[i] = errSentinel("transient: too many open files")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
start := time.Now()
go func() {
done <- router.Serve(ctx, ln)
}()
// Cancel after enough time for the backoff to climb (5ms + 10ms +
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
// loop would have made thousands of calls by now.
time.AfterFunc(150*time.Millisecond, cancel)
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the loop would burn through all 5 scripted errors
// in microseconds and then block on the channel. With backoff the
// total wall time should be at least 5ms (the first backoff).
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"loop ran without backing off — would burn CPU in production")
}
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
// import the gVisor package without dragging in the whole netstack, so
// the test uses the canonical string the production error formatter
// emits — same shape IsClosedListenerErr matches in production.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }

View File

@@ -26,10 +26,6 @@ type Peer struct {
// a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
// the same ServerStream, and a peer can be the target of many senders at
// once.
sendMu sync.Mutex
// registration time
RegisteredAt time.Time
@@ -37,13 +33,6 @@ type Peer struct {
Cancel context.CancelFunc
}
// Send writes a message to the peer's stream, serializing concurrent senders.
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.Stream.Send(msg)
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{

View File

@@ -1,67 +0,0 @@
package server
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/peer"
)
// concurrencyCheckStream records the maximum number of Send calls in flight at
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
// server must never have more than one in flight per peer.
type concurrencyCheckStream struct {
proto.SignalExchange_ConnectStreamServer
ctx context.Context
inflight atomic.Int32
maxSeen atomic.Int32
}
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
n := s.inflight.Add(1)
for {
old := s.maxSeen.Load()
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
break
}
}
// Widen the window so overlapping callers are reliably observed.
time.Sleep(time.Millisecond)
s.inflight.Add(-1)
return nil
}
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
// same peer never call Stream.Send concurrently, which would violate the gRPC
// ServerStream contract.
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
s, err := NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
const peerID = "peerX"
stream := &concurrencyCheckStream{ctx: context.Background()}
_, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
}()
}
wg.Wait()
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
}

View File

@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Send(msg):
case sendResultChan <- dstPeer.Stream.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return