Compare commits

..

7 Commits

Author SHA1 Message Date
pascal
ad6e0b244c rename 2026-06-18 20:10:38 +02:00
pascal
dd041e2136 log meta diff with context 2026-06-18 19:40:58 +02:00
pascal
698acf5dc2 log on info 2026-06-18 19:37:45 +02:00
pascal
9009784e1a log wt version 2026-06-18 19:37:45 +02:00
pascal
f4183ab0c3 log meta diff 2026-06-18 19:37:44 +02:00
Pascal Fischer
60a9544656 [management] pass meta update for browser clients (#6465) 2026-06-18 17:22:42 +02:00
Viktor Liu
d3710d4bb2 [signal] Serialize concurrent sends to a peer signal stream (#6463) 2026-06-18 15:00:19 +02:00
15 changed files with 298 additions and 966 deletions

View File

@@ -64,6 +64,7 @@ import (
mgm "github.com/netbirdio/netbird/shared/management/client"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
@@ -1077,17 +1078,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return ErrResetConnection
}
if !e.config.DisableIPv6 {
reset, err := e.reconcileIPv6(conf)
if err != nil {
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
}
if reset {
log.Infof("peer IPv6 address changed value, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
log.Infof("peer IPv6 address changed, restarting client")
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
e.clientCancel()
return ErrResetConnection
}
if conf.GetSshConfig() != nil {
@@ -1109,58 +1104,25 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
return nil
}
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
// engine's WireGuard interface in place when possible. Three transitions:
//
// - First v6 assignment (current had no v6, conf carries one): apply via
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
// boot config has no v6 — without this we reset on every fresh start
// once management has v6 enabled, orphaning any netstack listeners
// held outside the engine.
// - v6 removed (current had v6, conf carries none): clear in place, no
// reset.
// - v6 swapped to a different non-empty value: returns reset=true so the
// caller falls back to the engine-recreate path — the underlying
// interface address can't be safely swapped in place across all
// backends (gVisor netstack in particular fixes its address at
// CreateNetTUN time).
//
// Mutates e.config.WgAddr to match the applied state so subsequent
// PeerConfig comparisons are stable.
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
raw := conf.GetAddressV6()
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
// differs from the configured address (added, removed, or changed).
// Compares against e.config.WgAddr (not the interface address, which may have
// been cleared by ClearIPv6 if OS assignment failed).
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
current := e.config.WgAddr
raw := conf.GetAddressV6()
if len(raw) == 0 {
if !current.HasIPv6() {
return false, nil
}
current.ClearIPv6()
e.config.WgAddr = current
if err := e.wgInterface.UpdateAddr(current); err != nil {
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
}
return false, nil
return current.HasIPv6()
}
incoming := current
if err := incoming.SetIPv6FromCompact(raw); err != nil {
return false, fmt.Errorf("decode v6 overlay address: %w", err)
prefix, err := netiputil.DecodePrefix(raw)
if err != nil {
log.Errorf("decode v6 overlay address: %v", err)
return false
}
if !current.HasIPv6() {
e.config.WgAddr = incoming
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
}
return false, nil
}
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
return false, nil
}
return true, nil
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
}
func (e *Engine) receiveJobEvents() {

View File

@@ -1,305 +0,0 @@
package internal
import (
"context"
"errors"
"net/netip"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/peer"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
)
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
// netstack clients: any first NetworkMap update that brings an IPv6 address
// used to trigger ErrResetConnection, which destroys the netstack and orphans
// every listener bound on it (proxy-side inbound listeners in particular).
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
// from "v6 swapped value" (must reset).
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err, "encode v6 prefix %s", p)
return b
}
// reconcileIPv6Fixture builds the smallest Engine the function under test
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
// whose UpdateAddr call we can observe.
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
t.Helper()
var applied wgaddr.Address
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return initial },
UpdateAddrFunc: func(a wgaddr.Address) error {
applied = a
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: initial},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
return e, mock, &applied
}
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
// Embedded clients boot v4-only; management later assigns a v6 overlay.
// The fix: apply v6 in place, return reset=false. Pre-fix this case
// fell through to the "v6 changed" branch and reset the engine.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, mock, applied := reconcileIPv6Fixture(t, v4)
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
_ = mock
}
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
// Steady state: management redelivers the same PeerConfig. No interface
// mutation, no reset. Guards against an infinite reset loop if the
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
}
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
// Management withdraws v6 (e.g. account toggled off the v6 group).
// Cleared in place, no reset.
v6Prefix := netip.MustParsePrefix("fd00::1/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
e, _, applied := reconcileIPv6Fixture(t, addr)
e.config.WgAddr = addr
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: nil,
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.False(t, reset, "v6 removed must NOT trigger reset")
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
}
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
// change because the new netmask redefines the broadcast/scope.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::1/80")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 prefix length change must request a reset")
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
}
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
// v6 was X, now Y. The netstack backend can't safely swap an existing
// address in place — fall back to the engine recreate path.
oldPrefix := netip.MustParsePrefix("fd00::1/64")
newPrefix := netip.MustParsePrefix("fd00::2/64")
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
updateAddrCalled := false
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return addr },
UpdateAddrFunc: func(a wgaddr.Address) error {
updateAddrCalled = true
return nil
},
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: addr},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
conf := &mgmtProto.PeerConfig{
Address: addr.String(),
AddressV6: mustEncodeV6Prefix(t, newPrefix),
}
reset, err := e.reconcileIPv6(conf)
require.NoError(t, err)
assert.True(t, reset, "v6 value change must request a reset")
assert.False(t, updateAddrCalled,
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
}
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
// trigger a spurious reset.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
e, _, applied := reconcileIPv6Fixture(t, v4)
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "malformed v6 bytes must surface an error")
assert.False(t, reset, "decode error must NOT request a reset")
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
}
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
// kernel device), reconcileIPv6 returns the error to the caller for
// logging — but it must NOT request a reset. The whole point of the
// fix is to AVOID the reset cascade on v6 transitions.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
reset, err := e.reconcileIPv6(conf)
require.Error(t, err, "UpdateAddr failure must surface")
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
}
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
// The integration check: updateConfig must not return ErrResetConnection
// when the only change between current state and the new PeerConfig is
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
// every listener bound on the engine's netstack.
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
mock := &MockWGIface{
AddressFunc: func() wgaddr.Address { return v4 },
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
IsUserspaceBindFunc: func() bool { return true },
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
e := &Engine{
ctx: ctx,
clientCtx: ctx,
clientCancel: cancel,
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
wgInterface: mock,
syncMsgMux: &sync.Mutex{},
statusRecorder: peer.NewRecorder("https://mgm.test"),
}
v6Prefix := netip.MustParsePrefix("fd00::1/64")
conf := &mgmtProto.PeerConfig{
Address: v4.String(),
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
}
err := e.updateConfig(conf)
assert.NoError(t, err,
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
assert.NotErrorIs(t, err, ErrResetConnection)
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
}

View File

@@ -66,6 +66,7 @@ import (
"github.com/netbirdio/netbird/route"
mgmt "github.com/netbirdio/netbird/shared/management/client"
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/shared/netiputil"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/shared/signal/proto"
@@ -1706,12 +1707,82 @@ func getPeers(e *Engine) int {
return len(e.peerStore.PeersPubKey())
}
// The former TestEngine_hasIPv6Changed has been superseded by
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
// in place instead of demanding a full engine reset. The behavioral
// matrix the old test enforced is now covered, with corrected expectations,
// by TestReconcileIPv6_* in that sibling file.
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
t.Helper()
b, err := netiputil.EncodePrefix(p)
require.NoError(t, err)
return b
}
func TestEngine_hasIPv6Changed(t *testing.T) {
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
tests := []struct {
name string
current wgaddr.Address
confV6 []byte
expected bool
}{
{
name: "no v6 before, no v6 now",
current: v4Only,
confV6: nil,
expected: false,
},
{
name: "no v6 before, v6 added",
current: v4Only,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: true,
},
{
name: "had v6, now removed",
current: v4v6,
confV6: nil,
expected: true,
},
{
name: "had v6, same v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
expected: false,
},
{
name: "had v6, different v6",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
expected: true,
},
{
name: "same v6 addr, different prefix length",
current: v4v6,
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
expected: true,
},
{
name: "decode error keeps status quo",
current: v4Only,
confV6: []byte{1, 2, 3},
expected: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{WgAddr: tt.current},
}
conf := &mgmtProto.PeerConfig{
AddressV6: tt.confV6,
}
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
})
}
}
func TestFilterAllowedIPs(t *testing.T) {
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")

View File

@@ -1017,7 +1017,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
}
oldHasIPv6Cap := peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
updated, versionChanged = peer.UpdateMetaIfNew(sync.Meta)
updated, versionChanged = peer.UpdateMetaIfNew(ctx, sync.Meta)
ipv6CapabilityChanged = oldHasIPv6Cap != peer.HasCapability(nbpeer.PeerCapabilityIPv6Overlay)
if updated {
am.metrics.AccountManagerMetrics().CountPeerMetUpdate()
@@ -1124,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
var peer *nbpeer.Peer
var shouldStorePeer bool
var shouldStorePeer, shouldUpdatePeers bool
var peerGroupIDs []string
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
@@ -1151,6 +1151,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
if changed {
shouldStorePeer = true
shouldUpdatePeers = true
}
}
@@ -1174,13 +1175,16 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
}
}
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
peer.UpdateMetaIfNew(ctx, login.Meta)
return nil
})
if err != nil {
return nil, nil, nil, false, err
}
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
if err != nil {
return nil, nil, nil, false, err
}
@@ -1190,7 +1194,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
return nil, nil, nil, false, err
}
if isStatusChanged || shouldStorePeer {
if shouldUpdatePeers {
changedPeerIDs := []string{peer.ID}
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {

View File

@@ -1,12 +1,17 @@
package peer
import (
"context"
"fmt"
"net"
"net/netip"
"slices"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/management/server/util"
"github.com/netbirdio/netbird/shared/management/http/api"
)
@@ -162,49 +167,7 @@ type PeerSystemMeta struct { //nolint:revive
}
func (p PeerSystemMeta) isEqual(other PeerSystemMeta) bool {
sort.Slice(p.NetworkAddresses, func(i, j int) bool {
return p.NetworkAddresses[i].Mac < p.NetworkAddresses[j].Mac
})
sort.Slice(other.NetworkAddresses, func(i, j int) bool {
return other.NetworkAddresses[i].Mac < other.NetworkAddresses[j].Mac
})
equalNetworkAddresses := slices.EqualFunc(p.NetworkAddresses, other.NetworkAddresses, func(addr NetworkAddress, oAddr NetworkAddress) bool {
return addr.Mac == oAddr.Mac && addr.NetIP == oAddr.NetIP
})
if !equalNetworkAddresses {
return false
}
sort.Slice(p.Files, func(i, j int) bool {
return p.Files[i].Path < p.Files[j].Path
})
sort.Slice(other.Files, func(i, j int) bool {
return other.Files[i].Path < other.Files[j].Path
})
equalFiles := slices.EqualFunc(p.Files, other.Files, func(file File, oFile File) bool {
return file.Path == oFile.Path && file.Exist == oFile.Exist && file.ProcessIsRunning == oFile.ProcessIsRunning
})
if !equalFiles {
return false
}
return p.Hostname == other.Hostname &&
p.GoOS == other.GoOS &&
p.Kernel == other.Kernel &&
p.KernelVersion == other.KernelVersion &&
p.Core == other.Core &&
p.Platform == other.Platform &&
p.OS == other.OS &&
p.OSVersion == other.OSVersion &&
p.WtVersion == other.WtVersion &&
p.UIVersion == other.UIVersion &&
p.SystemSerialNumber == other.SystemSerialNumber &&
p.SystemProductName == other.SystemProductName &&
p.SystemManufacturer == other.SystemManufacturer &&
p.Environment.Cloud == other.Environment.Cloud &&
p.Environment.Platform == other.Environment.Platform &&
p.Flags.isEqual(other.Flags) &&
capabilitiesEqual(p.Capabilities, other.Capabilities)
return len(metaDiff(p, other)) == 0
}
func (p PeerSystemMeta) isEmpty() bool {
@@ -296,7 +259,7 @@ func (p *Peer) Copy() *Peer {
// UpdateMetaIfNew updates peer's system metadata if new information is provided
// returns true if meta was updated, false otherwise
func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged bool) {
func (p *Peer) UpdateMetaIfNew(ctx context.Context, meta PeerSystemMeta) (updated, versionChanged bool) {
if meta.isEmpty() {
return updated, versionChanged
}
@@ -308,14 +271,113 @@ func (p *Peer) UpdateMetaIfNew(meta PeerSystemMeta) (updated, versionChanged boo
meta.UIVersion = p.Meta.UIVersion
}
if p.Meta.isEqual(meta) {
return updated, versionChanged
oldVersion := p.Meta.WtVersion
diff := metaDiff(p.Meta, meta)
if len(diff) != 0 {
p.Meta = meta
updated = true
}
p.Meta = meta
updated = true
versionInfo := ""
if versionChanged {
versionInfo = fmt.Sprintf("version changed: %s -> %s, ", oldVersion, meta.WtVersion)
}
if len(diff) > 0 || versionChanged {
log.WithContext(ctx).
Debugf("peer meta updated, %s%d field(s) changed: %s", versionInfo, len(diff), strings.Join(diff, ", "))
}
return updated, versionChanged
}
// metaDiff returns a human-readable list of the fields that differ between the
// old and new meta, each formatted as `field: <old> -> <new>`. It is the single
// source of truth for meta comparison: isEqual reports equality as an empty
// diff, so the log line can never disagree with the change decision. Slices are
// cloned before sorting, so callers' meta is not mutated.
func metaDiff(oldMeta, newMeta PeerSystemMeta) []string {
var diff []string
add := func(field string, oldVal, newVal any) {
diff = append(diff, fmt.Sprintf("%s: %v -> %v", field, oldVal, newVal))
}
if oldMeta.Hostname != newMeta.Hostname {
add("hostname", oldMeta.Hostname, newMeta.Hostname)
}
if oldMeta.GoOS != newMeta.GoOS {
add("goos", oldMeta.GoOS, newMeta.GoOS)
}
if oldMeta.Kernel != newMeta.Kernel {
add("kernel", oldMeta.Kernel, newMeta.Kernel)
}
if oldMeta.KernelVersion != newMeta.KernelVersion {
add("kernel_version", oldMeta.KernelVersion, newMeta.KernelVersion)
}
if oldMeta.Core != newMeta.Core {
add("core", oldMeta.Core, newMeta.Core)
}
if oldMeta.Platform != newMeta.Platform {
add("platform", oldMeta.Platform, newMeta.Platform)
}
if oldMeta.OS != newMeta.OS {
add("os", oldMeta.OS, newMeta.OS)
}
if oldMeta.OSVersion != newMeta.OSVersion {
add("os_version", oldMeta.OSVersion, newMeta.OSVersion)
}
if oldMeta.WtVersion != newMeta.WtVersion {
add("wt_version", oldMeta.WtVersion, newMeta.WtVersion)
}
if oldMeta.UIVersion != newMeta.UIVersion {
add("ui_version", oldMeta.UIVersion, newMeta.UIVersion)
}
if oldMeta.SystemSerialNumber != newMeta.SystemSerialNumber {
add("system_serial_number", oldMeta.SystemSerialNumber, newMeta.SystemSerialNumber)
}
if oldMeta.SystemProductName != newMeta.SystemProductName {
add("system_product_name", oldMeta.SystemProductName, newMeta.SystemProductName)
}
if oldMeta.SystemManufacturer != newMeta.SystemManufacturer {
add("system_manufacturer", oldMeta.SystemManufacturer, newMeta.SystemManufacturer)
}
if oldMeta.Environment.Cloud != newMeta.Environment.Cloud {
add("environment_cloud", oldMeta.Environment.Cloud, newMeta.Environment.Cloud)
}
if oldMeta.Environment.Platform != newMeta.Environment.Platform {
add("environment_platform", oldMeta.Environment.Platform, newMeta.Environment.Platform)
}
if !oldMeta.Flags.isEqual(newMeta.Flags) {
add("flags", fmt.Sprintf("%+v", oldMeta.Flags), fmt.Sprintf("%+v", newMeta.Flags))
}
if !capabilitiesEqual(oldMeta.Capabilities, newMeta.Capabilities) {
add("capabilities", oldMeta.Capabilities, newMeta.Capabilities)
}
oldAddrs := slices.Clone(oldMeta.NetworkAddresses)
newAddrs := slices.Clone(newMeta.NetworkAddresses)
sort.Slice(oldAddrs, func(i, j int) bool { return oldAddrs[i].Mac < oldAddrs[j].Mac })
sort.Slice(newAddrs, func(i, j int) bool { return newAddrs[i].Mac < newAddrs[j].Mac })
if !slices.EqualFunc(oldAddrs, newAddrs, func(a, b NetworkAddress) bool {
return a.Mac == b.Mac && a.NetIP == b.NetIP
}) {
add("network_addresses", fmt.Sprintf("%v", oldAddrs), fmt.Sprintf("%v", newAddrs))
}
oldFiles := slices.Clone(oldMeta.Files)
newFiles := slices.Clone(newMeta.Files)
sort.Slice(oldFiles, func(i, j int) bool { return oldFiles[i].Path < oldFiles[j].Path })
sort.Slice(newFiles, func(i, j int) bool { return newFiles[i].Path < newFiles[j].Path })
if !slices.EqualFunc(oldFiles, newFiles, func(a, b File) bool {
return a.Path == b.Path && a.Exist == b.Exist && a.ProcessIsRunning == b.ProcessIsRunning
}) {
add("files", fmt.Sprintf("%v", oldFiles), fmt.Sprintf("%v", newFiles))
}
return diff
}
// GetLastLogin returns the last login time of the peer.
func (p *Peer) GetLastLogin() time.Time {
if p.LastLogin != nil {

View File

@@ -466,20 +466,15 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
_ = ln.Close()
}()
var backoff nbtcp.AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
return
}
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
continue
}
backoff.Reset()
router.HandleConn(ctx, conn)
}
}

View File

@@ -533,125 +533,3 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
-----END EC PRIVATE KEY-----`)
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
// to drive the feedRouterFromListener tests without binding a real
// socket — the production code path is a netstack-backed listener that
// returns gVisor's "endpoint is in invalid state" forever after its
// endpoint is destroyed.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// errSentinel carries a literal error message so tests can synthesise
// the exact gVisor text without importing the netstack package.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
// regression guard for the inbound side of the tight-loop bug. The
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
// invalid state" and exit, otherwise it pegs a CPU core and floods the
// account-scoped log with the same accept error every iteration.
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
}()
select {
case <-done:
// Expected: loop recognised the gVisor error and returned.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
// defence-in-depth path: an unknown sticky Accept error must NOT cause
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
router := nbtcp.NewRouter(logger, nil, addr)
const transientCount = 5
errs := make([]error, transientCount)
for i := range errs {
errs[i] = errSentinel("transient: temporary network error")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
start := time.Now()
done := make(chan struct{})
go func() {
defer close(done)
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
}()
time.AfterFunc(150*time.Millisecond, cancel)
select {
case <-done:
// Expected.
case <-time.After(2 * time.Second):
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the 5 scripted errors would burn in microseconds.
// With backoff the first delay alone is 5ms, so the loop must take
// at least that long even though ctx fires at 150ms.
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
"loop ran without backing off — would burn CPU in production")
}

View File

@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
// Create embedded NetBird client with the generated private key.
// The peer has already been created via CreateProxyPeer RPC with the public key.
wgPort := int(n.clientCfg.WGPort)
embedOpts := embed.Options{
client, err := embed.New(embed.Options{
DeviceName: deviceNamePrefix + n.proxyID,
ManagementURL: n.clientCfg.MgmtAddr,
PrivateKey: privateKey.String(),
@@ -371,9 +371,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
WireguardPort: &wgPort,
PreSharedKey: n.clientCfg.PreSharedKey,
Performance: n.clientCfg.Performance,
}
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
client, err := embed.New(embedOpts)
})
if err != nil {
return nil, fmt.Errorf("create netbird client: %w", err)
}
@@ -849,53 +847,3 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
return v
}
// logEmbedOptions emits a single structured INFO line summarising every
// operationally meaningful flag handed to embed.New for this per-account
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
// boolean — never logged verbatim. Use this when an embedded peer
// silently misbehaves: most failure modes (inbound drops, wrong
// management URL, v6 unexpectedly on, userspace flipped, port clash)
// are obvious from these flags before any traffic flows.
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
wgPort := 0
if opts.WireguardPort != nil {
wgPort = *opts.WireguardPort
}
mtu := uint16(0)
if opts.MTU != nil {
mtu = *opts.MTU
}
perfBuffers := uint32(0)
if opts.Performance.PreallocatedBuffersPerPool != nil {
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
}
perfBatch := uint32(0)
if opts.Performance.MaxBatchSize != nil {
perfBatch = *opts.Performance.MaxBatchSize
}
logger.WithFields(log.Fields{
"account_id": accountID,
"service_id": serviceID,
"public_key": publicKey,
"device_name": opts.DeviceName,
"management_url": opts.ManagementURL,
"log_level": opts.LogLevel,
"wg_port": wgPort,
"mtu": mtu,
"block_inbound": opts.BlockInbound,
"block_lan_access": opts.BlockLANAccess,
"disable_ipv6": opts.DisableIPv6,
"disable_client_routes": opts.DisableClientRoutes,
"no_userspace": opts.NoUserspace,
"config_path_set": opts.ConfigPath != "",
"state_path_set": opts.StatePath != "",
"private_key_present": opts.PrivateKey != "",
"presharedkey_present": opts.PreSharedKey != "",
"setup_key_present": opts.SetupKey != "",
"jwt_token_present": opts.JWTToken != "",
"dns_labels": opts.DNSLabels,
"perf_buffers_per_pool": perfBuffers,
"perf_max_batch_size": perfBatch,
}).Info("starting embedded netbird client for account")
}

View File

@@ -1,85 +0,0 @@
package tcp
import (
"context"
"errors"
"net"
"strings"
"time"
)
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
// when Accept() is called on a listener whose underlying endpoint has
// been destroyed (peer rekey, embedded-client reset, account churn).
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
// so we fall back to a string check. Stable across the gVisor versions
// netbird pins.
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
// IsClosedListenerErr reports whether err signals that an accept loop
// should exit because the underlying listener can no longer serve
// connections. It recognises:
//
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
// - gVisor's "endpoint is in invalid state" for netstack-backed
// listeners whose endpoint was destroyed out from under them
// (typically when a per-account WireGuard netstack is reset without
// also tearing the listener entry down).
//
// Without the gVisor branch an accept loop on a netstack listener spins
// CPU-hot forever after the endpoint dies, because Accept never blocks
// again and the error neither matches net.ErrClosed nor cancels ctx.
func IsClosedListenerErr(err error) bool {
if err == nil {
return false
}
if errors.Is(err, net.ErrClosed) {
return true
}
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
}
// AcceptBackoff implements the exponential backoff used by
// net/http.Server.Serve for transient Accept errors. Without it a loop
// hitting a sticky unknown error burns a full CPU core. The zero value
// is ready to use; call Reset after a successful Accept.
type AcceptBackoff struct {
delay time.Duration
}
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
// (net/http.Server.Serve) and keep us well below 1 log line per second
// per orphaned listener.
const (
minAcceptDelay = 5 * time.Millisecond
maxAcceptDelay = time.Second
)
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
// returns true when the wait completed. Returns false if ctx fired
// during the wait — callers should treat that as "exit the loop".
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
b.advance()
select {
case <-ctx.Done():
return false
case <-time.After(b.delay):
return true
}
}
// Reset clears the accumulated delay so the next failure starts at the
// minimum delay again. Call after a successful Accept.
func (b *AcceptBackoff) Reset() { b.delay = 0 }
func (b *AcceptBackoff) advance() {
if b.delay == 0 {
b.delay = minAcceptDelay
} else {
b.delay *= 2
}
if b.delay > maxAcceptDelay {
b.delay = maxAcceptDelay
}
}

View File

@@ -1,142 +0,0 @@
package tcp
import (
"context"
"errors"
"fmt"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
// and IsClosedListenerErr must unwrap it.
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
assert.True(t, IsClosedListenerErr(wrapped),
"net.OpError wrapping net.ErrClosed must be recognised as closed")
}
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
// regression guard. A gVisor netstack listener whose endpoint has been
// destroyed returns this exact text. Without recognising it the accept
// loop spins forever and burns a CPU core.
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
assert.True(t, IsClosedListenerErr(err),
"gVisor 'endpoint is in invalid state' must be recognised as closed")
}
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
// transient errors must keep returning false so the backoff path runs.
func TestIsClosedListenerErr_OtherError(t *testing.T) {
cases := []error{
errors.New("temporary failure"),
errors.New("accept tcp 10.10.1.254:80: too many open files"),
nil,
}
for _, c := range cases {
assert.False(t, IsClosedListenerErr(c),
"unexpected match on %v — must not be treated as closed", c)
}
}
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
// real timer but uses tight bounds so a slow CI machine still passes.
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
var b AcceptBackoff
expected := []time.Duration{
5 * time.Millisecond,
10 * time.Millisecond,
20 * time.Millisecond,
40 * time.Millisecond,
}
for i, want := range expected {
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
assert.GreaterOrEqual(t, elapsed, want,
"backoff %d (%v) must wait at least the configured delay", i, want)
assert.Less(t, elapsed, want*4,
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
}
// Burn enough rounds to reach the cap, then assert subsequent
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
// never exceed it.
for range 6 {
b.Backoff(context.Background())
}
assert.Equal(t, maxAcceptDelay, b.delay,
"after enough doublings the delay must clamp to maxAcceptDelay")
}
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
// after recovery.
func TestAcceptBackoff_Reset(t *testing.T) {
var b AcceptBackoff
for range 5 {
b.Backoff(context.Background())
}
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
b.Reset()
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
start := time.Now()
ok := b.Backoff(context.Background())
elapsed := time.Since(start)
require.True(t, ok, "Backoff after Reset must complete")
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"after Reset the next backoff must restart at minAcceptDelay")
assert.Less(t, elapsed, 50*time.Millisecond,
"after Reset the next backoff must NOT carry over the prior delay")
}
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
// when ctx fires mid-wait. Without this, a tear-down would still take
// up to 1 second per orphaned listener.
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
var b AcceptBackoff
// Drive the backoff up so the next call will wait ~1s — long
// enough that we can detect early cancellation.
for range 10 {
b.Backoff(context.Background())
}
require.Equal(t, maxAcceptDelay, b.delay)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(20 * time.Millisecond)
cancel()
}()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
assert.Less(t, elapsed, 200*time.Millisecond,
"cancellation must short-circuit the timer; took %v", elapsed)
}
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
// loop exits without sleeping at all.
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
var b AcceptBackoff
ctx, cancel := context.WithCancel(context.Background())
cancel()
start := time.Now()
ok := b.Backoff(ctx)
elapsed := time.Since(start)
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
assert.Less(t, elapsed, 50*time.Millisecond,
"already-cancelled ctx must return immediately; took %v", elapsed)
}

View File

@@ -297,23 +297,18 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
}
}()
var backoff AcceptBackoff
for {
conn, err := ln.Accept()
if err != nil {
if ctx.Err() != nil || IsClosedListenerErr(err) {
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
if ok := r.Drain(DefaultDrainTimeout); !ok {
r.logger.Warn("timed out waiting for connections to drain")
}
return nil
}
r.logger.Debugf("SNI router accept: %v; backing off", err)
if !backoff.Backoff(ctx) {
return nil
}
r.logger.Debugf("SNI router accept: %v", err)
continue
}
backoff.Reset()
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
r.activeConns.Add(1)
go func() {

View File

@@ -1836,132 +1836,3 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
t.Fatal("TLS conn never reached the TLS channel")
}
}
// scriptedAcceptListener is a net.Listener whose Accept() returns
// pre-scripted errors. Used by the accept-loop exit tests to simulate
// the failure mode that triggers the tight-loop bug: a netstack
// listener whose endpoint has been destroyed and now returns the gVisor
// "endpoint is in invalid state" error from every Accept call.
type scriptedAcceptListener struct {
errs chan error
closed chan struct{}
}
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
s := &scriptedAcceptListener{
errs: make(chan error, len(errs)+1),
closed: make(chan struct{}),
}
for _, e := range errs {
s.errs <- e
}
return s
}
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
select {
case <-s.closed:
return nil, net.ErrClosed
case err := <-s.errs:
return nil, err
}
}
func (s *scriptedAcceptListener) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
func (s *scriptedAcceptListener) Addr() net.Addr {
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
}
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
// for the tight-loop bug: when the underlying netstack endpoint is
// destroyed, Accept returns "endpoint is in invalid state" forever. The
// loop must recognise that signal and return, otherwise it pegs a CPU
// core and floods logs.
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
gvisorErr := &net.OpError{
Op: "accept",
Net: "tcp",
Addr: addr,
Err: errSentinel("endpoint is in invalid state"),
}
ln := newScriptedAcceptListener(gvisorErr)
defer ln.Close()
done := make(chan error, 1)
go func() {
done <- router.Serve(context.Background(), ln)
}()
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
}
}
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
// depth path: when Accept returns an unknown transient error, the loop
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
// "Bounded call count" stands in for "no CPU spin" — without backoff
// the goroutine would issue thousands of Accept calls in this window.
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
logger := log.StandardLogger()
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
router := NewRouter(logger, nil, addr)
const transientErrCount = 5
errs := make([]error, transientErrCount)
for i := range errs {
errs[i] = errSentinel("transient: too many open files")
}
ln := newScriptedAcceptListener(errs...)
defer ln.Close()
ctx, cancel := context.WithCancel(context.Background())
done := make(chan error, 1)
start := time.Now()
go func() {
done <- router.Serve(ctx, ln)
}()
// Cancel after enough time for the backoff to climb (5ms + 10ms +
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
// loop would have made thousands of calls by now.
time.AfterFunc(150*time.Millisecond, cancel)
select {
case err := <-done:
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
case <-time.After(2 * time.Second):
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
}
// Without backoff the loop would burn through all 5 scripted errors
// in microseconds and then block on the channel. With backoff the
// total wall time should be at least 5ms (the first backoff).
elapsed := time.Since(start)
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
"loop ran without backing off — would burn CPU in production")
}
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
// import the gVisor package without dragging in the whole netstack, so
// the test uses the canonical string the production error formatter
// emits — same shape IsClosedListenerErr matches in production.
type errSentinel string
func (e errSentinel) Error() string { return string(e) }

View File

@@ -26,6 +26,10 @@ type Peer struct {
// a gRpc connection stream to the Peer
Stream proto.SignalExchange_ConnectStreamServer
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
// the same ServerStream, and a peer can be the target of many senders at
// once.
sendMu sync.Mutex
// registration time
RegisteredAt time.Time
@@ -33,6 +37,13 @@ type Peer struct {
Cancel context.CancelFunc
}
// Send writes a message to the peer's stream, serializing concurrent senders.
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
p.sendMu.Lock()
defer p.sendMu.Unlock()
return p.Stream.Send(msg)
}
// NewPeer creates a new instance of a connected Peer
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
return &Peer{

View File

@@ -0,0 +1,67 @@
package server
import (
"context"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/signal/peer"
)
// concurrencyCheckStream records the maximum number of Send calls in flight at
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
// server must never have more than one in flight per peer.
type concurrencyCheckStream struct {
proto.SignalExchange_ConnectStreamServer
ctx context.Context
inflight atomic.Int32
maxSeen atomic.Int32
}
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
n := s.inflight.Add(1)
for {
old := s.maxSeen.Load()
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
break
}
}
// Widen the window so overlapping callers are reliably observed.
time.Sleep(time.Millisecond)
s.inflight.Add(-1)
return nil
}
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
// same peer never call Stream.Send concurrently, which would violate the gRPC
// ServerStream contract.
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
s, err := NewServer(context.Background(), otel.Meter(""))
require.NoError(t, err)
const peerID = "peerX"
stream := &concurrencyCheckStream{ctx: context.Background()}
_, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
}()
}
wg.Wait()
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
}

View File

@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
sendResultChan := make(chan error, 1)
go func() {
select {
case sendResultChan <- dstPeer.Stream.Send(msg):
case sendResultChan <- dstPeer.Send(msg):
return
case <-dstPeer.Stream.Context().Done():
return