mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-18 22:09:56 +00:00
Compare commits
1 Commits
main
...
fix/ipv6-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d8b0310a4 |
@@ -247,7 +247,7 @@ dockers_v2:
|
||||
- netbirdio/netbird
|
||||
- ghcr.io/netbirdio/netbird
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: client/Dockerfile
|
||||
extra_files:
|
||||
@@ -295,7 +295,7 @@ dockers_v2:
|
||||
- netbirdio/relay
|
||||
- ghcr.io/netbirdio/relay
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: relay/Dockerfile
|
||||
platforms:
|
||||
@@ -317,7 +317,7 @@ dockers_v2:
|
||||
- netbirdio/signal
|
||||
- ghcr.io/netbirdio/signal
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: signal/Dockerfile
|
||||
platforms:
|
||||
@@ -339,7 +339,7 @@ dockers_v2:
|
||||
- netbirdio/management
|
||||
- ghcr.io/netbirdio/management
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: management/Dockerfile
|
||||
platforms:
|
||||
@@ -361,7 +361,7 @@ dockers_v2:
|
||||
- netbirdio/upload
|
||||
- ghcr.io/netbirdio/upload
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: upload-server/Dockerfile
|
||||
platforms:
|
||||
@@ -383,7 +383,7 @@ dockers_v2:
|
||||
- netbirdio/netbird-server
|
||||
- ghcr.io/netbirdio/netbird-server
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: combined/Dockerfile
|
||||
platforms:
|
||||
@@ -405,7 +405,7 @@ dockers_v2:
|
||||
- netbirdio/reverse-proxy
|
||||
- ghcr.io/netbirdio/reverse-proxy
|
||||
tags:
|
||||
- "{{ .Version }}"
|
||||
- "v{{ .Version }}"
|
||||
- "{{ if eq .Env.SKIP_PUBLISH \"false\" }}latest{{ end }}"
|
||||
dockerfile: proxy/Dockerfile
|
||||
platforms:
|
||||
|
||||
@@ -64,7 +64,6 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
@@ -1078,11 +1077,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return ErrResetConnection
|
||||
}
|
||||
|
||||
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
|
||||
log.Infof("peer IPv6 address changed, restarting client")
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
if !e.config.DisableIPv6 {
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
if err != nil {
|
||||
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
|
||||
}
|
||||
if reset {
|
||||
log.Infof("peer IPv6 address changed value, restarting client")
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
}
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
@@ -1104,25 +1109,58 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
|
||||
// differs from the configured address (added, removed, or changed).
|
||||
// Compares against e.config.WgAddr (not the interface address, which may have
|
||||
// been cleared by ClearIPv6 if OS assignment failed).
|
||||
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
current := e.config.WgAddr
|
||||
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
|
||||
// engine's WireGuard interface in place when possible. Three transitions:
|
||||
//
|
||||
// - First v6 assignment (current had no v6, conf carries one): apply via
|
||||
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
|
||||
// boot config has no v6 — without this we reset on every fresh start
|
||||
// once management has v6 enabled, orphaning any netstack listeners
|
||||
// held outside the engine.
|
||||
// - v6 removed (current had v6, conf carries none): clear in place, no
|
||||
// reset.
|
||||
// - v6 swapped to a different non-empty value: returns reset=true so the
|
||||
// caller falls back to the engine-recreate path — the underlying
|
||||
// interface address can't be safely swapped in place across all
|
||||
// backends (gVisor netstack in particular fixes its address at
|
||||
// CreateNetTUN time).
|
||||
//
|
||||
// Mutates e.config.WgAddr to match the applied state so subsequent
|
||||
// PeerConfig comparisons are stable.
|
||||
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
|
||||
raw := conf.GetAddressV6()
|
||||
current := e.config.WgAddr
|
||||
|
||||
if len(raw) == 0 {
|
||||
return current.HasIPv6()
|
||||
if !current.HasIPv6() {
|
||||
return false, nil
|
||||
}
|
||||
current.ClearIPv6()
|
||||
e.config.WgAddr = current
|
||||
if err := e.wgInterface.UpdateAddr(current); err != nil {
|
||||
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
prefix, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
log.Errorf("decode v6 overlay address: %v", err)
|
||||
return false
|
||||
incoming := current
|
||||
if err := incoming.SetIPv6FromCompact(raw); err != nil {
|
||||
return false, fmt.Errorf("decode v6 overlay address: %w", err)
|
||||
}
|
||||
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
if !current.HasIPv6() {
|
||||
e.config.WgAddr = incoming
|
||||
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
|
||||
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
|
||||
305
client/internal/engine_reconcileipv6_test.go
Normal file
305
client/internal/engine_reconcileipv6_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
|
||||
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
|
||||
// netstack clients: any first NetworkMap update that brings an IPv6 address
|
||||
// used to trigger ErrResetConnection, which destroys the netstack and orphans
|
||||
// every listener bound on it (proxy-side inbound listeners in particular).
|
||||
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
|
||||
// from "v6 swapped value" (must reset).
|
||||
|
||||
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
require.NoError(t, err, "encode v6 prefix %s", p)
|
||||
return b
|
||||
}
|
||||
|
||||
// reconcileIPv6Fixture builds the smallest Engine the function under test
|
||||
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
|
||||
// whose UpdateAddr call we can observe.
|
||||
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
|
||||
t.Helper()
|
||||
var applied wgaddr.Address
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return initial },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
applied = a
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: initial},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
return e, mock, &applied
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
|
||||
// Embedded clients boot v4-only; management later assigns a v6 overlay.
|
||||
// The fix: apply v6 in place, return reset=false. Pre-fix this case
|
||||
// fell through to the "v6 changed" branch and reset the engine.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
e, mock, applied := reconcileIPv6Fixture(t, v4)
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
|
||||
|
||||
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
|
||||
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
|
||||
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
|
||||
|
||||
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
|
||||
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
|
||||
_ = mock
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
|
||||
// Steady state: management redelivers the same PeerConfig. No interface
|
||||
// mutation, no reset. Guards against an infinite reset loop if the
|
||||
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
|
||||
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
|
||||
// Management withdraws v6 (e.g. account toggled off the v6 group).
|
||||
// Cleared in place, no reset.
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
|
||||
|
||||
e, _, applied := reconcileIPv6Fixture(t, addr)
|
||||
e.config.WgAddr = addr
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: nil,
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "v6 removed must NOT trigger reset")
|
||||
|
||||
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
|
||||
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
|
||||
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
|
||||
// change because the new netmask redefines the broadcast/scope.
|
||||
oldPrefix := netip.MustParsePrefix("fd00::1/64")
|
||||
newPrefix := netip.MustParsePrefix("fd00::1/80")
|
||||
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, newPrefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reset, "v6 prefix length change must request a reset")
|
||||
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
|
||||
// v6 was X, now Y. The netstack backend can't safely swap an existing
|
||||
// address in place — fall back to the engine recreate path.
|
||||
oldPrefix := netip.MustParsePrefix("fd00::1/64")
|
||||
newPrefix := netip.MustParsePrefix("fd00::2/64")
|
||||
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, newPrefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reset, "v6 value change must request a reset")
|
||||
assert.False(t, updateAddrCalled,
|
||||
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
|
||||
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
|
||||
// trigger a spurious reset.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
e, _, applied := reconcileIPv6Fixture(t, v4)
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.Error(t, err, "malformed v6 bytes must surface an error")
|
||||
assert.False(t, reset, "decode error must NOT request a reset")
|
||||
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
|
||||
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
|
||||
// kernel device), reconcileIPv6 returns the error to the caller for
|
||||
// logging — but it must NOT request a reset. The whole point of the
|
||||
// fix is to AVOID the reset cascade on v6 transitions.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return v4 },
|
||||
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: v4},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.Error(t, err, "UpdateAddr failure must surface")
|
||||
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
|
||||
}
|
||||
|
||||
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
|
||||
// The integration check: updateConfig must not return ErrResetConnection
|
||||
// when the only change between current state and the new PeerConfig is
|
||||
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
|
||||
// every listener bound on the engine's netstack.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return v4 },
|
||||
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
|
||||
IsUserspaceBindFunc: func() bool { return true },
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
statusRecorder: peer.NewRecorder("https://mgm.test"),
|
||||
}
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
|
||||
err := e.updateConfig(conf)
|
||||
assert.NoError(t, err,
|
||||
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
|
||||
assert.NotErrorIs(t, err, ErrResetConnection)
|
||||
|
||||
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
|
||||
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
|
||||
}
|
||||
@@ -66,7 +66,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
@@ -1707,82 +1706,12 @@ func getPeers(e *Engine) int {
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
require.NoError(t, err)
|
||||
return b
|
||||
}
|
||||
|
||||
func TestEngine_hasIPv6Changed(t *testing.T) {
|
||||
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
|
||||
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current wgaddr.Address
|
||||
confV6 []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no v6 before, no v6 now",
|
||||
current: v4Only,
|
||||
confV6: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no v6 before, v6 added",
|
||||
current: v4Only,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, now removed",
|
||||
current: v4v6,
|
||||
confV6: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, same v6",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "had v6, different v6",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same v6 addr, different prefix length",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "decode error keeps status quo",
|
||||
current: v4Only,
|
||||
confV6: []byte{1, 2, 3},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{WgAddr: tt.current},
|
||||
}
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
AddressV6: tt.confV6,
|
||||
}
|
||||
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
|
||||
})
|
||||
}
|
||||
}
|
||||
// The former TestEngine_hasIPv6Changed has been superseded by
|
||||
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
|
||||
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
|
||||
// in place instead of demanding a full engine reset. The behavioral
|
||||
// matrix the old test enforced is now covered, with corrected expectations,
|
||||
// by TestReconcileIPv6_* in that sibling file.
|
||||
|
||||
func TestFilterAllowedIPs(t *testing.T) {
|
||||
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
@@ -1026,12 +1026,7 @@ func (am *DefaultAccountManager) SyncPeer(ctx context.Context, sync types.PeerSy
|
||||
return err
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||
postureChecks, err = getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1129,7 +1124,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
|
||||
var peer *nbpeer.Peer
|
||||
var shouldStorePeer, shouldUpdatePeers bool
|
||||
var shouldStorePeer bool
|
||||
var peerGroupIDs []string
|
||||
|
||||
settings, err := am.Store.GetAccountSettings(ctx, store.LockingStrengthNone, accountID)
|
||||
@@ -1156,7 +1151,6 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
|
||||
if changed {
|
||||
shouldStorePeer = true
|
||||
shouldUpdatePeers = true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1180,16 +1174,13 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
}
|
||||
}
|
||||
|
||||
// This is needed to keep in memory for the peer config. Otherwise browser client will end in a retry loop
|
||||
peer.UpdateMetaIfNew(login.Meta)
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
isRequiresApproval, _, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
isRequiresApproval, isStatusChanged, err := am.integratedPeerValidator.IsNotValidPeer(ctx, accountID, peer, peerGroupIDs, settings.Extra)
|
||||
if err != nil {
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
@@ -1199,7 +1190,7 @@ func (am *DefaultAccountManager) LoginPeer(ctx context.Context, login types.Peer
|
||||
return nil, nil, nil, false, err
|
||||
}
|
||||
|
||||
if shouldUpdatePeers {
|
||||
if isStatusChanged || shouldStorePeer {
|
||||
changedPeerIDs := []string{peer.ID}
|
||||
affectedPeerIDs := am.resolveAffectedPeersForPeerChanges(ctx, am.Store, accountID, changedPeerIDs)
|
||||
if err = am.networkMapController.OnPeersUpdated(ctx, accountID, changedPeerIDs, affectedPeerIDs); err != nil {
|
||||
@@ -1295,22 +1286,12 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, nil, false, nil
|
||||
}
|
||||
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
peerGroupIDs, err := transaction.GetPeerGroupIDs(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
postureChecks, err := getPeerPostureChecks(ctx, transaction, accountID, peerGroupIDs, policies)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, peer, policies, peerGroupIDs)
|
||||
enableSSH, err := isPeerSSHEnabled(ctx, transaction, accountID, peer)
|
||||
if err != nil {
|
||||
return nil, nil, false, err
|
||||
}
|
||||
@@ -1318,16 +1299,32 @@ func getPeerLoginInfo(ctx context.Context, transaction store.Store, accountID st
|
||||
return network, postureChecks, enableSSH, nil
|
||||
}
|
||||
|
||||
func isPeerSSHEnabled(ctx context.Context, peer *nbpeer.Peer, policies []*types.Policy, peerGroupIDs []string) (bool, error) {
|
||||
groupIDsMap := make(map[string]struct{}, len(peerGroupIDs))
|
||||
for _, peerID := range peerGroupIDs {
|
||||
groupIDsMap[peerID] = struct{}{}
|
||||
func isPeerSSHEnabled(ctx context.Context, transaction store.Store, accountID string, peer *nbpeer.Peer) (bool, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, groupIDsMap, peer.SSHEnabled), nil
|
||||
|
||||
peerGroups, err := transaction.GetPeerGroups(ctx, store.LockingStrengthNone, accountID, peer.ID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
peerGroupIDs := make(map[string]struct{}, len(peerGroups))
|
||||
for _, g := range peerGroups {
|
||||
peerGroupIDs[g.ID] = struct{}{}
|
||||
}
|
||||
|
||||
return types.PeerSSHEnabledFromPolicies(policies, peer.ID, peerGroupIDs, peer.SSHEnabled), nil
|
||||
}
|
||||
|
||||
// getPeerPostureChecks returns the posture checks for the peer.
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID string, peerGroupIDs []string, policies []*types.Policy) ([]*posture.Checks, error) {
|
||||
func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountID, peerID string) ([]*posture.Checks, error) {
|
||||
policies, err := transaction.GetAccountPolicies(ctx, store.LockingStrengthNone, accountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(policies) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1339,7 +1336,11 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
continue
|
||||
}
|
||||
|
||||
postureChecksIDs := processPeerPostureChecks(policy, peerGroupIDs)
|
||||
postureChecksIDs, err := processPeerPostureChecks(ctx, transaction, policy, accountID, peerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
peerPostureChecksIDs = append(peerPostureChecksIDs, postureChecksIDs...)
|
||||
}
|
||||
|
||||
@@ -1352,19 +1353,29 @@ func getPeerPostureChecks(ctx context.Context, transaction store.Store, accountI
|
||||
}
|
||||
|
||||
// processPeerPostureChecks checks if the peer is in the source group of the policy and returns the posture checks.
|
||||
func processPeerPostureChecks(policy *types.Policy, peerGroupIDs []string) []string {
|
||||
func processPeerPostureChecks(ctx context.Context, transaction store.Store, policy *types.Policy, accountID, peerID string) ([]string, error) {
|
||||
for _, rule := range policy.Rules {
|
||||
if !rule.Enabled {
|
||||
continue
|
||||
}
|
||||
|
||||
sourceGroups, err := transaction.GetGroupsByIDs(ctx, store.LockingStrengthNone, accountID, rule.Sources)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, sourceGroup := range rule.Sources {
|
||||
if slices.Contains(peerGroupIDs, sourceGroup) {
|
||||
return policy.SourcePostureChecks
|
||||
group, ok := sourceGroups[sourceGroup]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to check peer in policy source group")
|
||||
}
|
||||
|
||||
if slices.Contains(group.Peers, peerID) {
|
||||
return policy.SourcePostureChecks, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// checkIFPeerNeedsLoginWithoutLock checks if the peer needs login without acquiring the account lock. The check validate if the peer was not added via SSO
|
||||
|
||||
@@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
var backoff nbtcp.AcceptBackoff
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
|
||||
if !backoff.Backoff(ctx) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
|
||||
continue
|
||||
}
|
||||
backoff.Reset()
|
||||
router.HandleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
|
||||
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
||||
-----END EC PRIVATE KEY-----`)
|
||||
|
||||
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
|
||||
// to drive the feedRouterFromListener tests without binding a real
|
||||
// socket — the production code path is a netstack-backed listener that
|
||||
// returns gVisor's "endpoint is in invalid state" forever after its
|
||||
// endpoint is destroyed.
|
||||
type scriptedAcceptListener struct {
|
||||
errs chan error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
|
||||
s := &scriptedAcceptListener{
|
||||
errs: make(chan error, len(errs)+1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for _, e := range errs {
|
||||
s.errs <- e
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err := <-s.errs:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
// errSentinel carries a literal error message so tests can synthesise
|
||||
// the exact gVisor text without importing the netstack package.
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
|
||||
// regression guard for the inbound side of the tight-loop bug. The
|
||||
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
|
||||
// invalid state" and exit, otherwise it pegs a CPU core and floods the
|
||||
// account-scoped log with the same accept error every iteration.
|
||||
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
gvisorErr := &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: addr,
|
||||
Err: errSentinel("endpoint is in invalid state"),
|
||||
}
|
||||
ln := newScriptedAcceptListener(gvisorErr)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected: loop recognised the gVisor error and returned.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
|
||||
// defence-in-depth path: an unknown sticky Accept error must NOT cause
|
||||
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
|
||||
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
const transientCount = 5
|
||||
errs := make([]error, transientCount)
|
||||
for i := range errs {
|
||||
errs[i] = errSentinel("transient: temporary network error")
|
||||
}
|
||||
ln := newScriptedAcceptListener(errs...)
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
|
||||
}()
|
||||
time.AfterFunc(150*time.Millisecond, cancel)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
|
||||
}
|
||||
|
||||
// Without backoff the 5 scripted errors would burn in microseconds.
|
||||
// With backoff the first delay alone is 5ms, so the loop must take
|
||||
// at least that long even though ctx fires at 150ms.
|
||||
elapsed := time.Since(start)
|
||||
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
|
||||
"loop ran without backing off — would burn CPU in production")
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// Create embedded NetBird client with the generated private key.
|
||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||
wgPort := int(n.clientCfg.WGPort)
|
||||
client, err := embed.New(embed.Options{
|
||||
embedOpts := embed.Options{
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
@@ -371,7 +371,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
Performance: n.clientCfg.Performance,
|
||||
})
|
||||
}
|
||||
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
|
||||
client, err := embed.New(embedOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
@@ -847,3 +849,53 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// logEmbedOptions emits a single structured INFO line summarising every
|
||||
// operationally meaningful flag handed to embed.New for this per-account
|
||||
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
|
||||
// boolean — never logged verbatim. Use this when an embedded peer
|
||||
// silently misbehaves: most failure modes (inbound drops, wrong
|
||||
// management URL, v6 unexpectedly on, userspace flipped, port clash)
|
||||
// are obvious from these flags before any traffic flows.
|
||||
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
|
||||
wgPort := 0
|
||||
if opts.WireguardPort != nil {
|
||||
wgPort = *opts.WireguardPort
|
||||
}
|
||||
mtu := uint16(0)
|
||||
if opts.MTU != nil {
|
||||
mtu = *opts.MTU
|
||||
}
|
||||
perfBuffers := uint32(0)
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
|
||||
}
|
||||
perfBatch := uint32(0)
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
perfBatch = *opts.Performance.MaxBatchSize
|
||||
}
|
||||
logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
"public_key": publicKey,
|
||||
"device_name": opts.DeviceName,
|
||||
"management_url": opts.ManagementURL,
|
||||
"log_level": opts.LogLevel,
|
||||
"wg_port": wgPort,
|
||||
"mtu": mtu,
|
||||
"block_inbound": opts.BlockInbound,
|
||||
"block_lan_access": opts.BlockLANAccess,
|
||||
"disable_ipv6": opts.DisableIPv6,
|
||||
"disable_client_routes": opts.DisableClientRoutes,
|
||||
"no_userspace": opts.NoUserspace,
|
||||
"config_path_set": opts.ConfigPath != "",
|
||||
"state_path_set": opts.StatePath != "",
|
||||
"private_key_present": opts.PrivateKey != "",
|
||||
"presharedkey_present": opts.PreSharedKey != "",
|
||||
"setup_key_present": opts.SetupKey != "",
|
||||
"jwt_token_present": opts.JWTToken != "",
|
||||
"dns_labels": opts.DNSLabels,
|
||||
"perf_buffers_per_pool": perfBuffers,
|
||||
"perf_max_batch_size": perfBatch,
|
||||
}).Info("starting embedded netbird client for account")
|
||||
}
|
||||
|
||||
85
proxy/internal/tcp/accept.go
Normal file
85
proxy/internal/tcp/accept.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
|
||||
// when Accept() is called on a listener whose underlying endpoint has
|
||||
// been destroyed (peer rekey, embedded-client reset, account churn).
|
||||
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
|
||||
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
|
||||
// so we fall back to a string check. Stable across the gVisor versions
|
||||
// netbird pins.
|
||||
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
|
||||
|
||||
// IsClosedListenerErr reports whether err signals that an accept loop
|
||||
// should exit because the underlying listener can no longer serve
|
||||
// connections. It recognises:
|
||||
//
|
||||
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
|
||||
// - gVisor's "endpoint is in invalid state" for netstack-backed
|
||||
// listeners whose endpoint was destroyed out from under them
|
||||
// (typically when a per-account WireGuard netstack is reset without
|
||||
// also tearing the listener entry down).
|
||||
//
|
||||
// Without the gVisor branch an accept loop on a netstack listener spins
|
||||
// CPU-hot forever after the endpoint dies, because Accept never blocks
|
||||
// again and the error neither matches net.ErrClosed nor cancels ctx.
|
||||
func IsClosedListenerErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
|
||||
}
|
||||
|
||||
// AcceptBackoff implements the exponential backoff used by
|
||||
// net/http.Server.Serve for transient Accept errors. Without it a loop
|
||||
// hitting a sticky unknown error burns a full CPU core. The zero value
|
||||
// is ready to use; call Reset after a successful Accept.
|
||||
type AcceptBackoff struct {
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
|
||||
// (net/http.Server.Serve) and keep us well below 1 log line per second
|
||||
// per orphaned listener.
|
||||
const (
|
||||
minAcceptDelay = 5 * time.Millisecond
|
||||
maxAcceptDelay = time.Second
|
||||
)
|
||||
|
||||
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
|
||||
// returns true when the wait completed. Returns false if ctx fired
|
||||
// during the wait — callers should treat that as "exit the loop".
|
||||
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
|
||||
b.advance()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(b.delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the accumulated delay so the next failure starts at the
|
||||
// minimum delay again. Call after a successful Accept.
|
||||
func (b *AcceptBackoff) Reset() { b.delay = 0 }
|
||||
|
||||
func (b *AcceptBackoff) advance() {
|
||||
if b.delay == 0 {
|
||||
b.delay = minAcceptDelay
|
||||
} else {
|
||||
b.delay *= 2
|
||||
}
|
||||
if b.delay > maxAcceptDelay {
|
||||
b.delay = maxAcceptDelay
|
||||
}
|
||||
}
|
||||
142
proxy/internal/tcp/accept_test.go
Normal file
142
proxy/internal/tcp/accept_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
|
||||
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
|
||||
// and IsClosedListenerErr must unwrap it.
|
||||
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
|
||||
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
|
||||
assert.True(t, IsClosedListenerErr(wrapped),
|
||||
"net.OpError wrapping net.ErrClosed must be recognised as closed")
|
||||
}
|
||||
|
||||
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
|
||||
// regression guard. A gVisor netstack listener whose endpoint has been
|
||||
// destroyed returns this exact text. Without recognising it the accept
|
||||
// loop spins forever and burns a CPU core.
|
||||
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
|
||||
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
|
||||
assert.True(t, IsClosedListenerErr(err),
|
||||
"gVisor 'endpoint is in invalid state' must be recognised as closed")
|
||||
}
|
||||
|
||||
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
|
||||
// transient errors must keep returning false so the backoff path runs.
|
||||
func TestIsClosedListenerErr_OtherError(t *testing.T) {
|
||||
cases := []error{
|
||||
errors.New("temporary failure"),
|
||||
errors.New("accept tcp 10.10.1.254:80: too many open files"),
|
||||
nil,
|
||||
}
|
||||
for _, c := range cases {
|
||||
assert.False(t, IsClosedListenerErr(c),
|
||||
"unexpected match on %v — must not be treated as closed", c)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
|
||||
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
|
||||
// real timer but uses tight bounds so a slow CI machine still passes.
|
||||
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
expected := []time.Duration{
|
||||
5 * time.Millisecond,
|
||||
10 * time.Millisecond,
|
||||
20 * time.Millisecond,
|
||||
40 * time.Millisecond,
|
||||
}
|
||||
for i, want := range expected {
|
||||
start := time.Now()
|
||||
ok := b.Backoff(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
|
||||
assert.GreaterOrEqual(t, elapsed, want,
|
||||
"backoff %d (%v) must wait at least the configured delay", i, want)
|
||||
assert.Less(t, elapsed, want*4,
|
||||
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
|
||||
}
|
||||
|
||||
// Burn enough rounds to reach the cap, then assert subsequent
|
||||
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
|
||||
// never exceed it.
|
||||
for range 6 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
assert.Equal(t, maxAcceptDelay, b.delay,
|
||||
"after enough doublings the delay must clamp to maxAcceptDelay")
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
|
||||
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
|
||||
// after recovery.
|
||||
func TestAcceptBackoff_Reset(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
for range 5 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
|
||||
|
||||
b.Reset()
|
||||
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok, "Backoff after Reset must complete")
|
||||
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
|
||||
"after Reset the next backoff must restart at minAcceptDelay")
|
||||
assert.Less(t, elapsed, 50*time.Millisecond,
|
||||
"after Reset the next backoff must NOT carry over the prior delay")
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
|
||||
// when ctx fires mid-wait. Without this, a tear-down would still take
|
||||
// up to 1 second per orphaned listener.
|
||||
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
// Drive the backoff up so the next call will wait ~1s — long
|
||||
// enough that we can detect early cancellation.
|
||||
for range 10 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
require.Equal(t, maxAcceptDelay, b.delay)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(ctx)
|
||||
elapsed := time.Since(start)
|
||||
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
|
||||
assert.Less(t, elapsed, 200*time.Millisecond,
|
||||
"cancellation must short-circuit the timer; took %v", elapsed)
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
|
||||
// loop exits without sleeping at all.
|
||||
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(ctx)
|
||||
elapsed := time.Since(start)
|
||||
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
|
||||
assert.Less(t, elapsed, 50*time.Millisecond,
|
||||
"already-cancelled ctx must return immediately; took %v", elapsed)
|
||||
}
|
||||
@@ -297,18 +297,23 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
}
|
||||
}()
|
||||
|
||||
var backoff AcceptBackoff
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ctx.Err() != nil || IsClosedListenerErr(err) {
|
||||
if ok := r.Drain(DefaultDrainTimeout); !ok {
|
||||
r.logger.Warn("timed out waiting for connections to drain")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
r.logger.Debugf("SNI router accept: %v; backing off", err)
|
||||
if !backoff.Backoff(ctx) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
backoff.Reset()
|
||||
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
|
||||
@@ -1836,3 +1836,132 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
|
||||
t.Fatal("TLS conn never reached the TLS channel")
|
||||
}
|
||||
}
|
||||
|
||||
// scriptedAcceptListener is a net.Listener whose Accept() returns
|
||||
// pre-scripted errors. Used by the accept-loop exit tests to simulate
|
||||
// the failure mode that triggers the tight-loop bug: a netstack
|
||||
// listener whose endpoint has been destroyed and now returns the gVisor
|
||||
// "endpoint is in invalid state" error from every Accept call.
|
||||
type scriptedAcceptListener struct {
|
||||
errs chan error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
|
||||
s := &scriptedAcceptListener{
|
||||
errs: make(chan error, len(errs)+1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for _, e := range errs {
|
||||
s.errs <- e
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err := <-s.errs:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
|
||||
// for the tight-loop bug: when the underlying netstack endpoint is
|
||||
// destroyed, Accept returns "endpoint is in invalid state" forever. The
|
||||
// loop must recognise that signal and return, otherwise it pegs a CPU
|
||||
// core and floods logs.
|
||||
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
router := NewRouter(logger, nil, addr)
|
||||
|
||||
gvisorErr := &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: addr,
|
||||
Err: errSentinel("endpoint is in invalid state"),
|
||||
}
|
||||
ln := newScriptedAcceptListener(gvisorErr)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- router.Serve(context.Background(), ln)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
|
||||
// depth path: when Accept returns an unknown transient error, the loop
|
||||
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
|
||||
// "Bounded call count" stands in for "no CPU spin" — without backoff
|
||||
// the goroutine would issue thousands of Accept calls in this window.
|
||||
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
router := NewRouter(logger, nil, addr)
|
||||
|
||||
const transientErrCount = 5
|
||||
errs := make([]error, transientErrCount)
|
||||
for i := range errs {
|
||||
errs[i] = errSentinel("transient: too many open files")
|
||||
}
|
||||
ln := newScriptedAcceptListener(errs...)
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
start := time.Now()
|
||||
go func() {
|
||||
done <- router.Serve(ctx, ln)
|
||||
}()
|
||||
|
||||
// Cancel after enough time for the backoff to climb (5ms + 10ms +
|
||||
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
|
||||
// loop would have made thousands of calls by now.
|
||||
time.AfterFunc(150*time.Millisecond, cancel)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
|
||||
}
|
||||
|
||||
// Without backoff the loop would burn through all 5 scripted errors
|
||||
// in microseconds and then block on the channel. With backoff the
|
||||
// total wall time should be at least 5ms (the first backoff).
|
||||
elapsed := time.Since(start)
|
||||
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
|
||||
"loop ran without backing off — would burn CPU in production")
|
||||
}
|
||||
|
||||
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
|
||||
// import the gVisor package without dragging in the whole netstack, so
|
||||
// the test uses the canonical string the production error formatter
|
||||
// emits — same shape IsClosedListenerErr matches in production.
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
|
||||
@@ -26,10 +26,6 @@ type Peer struct {
|
||||
|
||||
// a gRpc connection stream to the Peer
|
||||
Stream proto.SignalExchange_ConnectStreamServer
|
||||
// sendMu serializes writes to Stream. gRPC forbids concurrent SendMsg on
|
||||
// the same ServerStream, and a peer can be the target of many senders at
|
||||
// once.
|
||||
sendMu sync.Mutex
|
||||
|
||||
// registration time
|
||||
RegisteredAt time.Time
|
||||
@@ -37,13 +33,6 @@ type Peer struct {
|
||||
Cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Send writes a message to the peer's stream, serializing concurrent senders.
|
||||
func (p *Peer) Send(msg *proto.EncryptedMessage) error {
|
||||
p.sendMu.Lock()
|
||||
defer p.sendMu.Unlock()
|
||||
return p.Stream.Send(msg)
|
||||
}
|
||||
|
||||
// NewPeer creates a new instance of a connected Peer
|
||||
func NewPeer(id string, stream proto.SignalExchange_ConnectStreamServer, cancel context.CancelFunc) *Peer {
|
||||
return &Peer{
|
||||
|
||||
@@ -1,67 +0,0 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
"github.com/netbirdio/netbird/signal/peer"
|
||||
)
|
||||
|
||||
// concurrencyCheckStream records the maximum number of Send calls in flight at
|
||||
// once. gRPC forbids concurrent SendMsg on the same ServerStream, so a correct
|
||||
// server must never have more than one in flight per peer.
|
||||
type concurrencyCheckStream struct {
|
||||
proto.SignalExchange_ConnectStreamServer
|
||||
ctx context.Context
|
||||
inflight atomic.Int32
|
||||
maxSeen atomic.Int32
|
||||
}
|
||||
|
||||
func (s *concurrencyCheckStream) Send(*proto.EncryptedMessage) error {
|
||||
n := s.inflight.Add(1)
|
||||
for {
|
||||
old := s.maxSeen.Load()
|
||||
if n <= old || s.maxSeen.CompareAndSwap(old, n) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// Widen the window so overlapping callers are reliably observed.
|
||||
time.Sleep(time.Millisecond)
|
||||
s.inflight.Add(-1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *concurrencyCheckStream) Context() context.Context { return s.ctx }
|
||||
|
||||
// TestForwardMessageToPeerSerializesSend verifies that concurrent forwards to the
|
||||
// same peer never call Stream.Send concurrently, which would violate the gRPC
|
||||
// ServerStream contract.
|
||||
func TestForwardMessageToPeerSerializesSend(t *testing.T) {
|
||||
s, err := NewServer(context.Background(), otel.Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
const peerID = "peerX"
|
||||
stream := &concurrencyCheckStream{ctx: context.Background()}
|
||||
_, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
require.NoError(t, s.registry.Register(peer.NewPeer(peerID, stream, cancel)))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
s.forwardMessageToPeer(context.Background(), &proto.EncryptedMessage{Key: "sender", RemoteKey: peerID})
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int32(1), stream.maxSeen.Load(), "Stream.Send must never run concurrently on the same peer stream")
|
||||
}
|
||||
@@ -179,7 +179,7 @@ func (s *Server) forwardMessageToPeer(ctx context.Context, msg *proto.EncryptedM
|
||||
sendResultChan := make(chan error, 1)
|
||||
go func() {
|
||||
select {
|
||||
case sendResultChan <- dstPeer.Send(msg):
|
||||
case sendResultChan <- dstPeer.Stream.Send(msg):
|
||||
return
|
||||
case <-dstPeer.Stream.Context().Done():
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user