mirror of
https://github.com/netbirdio/netbird.git
synced 2026-06-19 06:19:54 +00:00
Compare commits
1 Commits
test-proxy
...
fix/ipv6-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d8b0310a4 |
@@ -64,7 +64,6 @@ import (
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
@@ -1078,11 +1077,17 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return ErrResetConnection
|
||||
}
|
||||
|
||||
if !e.config.DisableIPv6 && e.hasIPv6Changed(conf) {
|
||||
log.Infof("peer IPv6 address changed, restarting client")
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
if !e.config.DisableIPv6 {
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
if err != nil {
|
||||
log.Warnf("reconcile IPv6 from PeerConfig: %v", err)
|
||||
}
|
||||
if reset {
|
||||
log.Infof("peer IPv6 address changed value, restarting client")
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
}
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
@@ -1104,25 +1109,58 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// hasIPv6Changed reports whether the IPv6 overlay address in the peer config
|
||||
// differs from the configured address (added, removed, or changed).
|
||||
// Compares against e.config.WgAddr (not the interface address, which may have
|
||||
// been cleared by ClearIPv6 if OS assignment failed).
|
||||
func (e *Engine) hasIPv6Changed(conf *mgmProto.PeerConfig) bool {
|
||||
current := e.config.WgAddr
|
||||
// reconcileIPv6 applies the management-supplied IPv6 overlay address to the
|
||||
// engine's WireGuard interface in place when possible. Three transitions:
|
||||
//
|
||||
// - First v6 assignment (current had no v6, conf carries one): apply via
|
||||
// WGIface.UpdateAddr, no reset. Critical for embedded clients whose
|
||||
// boot config has no v6 — without this we reset on every fresh start
|
||||
// once management has v6 enabled, orphaning any netstack listeners
|
||||
// held outside the engine.
|
||||
// - v6 removed (current had v6, conf carries none): clear in place, no
|
||||
// reset.
|
||||
// - v6 swapped to a different non-empty value: returns reset=true so the
|
||||
// caller falls back to the engine-recreate path — the underlying
|
||||
// interface address can't be safely swapped in place across all
|
||||
// backends (gVisor netstack in particular fixes its address at
|
||||
// CreateNetTUN time).
|
||||
//
|
||||
// Mutates e.config.WgAddr to match the applied state so subsequent
|
||||
// PeerConfig comparisons are stable.
|
||||
func (e *Engine) reconcileIPv6(conf *mgmProto.PeerConfig) (reset bool, err error) {
|
||||
raw := conf.GetAddressV6()
|
||||
current := e.config.WgAddr
|
||||
|
||||
if len(raw) == 0 {
|
||||
return current.HasIPv6()
|
||||
if !current.HasIPv6() {
|
||||
return false, nil
|
||||
}
|
||||
current.ClearIPv6()
|
||||
e.config.WgAddr = current
|
||||
if err := e.wgInterface.UpdateAddr(current); err != nil {
|
||||
return false, fmt.Errorf("clear ipv6 on wg interface: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
prefix, err := netiputil.DecodePrefix(raw)
|
||||
if err != nil {
|
||||
log.Errorf("decode v6 overlay address: %v", err)
|
||||
return false
|
||||
incoming := current
|
||||
if err := incoming.SetIPv6FromCompact(raw); err != nil {
|
||||
return false, fmt.Errorf("decode v6 overlay address: %w", err)
|
||||
}
|
||||
|
||||
return !current.HasIPv6() || current.IPv6 != prefix.Addr() || current.IPv6Net != prefix.Masked()
|
||||
if !current.HasIPv6() {
|
||||
e.config.WgAddr = incoming
|
||||
if err := e.wgInterface.UpdateAddr(incoming); err != nil {
|
||||
return false, fmt.Errorf("apply ipv6 on wg interface: %w", err)
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if current.IPv6 == incoming.IPv6 && current.IPv6Net == incoming.IPv6Net {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (e *Engine) receiveJobEvents() {
|
||||
|
||||
305
client/internal/engine_reconcileipv6_test.go
Normal file
305
client/internal/engine_reconcileipv6_test.go
Normal file
@@ -0,0 +1,305 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/netip"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
)
|
||||
|
||||
// reconcileIPv6 / updateConfig regression suite. Locks down the behavior that
|
||||
// PR #5631 (main-side IPv6 overlay support) accidentally broke for embedded
|
||||
// netstack clients: any first NetworkMap update that brings an IPv6 address
|
||||
// used to trigger ErrResetConnection, which destroys the netstack and orphans
|
||||
// every listener bound on it (proxy-side inbound listeners in particular).
|
||||
// The fix in reconcileIPv6 distinguishes "v6 first-assigned" (apply in place)
|
||||
// from "v6 swapped value" (must reset).
|
||||
|
||||
func mustEncodeV6Prefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
require.NoError(t, err, "encode v6 prefix %s", p)
|
||||
return b
|
||||
}
|
||||
|
||||
// reconcileIPv6Fixture builds the smallest Engine the function under test
|
||||
// needs: a config (with WgAddr being the load-bearing field) and a wgInterface
|
||||
// whose UpdateAddr call we can observe.
|
||||
func reconcileIPv6Fixture(t *testing.T, initial wgaddr.Address) (*Engine, *MockWGIface, *wgaddr.Address) {
|
||||
t.Helper()
|
||||
var applied wgaddr.Address
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return initial },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
applied = a
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
t.Cleanup(cancel)
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: initial},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
return e, mock, &applied
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_FirstAssignment_AppliesInPlace(t *testing.T) {
|
||||
// Embedded clients boot v4-only; management later assigns a v6 overlay.
|
||||
// The fix: apply v6 in place, return reset=false. Pre-fix this case
|
||||
// fell through to the "v6 changed" branch and reset the engine.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
e, mock, applied := reconcileIPv6Fixture(t, v4)
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "first v6 assignment must NOT request an engine reset")
|
||||
|
||||
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the new v6")
|
||||
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6, "engine config v6 address must match")
|
||||
assert.Equal(t, v6Prefix.Masked(), e.config.WgAddr.IPv6Net, "engine config v6 prefix must match")
|
||||
|
||||
require.True(t, applied.HasIPv6(), "WGIface.UpdateAddr must be called with v6 populated")
|
||||
assert.Equal(t, v6Prefix.Addr(), applied.IPv6, "UpdateAddr must carry the new v6")
|
||||
_ = mock
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_NoChange_NoOp(t *testing.T) {
|
||||
// Steady state: management redelivers the same PeerConfig. No interface
|
||||
// mutation, no reset. Guards against an infinite reset loop if the
|
||||
// comparison ever drifts (e.g. address-vs-prefix masking bugs).
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "unchanged v6 must NOT trigger reset")
|
||||
assert.False(t, updateAddrCalled, "unchanged v6 must NOT call UpdateAddr")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_Removed_AppliesInPlace(t *testing.T) {
|
||||
// Management withdraws v6 (e.g. account toggled off the v6 group).
|
||||
// Cleared in place, no reset.
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, v6Prefix)))
|
||||
|
||||
e, _, applied := reconcileIPv6Fixture(t, addr)
|
||||
e.config.WgAddr = addr
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: nil,
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, reset, "v6 removed must NOT trigger reset")
|
||||
|
||||
assert.False(t, e.config.WgAddr.HasIPv6(), "engine config must reflect v6 cleared")
|
||||
assert.False(t, applied.HasIPv6(), "UpdateAddr must receive cleared v6")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_PrefixLengthChanged_RequestsReset(t *testing.T) {
|
||||
// Same v6 host, different mask (e.g. /64 → /80). Treated like a value
|
||||
// change because the new netmask redefines the broadcast/scope.
|
||||
oldPrefix := netip.MustParsePrefix("fd00::1/64")
|
||||
newPrefix := netip.MustParsePrefix("fd00::1/80")
|
||||
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, newPrefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reset, "v6 prefix length change must request a reset")
|
||||
assert.False(t, updateAddrCalled, "v6 prefix length change must NOT touch the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_ValueChanged_RequestsReset(t *testing.T) {
|
||||
// v6 was X, now Y. The netstack backend can't safely swap an existing
|
||||
// address in place — fall back to the engine recreate path.
|
||||
oldPrefix := netip.MustParsePrefix("fd00::1/64")
|
||||
newPrefix := netip.MustParsePrefix("fd00::2/64")
|
||||
|
||||
addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
require.NoError(t, addr.SetIPv6FromCompact(mustEncodeV6Prefix(t, oldPrefix)))
|
||||
|
||||
updateAddrCalled := false
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return addr },
|
||||
UpdateAddrFunc: func(a wgaddr.Address) error {
|
||||
updateAddrCalled = true
|
||||
return nil
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: addr},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: addr.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, newPrefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, reset, "v6 value change must request a reset")
|
||||
assert.False(t, updateAddrCalled,
|
||||
"v6 value change must NOT call UpdateAddr — caller will recreate the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_InvalidBytes_ReturnsError(t *testing.T) {
|
||||
// Corrupt PeerConfig.AddressV6 must not crash the engine and must not
|
||||
// trigger a spurious reset.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
e, _, applied := reconcileIPv6Fixture(t, v4)
|
||||
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: []byte{0x00}, // truncated, definitely not a valid prefix
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.Error(t, err, "malformed v6 bytes must surface an error")
|
||||
assert.False(t, reset, "decode error must NOT request a reset")
|
||||
assert.False(t, applied.HasIPv6(), "decode error must NOT touch the interface")
|
||||
}
|
||||
|
||||
func TestReconcileIPv6_UpdateAddrError_DoesNotPropagateReset(t *testing.T) {
|
||||
// If WGIface.UpdateAddr fails (e.g. OS-side assignment error on a
|
||||
// kernel device), reconcileIPv6 returns the error to the caller for
|
||||
// logging — but it must NOT request a reset. The whole point of the
|
||||
// fix is to AVOID the reset cascade on v6 transitions.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return v4 },
|
||||
UpdateAddrFunc: func(_ wgaddr.Address) error { return errors.New("os refused address") },
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: v4},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
}
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
reset, err := e.reconcileIPv6(conf)
|
||||
require.Error(t, err, "UpdateAddr failure must surface")
|
||||
assert.False(t, reset, "UpdateAddr failure must NOT request a reset")
|
||||
}
|
||||
|
||||
func TestUpdateConfig_V6FirstAssignment_DoesNotResetEngine(t *testing.T) {
|
||||
// The integration check: updateConfig must not return ErrResetConnection
|
||||
// when the only change between current state and the new PeerConfig is
|
||||
// "v6 added". Pre-fix this returned ErrResetConnection, tearing down
|
||||
// every listener bound on the engine's netstack.
|
||||
v4 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
mock := &MockWGIface{
|
||||
AddressFunc: func() wgaddr.Address { return v4 },
|
||||
UpdateAddrFunc: func(_ wgaddr.Address) error { return nil },
|
||||
IsUserspaceBindFunc: func() bool { return true },
|
||||
}
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
e := &Engine{
|
||||
ctx: ctx,
|
||||
clientCtx: ctx,
|
||||
clientCancel: cancel,
|
||||
config: &EngineConfig{WgAddr: v4, WgPort: 51820},
|
||||
wgInterface: mock,
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
statusRecorder: peer.NewRecorder("https://mgm.test"),
|
||||
}
|
||||
|
||||
v6Prefix := netip.MustParsePrefix("fd00::1/64")
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
Address: v4.String(),
|
||||
AddressV6: mustEncodeV6Prefix(t, v6Prefix),
|
||||
}
|
||||
|
||||
err := e.updateConfig(conf)
|
||||
assert.NoError(t, err,
|
||||
"updateConfig MUST NOT return ErrResetConnection when v6 is added for the first time — that's the bug fix")
|
||||
assert.NotErrorIs(t, err, ErrResetConnection)
|
||||
|
||||
require.True(t, e.config.WgAddr.HasIPv6(), "engine config must record the assigned v6 after updateConfig")
|
||||
assert.Equal(t, v6Prefix.Addr(), e.config.WgAddr.IPv6)
|
||||
}
|
||||
@@ -66,7 +66,6 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgmt "github.com/netbirdio/netbird/shared/management/client"
|
||||
mgmtProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/shared/netiputil"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||
"github.com/netbirdio/netbird/shared/signal/proto"
|
||||
@@ -1707,82 +1706,12 @@ func getPeers(e *Engine) int {
|
||||
return len(e.peerStore.PeersPubKey())
|
||||
}
|
||||
|
||||
func mustEncodePrefix(t *testing.T, p netip.Prefix) []byte {
|
||||
t.Helper()
|
||||
b, err := netiputil.EncodePrefix(p)
|
||||
require.NoError(t, err)
|
||||
return b
|
||||
}
|
||||
|
||||
func TestEngine_hasIPv6Changed(t *testing.T) {
|
||||
v4Only := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
v4v6 := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
v4v6.IPv6 = netip.MustParseAddr("fd00::1")
|
||||
v4v6.IPv6Net = netip.MustParsePrefix("fd00::1/64").Masked()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
current wgaddr.Address
|
||||
confV6 []byte
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no v6 before, no v6 now",
|
||||
current: v4Only,
|
||||
confV6: nil,
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "no v6 before, v6 added",
|
||||
current: v4Only,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, now removed",
|
||||
current: v4v6,
|
||||
confV6: nil,
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "had v6, same v6",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/64")),
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "had v6, different v6",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::2/64")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "same v6 addr, different prefix length",
|
||||
current: v4v6,
|
||||
confV6: mustEncodePrefix(t, netip.MustParsePrefix("fd00::1/80")),
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "decode error keeps status quo",
|
||||
current: v4Only,
|
||||
confV6: []byte{1, 2, 3},
|
||||
expected: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
engine := &Engine{
|
||||
config: &EngineConfig{WgAddr: tt.current},
|
||||
}
|
||||
conf := &mgmtProto.PeerConfig{
|
||||
AddressV6: tt.confV6,
|
||||
}
|
||||
assert.Equal(t, tt.expected, engine.hasIPv6Changed(conf))
|
||||
})
|
||||
}
|
||||
}
|
||||
// The former TestEngine_hasIPv6Changed has been superseded by
|
||||
// engine_reconcileipv6_test.go — the underlying function (hasIPv6Changed)
|
||||
// was replaced by reconcileIPv6, which applies "v6 added" / "v6 removed"
|
||||
// in place instead of demanding a full engine reset. The behavioral
|
||||
// matrix the old test enforced is now covered, with corrected expectations,
|
||||
// by TestReconcileIPv6_* in that sibling file.
|
||||
|
||||
func TestFilterAllowedIPs(t *testing.T) {
|
||||
v4v6Addr := wgaddr.MustParseWGAddress("100.64.0.1/16")
|
||||
|
||||
@@ -434,7 +434,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
||||
t.Helper()
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(context.Background(), testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(context.Background(), testCacheStore(t))
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
srv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
return srv
|
||||
}
|
||||
|
||||
@@ -723,7 +723,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
@@ -1147,7 +1147,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
||||
|
||||
tokenStore := nbgrpc.NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := nbgrpc.NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil, nil)
|
||||
proxySrv := nbgrpc.NewProxyServiceServer(nil, tokenStore, pkceStore, nbgrpc.ProxyOIDCConfig{}, nil, nil, nil, nil)
|
||||
|
||||
proxyController, err := proxymanager.NewGRPCController(proxySrv, noop.NewMeterProvider().Meter(""))
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -219,7 +219,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
||||
|
||||
func (s *BaseServer) ReverseProxyGRPCServer() *nbgrpc.ProxyServiceServer {
|
||||
return Create(s, func() *nbgrpc.ProxyServiceServer {
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.IdpManager(), s.ProxyManager(), s.Store())
|
||||
proxyService := nbgrpc.NewProxyServiceServer(s.AccessLogsManager(), s.ProxyTokenStore(), s.PKCEVerifierStore(), s.proxyOIDCConfig(), s.PeersManager(), s.UsersManager(), s.ProxyManager(), s.Store())
|
||||
s.AfterInit(func(s *BaseServer) {
|
||||
proxyService.SetServiceManager(s.ServiceManager())
|
||||
proxyService.SetProxyController(s.ServiceProxyController())
|
||||
|
||||
@@ -33,8 +33,6 @@ import (
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
rpservice "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/sessionkey"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/management/server/users"
|
||||
proxyauth "github.com/netbirdio/netbird/proxy/auth"
|
||||
@@ -84,9 +82,6 @@ type ProxyServiceServer struct {
|
||||
// Manager for users
|
||||
usersManager users.Manager
|
||||
|
||||
// Manager for IdP-enriched user data (may be nil when no IdP is configured)
|
||||
idpManager idp.Manager
|
||||
|
||||
// Store for one-time authentication tokens
|
||||
tokenStore *OneTimeTokenStore
|
||||
|
||||
@@ -162,7 +157,7 @@ func enforceAccountScope(ctx context.Context, requestAccountID string) error {
|
||||
}
|
||||
|
||||
// NewProxyServiceServer creates a new proxy service server.
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, idpManager idp.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeTokenStore, pkceStore *PKCEVerifierStore, oidcConfig ProxyOIDCConfig, peersManager peers.Manager, usersManager users.Manager, proxyMgr proxy.Manager, tokenChecker ProxyTokenChecker) *ProxyServiceServer {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
s := &ProxyServiceServer{
|
||||
accessLogManager: accessLogMgr,
|
||||
@@ -171,7 +166,6 @@ func NewProxyServiceServer(accessLogMgr accesslogs.Manager, tokenStore *OneTimeT
|
||||
pkceVerifierStore: pkceStore,
|
||||
peersManager: peersManager,
|
||||
usersManager: usersManager,
|
||||
idpManager: idpManager,
|
||||
proxyManager: proxyMgr,
|
||||
tokenChecker: tokenChecker,
|
||||
snapshotBatchSize: snapshotBatchSizeFromEnv(),
|
||||
@@ -1708,7 +1702,22 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}
|
||||
|
||||
groupIDs, groupNames := pairGroupIDsAndNames(peerGroups)
|
||||
principalID, displayIdentity := s.getTunnelPeerInfo(ctx, domain, service, peer)
|
||||
|
||||
// Resolve the principal: when the peer is linked to a user, the human
|
||||
// is the principal so multiple peers owned by the same user share a
|
||||
// single identity. Unlinked peers (machine agents) are their own
|
||||
// principal keyed on peer.ID. displayIdentity is what upstream gateways
|
||||
// tag spend with — user.Email when linked, peer.Name when not.
|
||||
principalID := peer.ID
|
||||
displayIdentity := peer.Name
|
||||
if peer.UserID != "" {
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := checkPeerGroupAccess(service, groupIDs); err != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "peer_id": peer.ID, "error": err.Error()}).Debug("ValidateTunnelPeer: access denied")
|
||||
@@ -1745,45 +1754,6 @@ func (s *ProxyServiceServer) ValidateTunnelPeer(ctx context.Context, req *proto.
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getTunnelPeerInfo returns the principal ID and display name for a peer, e.g. a
|
||||
// user or peer ID, and peer name or user email.
|
||||
func (s *ProxyServiceServer) getTunnelPeerInfo(ctx context.Context, domain string, service *rpservice.Service, peer *peer.Peer) (string, string) {
|
||||
// Resolve the principal: when the peer is linked to a user, the human is the
|
||||
// principal so multiple peers owned by the same user share a single
|
||||
// identity. Unlinked peers (machine agents) are their own principal keyed on
|
||||
// peer.ID. displayIdentity is what upstream gateways tag spend with —
|
||||
// user.Email when linked, peer.Name when not.
|
||||
|
||||
// If the peer isn't associated with a user, return the peer info directly.
|
||||
if peer.UserID == "" {
|
||||
return peer.ID, peer.Name
|
||||
}
|
||||
|
||||
// Otherwise, if the peer is linked to a user, the user is the principal and
|
||||
// if an IdP is available, we gather details on the user from it.
|
||||
principalID := peer.UserID
|
||||
displayIdentity := peer.Name
|
||||
// Stored column first (cheap, but often empty for OIDC-provisioned users).
|
||||
if user, uerr := s.usersManager.GetUser(ctx, peer.UserID); uerr == nil && user != nil {
|
||||
principalID = user.Id
|
||||
if user.Email != "" {
|
||||
displayIdentity = user.Email
|
||||
}
|
||||
}
|
||||
// IdP enrichment wins when available — the stored email column is a
|
||||
// best-effort cache and is frequently empty for OIDC users. Enrichment
|
||||
// failures must never fail the RPC; we simply keep the stored/peer identity.
|
||||
if s.idpManager != nil {
|
||||
if ud, uerr := s.idpManager.GetUserDataByID(ctx, peer.UserID, idp.AppMetadata{WTAccountID: service.AccountID}); uerr == nil && ud != nil && ud.Email != "" {
|
||||
displayIdentity = ud.Email
|
||||
} else if uerr != nil {
|
||||
log.WithFields(log.Fields{"domain": domain, "user_id": peer.UserID, "error": uerr.Error()}).Debug("ValidateTunnelPeer: IdP user enrichment failed; using stored/peer identity")
|
||||
}
|
||||
}
|
||||
|
||||
return principalID, displayIdentity
|
||||
}
|
||||
|
||||
// checkPeerGroupAccess gates ValidateTunnelPeer by the service's required
|
||||
// groups. Private services authorise against AccessGroups (empty list fails
|
||||
// closed — Validate() rejects that at save time but the RPC is the security
|
||||
|
||||
@@ -3,19 +3,14 @@ package grpc
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/management/internals/modules/peers"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/service"
|
||||
"github.com/netbirdio/netbird/management/server/idp"
|
||||
"github.com/netbirdio/netbird/management/server/peer"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
"github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
type mockReverseProxyManager struct {
|
||||
@@ -142,52 +137,6 @@ func (m *mockUsersManager) GetUserWithGroups(ctx context.Context, userID string)
|
||||
return user, nil, nil
|
||||
}
|
||||
|
||||
// mockTunnelPeersManager implements only the two peers.Manager methods that
|
||||
// ValidateTunnelPeer calls; the embedded interface satisfies the rest (and
|
||||
// panics if any unexpected method is invoked).
|
||||
type mockTunnelPeersManager struct {
|
||||
peers.Manager
|
||||
peer *peer.Peer
|
||||
peerErr error
|
||||
groups []*types.Group
|
||||
groupsErr error
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerByTunnelIP(_ context.Context, _ string, _ net.IP) (*peer.Peer, error) {
|
||||
return m.peer, m.peerErr
|
||||
}
|
||||
|
||||
func (m *mockTunnelPeersManager) GetPeerWithGroups(_ context.Context, _, _ string) (*peer.Peer, []*types.Group, error) {
|
||||
return m.peer, m.groups, m.groupsErr
|
||||
}
|
||||
|
||||
// mockTunnelIdpManager implements only GetUserDataByID; the embedded interface
|
||||
// satisfies the rest of idp.Manager. hasData==false returns (nil, nil) to model
|
||||
// an IdP that knows nothing about the user.
|
||||
type mockTunnelIdpManager struct {
|
||||
idp.Manager
|
||||
email string
|
||||
hasData bool
|
||||
err error
|
||||
gotCalls int
|
||||
gotMeta []idp.AppMetadata
|
||||
}
|
||||
|
||||
func (m *mockTunnelIdpManager) GetUserDataByID(_ context.Context, userID string, meta idp.AppMetadata) (*idp.UserData, error) {
|
||||
m.gotCalls++
|
||||
m.gotMeta = append(m.gotMeta, meta)
|
||||
if m.err != nil {
|
||||
return nil, m.err
|
||||
}
|
||||
if !m.hasData {
|
||||
// This might not be a thing any of the actual IDP implementations do,
|
||||
// i.e. return a nil value with no error, but it seems valuable to test
|
||||
// that behavior here.
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
return &idp.UserData{ID: userID, Email: m.email}, nil
|
||||
}
|
||||
|
||||
func TestValidateUserGroupAccess(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -405,163 +354,6 @@ func TestValidateUserGroupAccess(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestValidateTunnelPeerUserEmailEnrichment verifies the UserEmail/UserId
|
||||
// resolution in ValidateTunnelPeer, including the IdP-enrichment fallback order
|
||||
// (IdP email -> stored User.Email -> peer.Name).
|
||||
func TestValidateTunnelPeerUserEmailEnrichment(t *testing.T) {
|
||||
const (
|
||||
domain = "app.example.com"
|
||||
accountID = "account1"
|
||||
peerID = "peer1"
|
||||
peerName = "peer-display-name"
|
||||
userID = "user1"
|
||||
)
|
||||
|
||||
storedUser := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: "stored@example.com"}}
|
||||
storedUserNoEmail := map[string]*types.User{userID: {Id: userID, AccountID: accountID, Email: ""}}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
peerUserID string
|
||||
storedUsers map[string]*types.User
|
||||
storedErr error
|
||||
noIdP bool
|
||||
idpEmail string
|
||||
idpHasData bool
|
||||
idpErr error
|
||||
expectEmail string
|
||||
expectUserID string
|
||||
expectIdPHit bool
|
||||
}{
|
||||
{
|
||||
name: "idp email wins over stored email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp returns empty email",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "",
|
||||
idpHasData: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp has no data",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpHasData: false,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when idp errors",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
idpErr: errors.New("idp unreachable"),
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "stored email when no idp manager",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUser,
|
||||
noIdP: true,
|
||||
expectEmail: "stored@example.com",
|
||||
expectUserID: userID,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored email is empty",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "idp email when stored user missing keeps peer.UserID as principal",
|
||||
peerUserID: userID,
|
||||
storedUsers: map[string]*types.User{},
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: "idp@example.com",
|
||||
expectUserID: userID,
|
||||
expectIdPHit: true,
|
||||
},
|
||||
{
|
||||
name: "unlinked peer uses peer name and never consults idp",
|
||||
peerUserID: "",
|
||||
storedUsers: storedUser,
|
||||
idpEmail: "idp@example.com",
|
||||
idpHasData: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: peerID,
|
||||
expectIdPHit: false,
|
||||
},
|
||||
{
|
||||
name: "linked peer with empty stored email and no idp falls back to peer name",
|
||||
peerUserID: userID,
|
||||
storedUsers: storedUserNoEmail,
|
||||
noIdP: true,
|
||||
expectEmail: peerName,
|
||||
expectUserID: userID,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
svc := &service.Service{Domain: domain, AccountID: accountID}
|
||||
server := &ProxyServiceServer{
|
||||
serviceManager: &mockReverseProxyManager{
|
||||
proxiesByAccount: map[string][]*service.Service{accountID: {svc}},
|
||||
},
|
||||
peersManager: &mockTunnelPeersManager{
|
||||
peer: &peer.Peer{ID: peerID, Name: peerName, UserID: tt.peerUserID},
|
||||
},
|
||||
usersManager: &mockUsersManager{users: tt.storedUsers, err: tt.storedErr},
|
||||
}
|
||||
|
||||
var idpMock *mockTunnelIdpManager
|
||||
if !tt.noIdP {
|
||||
idpMock = &mockTunnelIdpManager{email: tt.idpEmail, hasData: tt.idpHasData, err: tt.idpErr}
|
||||
server.idpManager = idpMock
|
||||
}
|
||||
|
||||
resp, err := server.ValidateTunnelPeer(context.Background(), &proto.ValidateTunnelPeerRequest{
|
||||
Domain: domain,
|
||||
TunnelIp: "100.64.0.1",
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, resp)
|
||||
assert.True(t, resp.GetValid(), "expected access granted")
|
||||
assert.Equal(t, tt.expectEmail, resp.GetUserEmail())
|
||||
assert.Equal(t, tt.expectUserID, resp.GetUserId())
|
||||
|
||||
if idpMock != nil {
|
||||
if tt.expectIdPHit {
|
||||
assert.Equal(t, 1, idpMock.gotCalls, "expected IdP to be consulted")
|
||||
require.Len(t, idpMock.gotMeta, 1)
|
||||
assert.Equal(t, accountID, idpMock.gotMeta[0].WTAccountID)
|
||||
} else {
|
||||
assert.Equal(t, 0, idpMock.gotCalls, "expected IdP to not be consulted")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAccountProxyByDomain(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupValidateSessionTest(t *testing.T) *validateSessionTestSetup {
|
||||
tokenStore := NewOneTimeTokenStore(ctx, testCacheStore(t))
|
||||
pkceStore := NewPKCEVerifierStore(ctx, testCacheStore(t))
|
||||
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, nil, proxyManager, nil)
|
||||
proxyService := NewProxyServiceServer(nil, tokenStore, pkceStore, ProxyOIDCConfig{}, nil, usersManager, proxyManager, nil)
|
||||
proxyService.SetServiceManager(serviceManager)
|
||||
|
||||
createTestProxies(t, ctx, testStore)
|
||||
|
||||
@@ -3215,7 +3215,7 @@ func createManager(t testing.TB) (*DefaultAccountManager, *update_channel.PeersU
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, nil, proxyManager, nil)
|
||||
proxyGrpcServer := nbgrpc.NewProxyServiceServer(nil, nil, nil, nbgrpc.ProxyOIDCConfig{}, peersManager, nil, proxyManager, nil)
|
||||
proxyController, err := proxymanager.NewGRPCController(proxyGrpcServer, noop.Meter{})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
|
||||
@@ -217,7 +217,6 @@ func setupAuthCallbackTest(t *testing.T) *testSetup {
|
||||
usersManager,
|
||||
nil,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
proxyService.SetServiceManager(&testServiceManager{store: testStore})
|
||||
|
||||
@@ -110,7 +110,7 @@ func BuildApiBlackBoxWithDBState(t testing_tools.TB, sqlFile string, expectedPee
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
@@ -240,7 +240,7 @@ func BuildApiBlackBoxWithDBStateAndPeerChannel(t testing_tools.TB, sqlFile strin
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create proxy manager: %v", err)
|
||||
}
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, nil, proxyMgr, nil)
|
||||
proxyServiceServer := nbgrpc.NewProxyServiceServer(accessLogsManager, proxyTokenStore, pkceverifierStore, nbgrpc.ProxyOIDCConfig{}, peersManager, userManager, proxyMgr, nil)
|
||||
domainManager := manager.NewManager(store, proxyMgr, permissionsManager, am)
|
||||
serviceProxyController, err := proxymanager.NewGRPCController(proxyServiceServer, noopMeter)
|
||||
if err != nil {
|
||||
|
||||
@@ -466,15 +466,20 @@ func feedRouterFromListener(ctx context.Context, ln net.Listener, router *nbtcp.
|
||||
_ = ln.Close()
|
||||
}()
|
||||
|
||||
var backoff nbtcp.AcceptBackoff
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ctx.Err() != nil || nbtcp.IsClosedListenerErr(err) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v; backing off", err)
|
||||
if !backoff.Backoff(ctx) {
|
||||
return
|
||||
}
|
||||
logger.WithField("account_id", accountID).Debugf("plain inbound accept: %v", err)
|
||||
continue
|
||||
}
|
||||
backoff.Reset()
|
||||
router.HandleConn(ctx, conn)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -533,3 +533,125 @@ MHcCAQEEIIrYSSNQFaA2Hwf1duRSxKtLYX5CB04fSeQ6tF1aY/PuoAoGCCqGSM49
|
||||
AwEHoUQDQgAEPR3tU2Fta9ktY+6P9G0cWO+0kETA6SFs38GecTyudlHz6xvCdz8q
|
||||
EKTcWGekdmdDPsHloRNtsiCa697B2O9IFA==
|
||||
-----END EC PRIVATE KEY-----`)
|
||||
|
||||
// scriptedAcceptListener returns pre-scripted errors from Accept(). Used
|
||||
// to drive the feedRouterFromListener tests without binding a real
|
||||
// socket — the production code path is a netstack-backed listener that
|
||||
// returns gVisor's "endpoint is in invalid state" forever after its
|
||||
// endpoint is destroyed.
|
||||
type scriptedAcceptListener struct {
|
||||
errs chan error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
|
||||
s := &scriptedAcceptListener{
|
||||
errs: make(chan error, len(errs)+1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for _, e := range errs {
|
||||
s.errs <- e
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err := <-s.errs:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
// errSentinel carries a literal error message so tests can synthesise
|
||||
// the exact gVisor text without importing the netstack package.
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
// TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint is the
|
||||
// regression guard for the inbound side of the tight-loop bug. The
|
||||
// per-account plain-HTTP feeder must recognise gVisor's "endpoint is in
|
||||
// invalid state" and exit, otherwise it pegs a CPU core and floods the
|
||||
// account-scoped log with the same accept error every iteration.
|
||||
func TestFeedRouterFromListener_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
gvisorErr := &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: addr,
|
||||
Err: errSentinel("endpoint is in invalid state"),
|
||||
}
|
||||
ln := newScriptedAcceptListener(gvisorErr)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(context.Background(), ln, router, logger, "acct-1")
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected: loop recognised the gVisor error and returned.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestFeedRouterFromListener_BacksOffOnTransientError asserts the
|
||||
// defence-in-depth path: an unknown sticky Accept error must NOT cause
|
||||
// CPU spin. The loop backs off and exits cleanly when ctx is cancelled.
|
||||
func TestFeedRouterFromListener_BacksOffOnTransientError(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 80}
|
||||
router := nbtcp.NewRouter(logger, nil, addr)
|
||||
|
||||
const transientCount = 5
|
||||
errs := make([]error, transientCount)
|
||||
for i := range errs {
|
||||
errs[i] = errSentinel("transient: temporary network error")
|
||||
}
|
||||
ln := newScriptedAcceptListener(errs...)
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
start := time.Now()
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
feedRouterFromListener(ctx, ln, router, logger, "acct-1")
|
||||
}()
|
||||
time.AfterFunc(150*time.Millisecond, cancel)
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
// Expected.
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("feedRouterFromListener did not exit on ctx cancellation — backoff or exit path broken")
|
||||
}
|
||||
|
||||
// Without backoff the 5 scripted errors would burn in microseconds.
|
||||
// With backoff the first delay alone is 5ms, so the loop must take
|
||||
// at least that long even though ctx fires at 150ms.
|
||||
elapsed := time.Since(start)
|
||||
assert.GreaterOrEqual(t, elapsed, 5*time.Millisecond,
|
||||
"loop ran without backing off — would burn CPU in production")
|
||||
}
|
||||
|
||||
@@ -356,7 +356,7 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
// Create embedded NetBird client with the generated private key.
|
||||
// The peer has already been created via CreateProxyPeer RPC with the public key.
|
||||
wgPort := int(n.clientCfg.WGPort)
|
||||
client, err := embed.New(embed.Options{
|
||||
embedOpts := embed.Options{
|
||||
DeviceName: deviceNamePrefix + n.proxyID,
|
||||
ManagementURL: n.clientCfg.MgmtAddr,
|
||||
PrivateKey: privateKey.String(),
|
||||
@@ -371,7 +371,9 @@ func (n *NetBird) createClientEntry(ctx context.Context, accountID types.Account
|
||||
WireguardPort: &wgPort,
|
||||
PreSharedKey: n.clientCfg.PreSharedKey,
|
||||
Performance: n.clientCfg.Performance,
|
||||
})
|
||||
}
|
||||
logEmbedOptions(n.logger, accountID, serviceID, publicKey.String(), embedOpts)
|
||||
client, err := embed.New(embedOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create netbird client: %w", err)
|
||||
}
|
||||
@@ -847,3 +849,53 @@ func DirectUpstreamFromContext(ctx context.Context) bool {
|
||||
v, _ := ctx.Value(directUpstreamContextKey{}).(bool)
|
||||
return v
|
||||
}
|
||||
|
||||
// logEmbedOptions emits a single structured INFO line summarising every
|
||||
// operationally meaningful flag handed to embed.New for this per-account
|
||||
// client. Secrets (PrivateKey, PreSharedKey) are reduced to a "present"
|
||||
// boolean — never logged verbatim. Use this when an embedded peer
|
||||
// silently misbehaves: most failure modes (inbound drops, wrong
|
||||
// management URL, v6 unexpectedly on, userspace flipped, port clash)
|
||||
// are obvious from these flags before any traffic flows.
|
||||
func logEmbedOptions(logger *log.Logger, accountID types.AccountID, serviceID types.ServiceID, publicKey string, opts embed.Options) {
|
||||
wgPort := 0
|
||||
if opts.WireguardPort != nil {
|
||||
wgPort = *opts.WireguardPort
|
||||
}
|
||||
mtu := uint16(0)
|
||||
if opts.MTU != nil {
|
||||
mtu = *opts.MTU
|
||||
}
|
||||
perfBuffers := uint32(0)
|
||||
if opts.Performance.PreallocatedBuffersPerPool != nil {
|
||||
perfBuffers = *opts.Performance.PreallocatedBuffersPerPool
|
||||
}
|
||||
perfBatch := uint32(0)
|
||||
if opts.Performance.MaxBatchSize != nil {
|
||||
perfBatch = *opts.Performance.MaxBatchSize
|
||||
}
|
||||
logger.WithFields(log.Fields{
|
||||
"account_id": accountID,
|
||||
"service_id": serviceID,
|
||||
"public_key": publicKey,
|
||||
"device_name": opts.DeviceName,
|
||||
"management_url": opts.ManagementURL,
|
||||
"log_level": opts.LogLevel,
|
||||
"wg_port": wgPort,
|
||||
"mtu": mtu,
|
||||
"block_inbound": opts.BlockInbound,
|
||||
"block_lan_access": opts.BlockLANAccess,
|
||||
"disable_ipv6": opts.DisableIPv6,
|
||||
"disable_client_routes": opts.DisableClientRoutes,
|
||||
"no_userspace": opts.NoUserspace,
|
||||
"config_path_set": opts.ConfigPath != "",
|
||||
"state_path_set": opts.StatePath != "",
|
||||
"private_key_present": opts.PrivateKey != "",
|
||||
"presharedkey_present": opts.PreSharedKey != "",
|
||||
"setup_key_present": opts.SetupKey != "",
|
||||
"jwt_token_present": opts.JWTToken != "",
|
||||
"dns_labels": opts.DNSLabels,
|
||||
"perf_buffers_per_pool": perfBuffers,
|
||||
"perf_max_batch_size": perfBatch,
|
||||
}).Info("starting embedded netbird client for account")
|
||||
}
|
||||
|
||||
85
proxy/internal/tcp/accept.go
Normal file
85
proxy/internal/tcp/accept.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// gvisorInvalidEndpointMsg is the canonical text gVisor netstack returns
|
||||
// when Accept() is called on a listener whose underlying endpoint has
|
||||
// been destroyed (peer rekey, embedded-client reset, account churn).
|
||||
// There is no exported sentinel from gvisor.dev/gvisor/pkg/tcpip that
|
||||
// survives gonet's *net.OpError wrapping in a way errors.Is can match,
|
||||
// so we fall back to a string check. Stable across the gVisor versions
|
||||
// netbird pins.
|
||||
const gvisorInvalidEndpointMsg = "endpoint is in invalid state"
|
||||
|
||||
// IsClosedListenerErr reports whether err signals that an accept loop
|
||||
// should exit because the underlying listener can no longer serve
|
||||
// connections. It recognises:
|
||||
//
|
||||
// - net.ErrClosed for stdlib listeners (Listener.Close was called).
|
||||
// - gVisor's "endpoint is in invalid state" for netstack-backed
|
||||
// listeners whose endpoint was destroyed out from under them
|
||||
// (typically when a per-account WireGuard netstack is reset without
|
||||
// also tearing the listener entry down).
|
||||
//
|
||||
// Without the gVisor branch an accept loop on a netstack listener spins
|
||||
// CPU-hot forever after the endpoint dies, because Accept never blocks
|
||||
// again and the error neither matches net.ErrClosed nor cancels ctx.
|
||||
func IsClosedListenerErr(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(err.Error(), gvisorInvalidEndpointMsg)
|
||||
}
|
||||
|
||||
// AcceptBackoff implements the exponential backoff used by
|
||||
// net/http.Server.Serve for transient Accept errors. Without it a loop
|
||||
// hitting a sticky unknown error burns a full CPU core. The zero value
|
||||
// is ready to use; call Reset after a successful Accept.
|
||||
type AcceptBackoff struct {
|
||||
delay time.Duration
|
||||
}
|
||||
|
||||
// minAcceptDelay / maxAcceptDelay mirror the stdlib defaults
|
||||
// (net/http.Server.Serve) and keep us well below 1 log line per second
|
||||
// per orphaned listener.
|
||||
const (
|
||||
minAcceptDelay = 5 * time.Millisecond
|
||||
maxAcceptDelay = time.Second
|
||||
)
|
||||
|
||||
// Backoff waits the next exponential delay (5ms doubling up to 1s) and
|
||||
// returns true when the wait completed. Returns false if ctx fired
|
||||
// during the wait — callers should treat that as "exit the loop".
|
||||
func (b *AcceptBackoff) Backoff(ctx context.Context) bool {
|
||||
b.advance()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-time.After(b.delay):
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the accumulated delay so the next failure starts at the
|
||||
// minimum delay again. Call after a successful Accept.
|
||||
func (b *AcceptBackoff) Reset() { b.delay = 0 }
|
||||
|
||||
func (b *AcceptBackoff) advance() {
|
||||
if b.delay == 0 {
|
||||
b.delay = minAcceptDelay
|
||||
} else {
|
||||
b.delay *= 2
|
||||
}
|
||||
if b.delay > maxAcceptDelay {
|
||||
b.delay = maxAcceptDelay
|
||||
}
|
||||
}
|
||||
142
proxy/internal/tcp/accept_test.go
Normal file
142
proxy/internal/tcp/accept_test.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestIsClosedListenerErr_NetErrClosed verifies the stdlib path: a
|
||||
// closed *net.Listener returns net.ErrClosed wrapped in *net.OpError,
|
||||
// and IsClosedListenerErr must unwrap it.
|
||||
func TestIsClosedListenerErr_NetErrClosed(t *testing.T) {
|
||||
wrapped := &net.OpError{Op: "accept", Net: "tcp", Err: net.ErrClosed}
|
||||
assert.True(t, IsClosedListenerErr(wrapped),
|
||||
"net.OpError wrapping net.ErrClosed must be recognised as closed")
|
||||
}
|
||||
|
||||
// TestIsClosedListenerErr_GVisorInvalidEndpoint is the load-bearing
|
||||
// regression guard. A gVisor netstack listener whose endpoint has been
|
||||
// destroyed returns this exact text. Without recognising it the accept
|
||||
// loop spins forever and burns a CPU core.
|
||||
func TestIsClosedListenerErr_GVisorInvalidEndpoint(t *testing.T) {
|
||||
err := fmt.Errorf("accept tcp 10.10.1.254:80: endpoint is in invalid state")
|
||||
assert.True(t, IsClosedListenerErr(err),
|
||||
"gVisor 'endpoint is in invalid state' must be recognised as closed")
|
||||
}
|
||||
|
||||
// TestIsClosedListenerErr_OtherError confirms we don't over-match —
|
||||
// transient errors must keep returning false so the backoff path runs.
|
||||
func TestIsClosedListenerErr_OtherError(t *testing.T) {
|
||||
cases := []error{
|
||||
errors.New("temporary failure"),
|
||||
errors.New("accept tcp 10.10.1.254:80: too many open files"),
|
||||
nil,
|
||||
}
|
||||
for _, c := range cases {
|
||||
assert.False(t, IsClosedListenerErr(c),
|
||||
"unexpected match on %v — must not be treated as closed", c)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_ProgressionAndCap asserts the doubling schedule:
|
||||
// 5ms, 10ms, 20ms, 40ms, ... capped at 1s. The test runs against a
|
||||
// real timer but uses tight bounds so a slow CI machine still passes.
|
||||
func TestAcceptBackoff_ProgressionAndCap(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
expected := []time.Duration{
|
||||
5 * time.Millisecond,
|
||||
10 * time.Millisecond,
|
||||
20 * time.Millisecond,
|
||||
40 * time.Millisecond,
|
||||
}
|
||||
for i, want := range expected {
|
||||
start := time.Now()
|
||||
ok := b.Backoff(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok, "Backoff %d must complete; ctx is alive", i)
|
||||
assert.GreaterOrEqual(t, elapsed, want,
|
||||
"backoff %d (%v) must wait at least the configured delay", i, want)
|
||||
assert.Less(t, elapsed, want*4,
|
||||
"backoff %d (%v) must not overshoot by more than 4x — caps misbehaving", i, want)
|
||||
}
|
||||
|
||||
// Burn enough rounds to reach the cap, then assert subsequent
|
||||
// rounds stay at exactly maxAcceptDelay (1s) — the timer should
|
||||
// never exceed it.
|
||||
for range 6 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
assert.Equal(t, maxAcceptDelay, b.delay,
|
||||
"after enough doublings the delay must clamp to maxAcceptDelay")
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_Reset confirms that a successful Accept resets the
|
||||
// schedule — a busy-then-quiet listener mustn't stay on a 1s timer
|
||||
// after recovery.
|
||||
func TestAcceptBackoff_Reset(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
for range 5 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
require.NotEqual(t, time.Duration(0), b.delay, "precondition: delay must have accumulated")
|
||||
|
||||
b.Reset()
|
||||
assert.Equal(t, time.Duration(0), b.delay, "Reset must zero the delay")
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(context.Background())
|
||||
elapsed := time.Since(start)
|
||||
require.True(t, ok, "Backoff after Reset must complete")
|
||||
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
|
||||
"after Reset the next backoff must restart at minAcceptDelay")
|
||||
assert.Less(t, elapsed, 50*time.Millisecond,
|
||||
"after Reset the next backoff must NOT carry over the prior delay")
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_CancelDuringWait proves the loop exits promptly
|
||||
// when ctx fires mid-wait. Without this, a tear-down would still take
|
||||
// up to 1 second per orphaned listener.
|
||||
func TestAcceptBackoff_CancelDuringWait(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
// Drive the backoff up so the next call will wait ~1s — long
|
||||
// enough that we can detect early cancellation.
|
||||
for range 10 {
|
||||
b.Backoff(context.Background())
|
||||
}
|
||||
require.Equal(t, maxAcceptDelay, b.delay)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go func() {
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
cancel()
|
||||
}()
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(ctx)
|
||||
elapsed := time.Since(start)
|
||||
assert.False(t, ok, "Backoff must return false when ctx is cancelled mid-wait")
|
||||
assert.Less(t, elapsed, 200*time.Millisecond,
|
||||
"cancellation must short-circuit the timer; took %v", elapsed)
|
||||
}
|
||||
|
||||
// TestAcceptBackoff_CancelBeforeCall — when ctx is already done the
|
||||
// loop exits without sleeping at all.
|
||||
func TestAcceptBackoff_CancelBeforeCall(t *testing.T) {
|
||||
var b AcceptBackoff
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
start := time.Now()
|
||||
ok := b.Backoff(ctx)
|
||||
elapsed := time.Since(start)
|
||||
assert.False(t, ok, "Backoff must return false when ctx is already cancelled")
|
||||
assert.Less(t, elapsed, 50*time.Millisecond,
|
||||
"already-cancelled ctx must return immediately; took %v", elapsed)
|
||||
}
|
||||
@@ -297,18 +297,23 @@ func (r *Router) Serve(ctx context.Context, ln net.Listener) error {
|
||||
}
|
||||
}()
|
||||
|
||||
var backoff AcceptBackoff
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil || errors.Is(err, net.ErrClosed) {
|
||||
if ctx.Err() != nil || IsClosedListenerErr(err) {
|
||||
if ok := r.Drain(DefaultDrainTimeout); !ok {
|
||||
r.logger.Warn("timed out waiting for connections to drain")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
r.logger.Debugf("SNI router accept: %v", err)
|
||||
r.logger.Debugf("SNI router accept: %v; backing off", err)
|
||||
if !backoff.Backoff(ctx) {
|
||||
return nil
|
||||
}
|
||||
continue
|
||||
}
|
||||
backoff.Reset()
|
||||
r.logger.Debugf("SNI router accepted conn from %s on %s", conn.RemoteAddr(), conn.LocalAddr())
|
||||
r.activeConns.Add(1)
|
||||
go func() {
|
||||
|
||||
@@ -1836,3 +1836,132 @@ func TestRouter_TLS_StaysOnTLSChannel_WhenPlainEnabled(t *testing.T) {
|
||||
t.Fatal("TLS conn never reached the TLS channel")
|
||||
}
|
||||
}
|
||||
|
||||
// scriptedAcceptListener is a net.Listener whose Accept() returns
|
||||
// pre-scripted errors. Used by the accept-loop exit tests to simulate
|
||||
// the failure mode that triggers the tight-loop bug: a netstack
|
||||
// listener whose endpoint has been destroyed and now returns the gVisor
|
||||
// "endpoint is in invalid state" error from every Accept call.
|
||||
type scriptedAcceptListener struct {
|
||||
errs chan error
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
func newScriptedAcceptListener(errs ...error) *scriptedAcceptListener {
|
||||
s := &scriptedAcceptListener{
|
||||
errs: make(chan error, len(errs)+1),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
for _, e := range errs {
|
||||
s.errs <- e
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return nil, net.ErrClosed
|
||||
case err := <-s.errs:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Close() error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
default:
|
||||
close(s.closed)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *scriptedAcceptListener) Addr() net.Addr {
|
||||
return &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0}
|
||||
}
|
||||
|
||||
// TestRouter_Serve_ExitsOnGVisorInvalidEndpoint is the regression guard
|
||||
// for the tight-loop bug: when the underlying netstack endpoint is
|
||||
// destroyed, Accept returns "endpoint is in invalid state" forever. The
|
||||
// loop must recognise that signal and return, otherwise it pegs a CPU
|
||||
// core and floods logs.
|
||||
func TestRouter_Serve_ExitsOnGVisorInvalidEndpoint(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
router := NewRouter(logger, nil, addr)
|
||||
|
||||
gvisorErr := &net.OpError{
|
||||
Op: "accept",
|
||||
Net: "tcp",
|
||||
Addr: addr,
|
||||
Err: errSentinel("endpoint is in invalid state"),
|
||||
}
|
||||
ln := newScriptedAcceptListener(gvisorErr)
|
||||
defer ln.Close()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- router.Serve(context.Background(), ln)
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err, "Serve must return cleanly on a recognised closed-listener error")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Serve did not exit on gVisor 'endpoint is in invalid state' — accept loop is spinning")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRouter_Serve_BacksOffOnTransientError verifies the defence-in-
|
||||
// depth path: when Accept returns an unknown transient error, the loop
|
||||
// MUST not spin. It backs off, then exits cleanly once ctx is cancelled.
|
||||
// "Bounded call count" stands in for "no CPU spin" — without backoff
|
||||
// the goroutine would issue thousands of Accept calls in this window.
|
||||
func TestRouter_Serve_BacksOffOnTransientError(t *testing.T) {
|
||||
logger := log.StandardLogger()
|
||||
addr := &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 443}
|
||||
router := NewRouter(logger, nil, addr)
|
||||
|
||||
const transientErrCount = 5
|
||||
errs := make([]error, transientErrCount)
|
||||
for i := range errs {
|
||||
errs[i] = errSentinel("transient: too many open files")
|
||||
}
|
||||
ln := newScriptedAcceptListener(errs...)
|
||||
defer ln.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
done := make(chan error, 1)
|
||||
start := time.Now()
|
||||
go func() {
|
||||
done <- router.Serve(ctx, ln)
|
||||
}()
|
||||
|
||||
// Cancel after enough time for the backoff to climb (5ms + 10ms +
|
||||
// 20ms + 40ms = 75ms minimum), but short enough that a spinning
|
||||
// loop would have made thousands of calls by now.
|
||||
time.AfterFunc(150*time.Millisecond, cancel)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
assert.NoError(t, err, "Serve must return cleanly on ctx cancellation")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("Serve did not exit on ctx cancellation — backoff or exit path broken")
|
||||
}
|
||||
|
||||
// Without backoff the loop would burn through all 5 scripted errors
|
||||
// in microseconds and then block on the channel. With backoff the
|
||||
// total wall time should be at least 5ms (the first backoff).
|
||||
elapsed := time.Since(start)
|
||||
assert.GreaterOrEqual(t, elapsed, minAcceptDelay,
|
||||
"loop ran without backing off — would burn CPU in production")
|
||||
}
|
||||
|
||||
// errSentinel mirrors gVisor's tcpip error message exactly. We can't
|
||||
// import the gVisor package without dragging in the whole netstack, so
|
||||
// the test uses the canonical string the production error formatter
|
||||
// emits — same shape IsClosedListenerErr matches in production.
|
||||
type errSentinel string
|
||||
|
||||
func (e errSentinel) Error() string { return string(e) }
|
||||
|
||||
|
||||
@@ -125,7 +125,6 @@ func setupBYOPIntegrationTest(t *testing.T) *byopTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
realProxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
@@ -140,7 +140,6 @@ func setupIntegrationTest(t *testing.T) *integrationTestSetup {
|
||||
oidcConfig,
|
||||
nil,
|
||||
usersManager,
|
||||
nil,
|
||||
proxyManager,
|
||||
nil,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user