Compare commits

...

8 Commits

Author SHA1 Message Date
riccardom
ec98c930cb [Recheck watcher ctx cancellation under conn.mu in onWGDisconnected
onWGDisconnected only checked conn.ctx (the engine-scoped context), never
the watcher's own context. disableWgWatcherIfNeeded cancels the wgWatcherCtx,
not conn.ctx, so a disabled watcher's timeout callback did not see the
cancellation.

handshakeCheck runs lock-free, so between the ctx check in periodicHandshakeCheck
and acquiring conn.mu a fast disconnect/reconnect can slip in: the stale watcher
then acquires the lock and tears down the *new*, healthy connection based on the
old timeout, forcing the guard into an unnecessary reconnect (flap).

Recheck watcherCtx.Err() under conn.mu so a superseded watcher exits without
touching the connection that replaced it.
2026-07-03 12:15:24 +02:00
riccardom
60104e000b Discriminate not updated from timeout handshakes 2026-07-03 12:02:50 +02:00
riccardom
d5a212349f Stick new watcher creation to actual existence of af the conn
and its removal to the removal of such same conn.
Avoid debouncing and cross lock dead locking
2026-07-03 11:37:41 +02:00
Zoltan Papp
f6900fb07c [client] backport enforce a single selected exit node (#6640)
* routemanager: enforce a single selected exit node

Backport of the exit-node exclusivity reconcile from the 0.75.0 line
(upstream commit 966fbec11) onto v0.74.0. Exit nodes are mutually
exclusive, but the RouteSelector stores routes with default-on semantics,
so every available exit node reported as selected at once.

Reconcile exit-node selection on each network map: keep at most one
selected -- the user's persisted pick, else whatever management marks for
auto-apply (SkipAutoApply=false), else none. Never auto-activate an exit
node the map does not request.

Carries over only the manager/routeselector logic and its test; the
desktop-only client/server changes and the BumpNetworksRevision UI-push
feature from the original commit are intentionally excluded.

* routeselector: make exit-node reconciliation atomic

enforceSingleExitNode took the RouteSelector lock three separate times
(IsDeselectAll, then DeselectRoutes, then SelectRoutes), so a concurrent
DeselectAllRoutes could interleave and be silently undone: SelectRoutes on
its deselectAll branch clears the flag and re-selects the preferred exit
node, overriding the user's "all off".

Move the whole reconciliation into a single locked RouteSelector method
(SetExclusiveExitNode) that checks deselectAll inside the critical section,
so a deselect-all either fully precedes the reconcile (left untouched) or
fully follows it (honoured). No interleaving is possible.
2026-07-03 10:31:06 +02:00
Zoltan Papp
4b3dd9103d [client] Fix slow wg operations (#6633)
* [iface] Drop redundant device dump in kernel configure()

wgctrl.ConfigureDevice already returns an error when the interface is
missing, so the preceding wg.Device() existence check is redundant. That
check dumps the entire device (all peers) on every configure() call,
making it O(peers) per call and turning bulk peer insertion into
O(peers^2): inserting N peers one by one re-parsed the whole growing peer
list N times. Removing it keeps each peer write constant-time regardless
of how many peers are already configured.

* [iface] Cache WireGuard stats to collapse per-peer device dumps

Each peer runs a WGWatcher that polls GetStats(), and every call dumps
the whole device, so with N peers the watchers perform O(N) full dumps
per poll cycle (O(N^2) work) while each keeps only its own peer's entry.

Wrap the kernel and userspace configurer GetStats() in a short-TTL cache
with singleflight: the staggered per-peer calls share a single device
dump per window and concurrent misses collapse into one dump. The kernel
and userspace WireGuard APIs have no per-peer stats query (a get always
returns the whole device), so a shared cached snapshot avoids the
repeated full dumps.

* Ignore .claude directory
2026-07-02 20:42:43 +02:00
Riccardo Manfrin
8e3b284f4b [client] Increase mgmt grpc buff size to 16MB (#6641) 2026-07-02 17:50:18 +02:00
Maycon Santos
21aa933584 [misc] Fix GHCR image push after dockers_v2 migration (#6653) 2026-07-02 17:21:06 +02:00
Misha Bragin
1dfa85a917 [management] Add vLLM e2e test (#6649)
* Add vLLM to Agent Network

* Add vllm e2e test
2026-07-02 15:36:51 +02:00
16 changed files with 780 additions and 106 deletions

View File

@@ -293,8 +293,11 @@ jobs:
${{ steps.goreleaser.outputs.artifacts }}
JSON
# dockers_v2 artifacts have no top-level goarch field, so match the
# per-platform -amd64 tag suffix instead; it works for both the old
# dockers and the new dockers_v2 image naming.
mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
jq -r '.[] | select(.type == "Docker Image") | .name | select(startswith("ghcr.io/") and endswith("-amd64"))' /tmp/goreleaser-artifacts.json
)
for src in "${src_images[@]}"; do

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
.claude
.idea
.run
*.iml

View File

@@ -17,12 +17,15 @@ import (
type KernelConfigurer struct {
deviceName string
statsCache *statsCache
}
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
return &KernelConfigurer{
c := &KernelConfigurer{
deviceName: deviceName,
}
c.statsCache = newStatsCache(statsCacheTTL, c.fetchStats)
return c
}
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
@@ -246,12 +249,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
}
}()
// validate if device with name exists
_, err = wg.Device(c.deviceName)
if err != nil {
return err
}
return wg.ConfigureDevice(c.deviceName, config)
}
@@ -300,6 +297,14 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
}
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
return c.statsCache.get()
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}
func (c *KernelConfigurer) fetchStats() (map[string]WGStats, error) {
stats := make(map[string]WGStats)
wg, err := wgctrl.New()
if err != nil {
@@ -326,7 +331,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
}
return stats, nil
}
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
return nil
}

View File

@@ -0,0 +1,52 @@
package configurer
import (
"sync"
"time"
"golang.org/x/sync/singleflight"
)
const statsCacheTTL = 1 * time.Second
type statsCache struct {
ttl time.Duration
fetch func() (map[string]WGStats, error)
mu sync.RWMutex
value map[string]WGStats
expireAt time.Time
sf singleflight.Group
}
func newStatsCache(ttl time.Duration, fetch func() (map[string]WGStats, error)) *statsCache {
return &statsCache{ttl: ttl, fetch: fetch}
}
func (c *statsCache) get() (map[string]WGStats, error) {
c.mu.RLock()
if c.value != nil && time.Now().Before(c.expireAt) {
value := c.value
c.mu.RUnlock()
return value, nil
}
c.mu.RUnlock()
value, err, _ := c.sf.Do("stats", func() (interface{}, error) {
res, err := c.fetch()
if err != nil {
return nil, err
}
c.mu.Lock()
c.value = res
c.expireAt = time.Now().Add(c.ttl)
c.mu.Unlock()
return res, nil
})
if err != nil {
return nil, err
}
return value.(map[string]WGStats), nil
}

View File

@@ -0,0 +1,70 @@
package configurer
import (
"errors"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func TestStatsCache_CachesWithinTTL(t *testing.T) {
var calls atomic.Int64
c := newStatsCache(50*time.Millisecond, func() (map[string]WGStats, error) {
calls.Add(1)
return map[string]WGStats{"p": {}}, nil
})
for i := 0; i < 10; i++ {
_, err := c.get()
require.NoError(t, err)
}
require.Equal(t, int64(1), calls.Load(), "within TTL only one underlying fetch")
time.Sleep(60 * time.Millisecond)
_, err := c.get()
require.NoError(t, err)
require.Equal(t, int64(2), calls.Load(), "after TTL expiry a fresh fetch happens")
}
func TestStatsCache_SingleFlight(t *testing.T) {
var calls atomic.Int64
release := make(chan struct{})
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
calls.Add(1)
<-release
return map[string]WGStats{}, nil
})
const n = 50
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
_, _ = c.get()
}()
}
time.Sleep(20 * time.Millisecond)
close(release)
wg.Wait()
require.Equal(t, int64(1), calls.Load(), "concurrent misses collapse into one fetch")
}
func TestStatsCache_ErrorNotCached(t *testing.T) {
var calls atomic.Int64
wantErr := errors.New("dump failed")
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
calls.Add(1)
return nil, wantErr
})
_, err := c.get()
require.ErrorIs(t, err, wantErr)
_, err = c.get()
require.ErrorIs(t, err, wantErr)
require.Equal(t, int64(2), calls.Load(), "errors are not cached; each call retries")
}

View File

@@ -40,6 +40,7 @@ type WGUSPConfigurer struct {
device *device.Device
deviceName string
activityRecorder *bind.ActivityRecorder
statsCache *statsCache
uapiListener net.Listener
}
@@ -50,16 +51,19 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
wgCfg.startUAPI()
return wgCfg
}
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
return &WGUSPConfigurer{
wgCfg := &WGUSPConfigurer{
device: device,
deviceName: deviceName,
activityRecorder: activityRecorder,
}
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
return wgCfg
}
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
@@ -348,6 +352,10 @@ func (t *WGUSPConfigurer) Close() {
}
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
return t.statsCache.get()
}
func (t *WGUSPConfigurer) fetchStats() (map[string]WGStats, error) {
ipc, err := t.device.IpcGet()
if err != nil {
return nil, fmt.Errorf("ipc get: %w", err)

View File

@@ -195,7 +195,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
statusICE: worker.NewAtomicStatus(),
dumpState: dumpState,
endpointUpdater: NewEndpointUpdater(connLog, config.WgConfig, isController(config)),
wgWatcher: NewWGWatcher(connLog, config.WgConfig.WgInterface, config.Key, dumpState),
metricsRecorder: services.MetricsRecorder,
}
@@ -663,11 +662,16 @@ func (conn *Conn) onGuardEvent() {
}
}
func (conn *Conn) onWGDisconnected() {
// onWGDisconnected is invoked by the watcher goroutine when a handshake timeout is detected.
// watcherCtx is the context of the watcher that fired: the timeout check runs lock-free, so by
// the time we acquire conn.mu the watcher may have been cancelled (disabled) and a new connection
// (and watcher) may already be in place. Re-checking watcherCtx under the lock prevents a stale
// watcher from tearing down the connection that superseded it.
func (conn *Conn) onWGDisconnected(watcherCtx context.Context) {
conn.mu.Lock()
defer conn.mu.Unlock()
if conn.ctx.Err() != nil {
if conn.ctx.Err() != nil || watcherCtx.Err() != nil {
return
}
@@ -802,25 +806,44 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
})
}
// enableWgWatcherIfNeeded starts a fresh watcher for the current connection. A new WGWatcher
// instance is created per attempt (rather than reusing one) so its lifecycle is bound entirely
// to conn.mu: enable/disable can never race against an old watcher goroutine's shutdown, which
// was the source of the "watcher silently fails to restart on a fast reconnect" bug. Caller must
// hold conn.mu.
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
if !conn.wgWatcher.PrepareInitialHandshake() {
if conn.wgWatcher != nil {
// a watcher is already running for the current connection
return
}
watcher := NewWGWatcher(conn.Log, conn.config.WgConfig.WgInterface, conn.config.Key, conn.dumpState)
watcher.PrepareInitialHandshake()
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
conn.wgWatcher = watcher
conn.wgWatcherCancel = wgWatcherCancel
conn.wgWatcherWg.Add(1)
go func() {
defer conn.wgWatcherWg.Done()
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
onDisconnected := func() { conn.onWGDisconnected(wgWatcherCtx) }
watcher.EnableWgWatcher(wgWatcherCtx, enabledTime, onDisconnected, conn.onWGHandshakeSuccess)
}()
}
// disableWgWatcherIfNeeded stops and drops the current watcher once no transport is active. It
// only signals the watcher goroutine (cancel) and clears the reference; it never waits for the
// goroutine to exit, because the watcher's own timeout path reentrantly calls back here under
// conn.mu (via onWGDisconnected), so blocking would deadlock. The cancelled goroutine drains
// harmlessly. Caller must hold conn.mu.
func (conn *Conn) disableWgWatcherIfNeeded() {
if conn.currentConnPriority == conntype.None && conn.wgWatcherCancel != nil {
conn.wgWatcherCancel()
conn.wgWatcherCancel = nil
if conn.currentConnPriority != conntype.None || conn.wgWatcher == nil {
return
}
conn.wgWatcherCancel()
conn.wgWatcher = nil
conn.wgWatcherCancel = nil
}
func (conn *Conn) newProxy(remoteConn net.Conn) (wgproxy.Proxy, error) {
@@ -843,7 +866,9 @@ func (conn *Conn) resetEndpoint() {
return
}
conn.Log.Infof("reset wg endpoint")
conn.wgWatcher.Reset()
if conn.wgWatcher != nil {
conn.wgWatcher.Reset()
}
if err := conn.endpointUpdater.RemoveEndpointAddress(); err != nil {
conn.Log.Warnf("failed to remove endpoint address before update: %v", err)
}

View File

@@ -3,7 +3,6 @@ package peer
import (
"context"
"fmt"
"sync"
"time"
log "github.com/sirupsen/logrus"
@@ -24,14 +23,15 @@ type WGInterfaceStater interface {
GetStats() (map[string]configurer.WGStats, error)
}
// WGWatcher is single-shot: create one instance per connection attempt, run it once via
// EnableWgWatcher, then discard it. Lifecycle (start/stop) is owned by Conn under conn.mu,
// so the watcher itself keeps no "enabled" state to go stale on a fast disconnect/reconnect.
type WGWatcher struct {
log *log.Entry
wgIfaceStater WGInterfaceStater
peerKey string
stateDump *stateDump
enabled bool
muEnabled sync.Mutex
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
initialHandshake time.Time
@@ -48,25 +48,14 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
}
}
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
// handshake time. It must be called before the peer is (re)configured on the WireGuard
// interface, so the captured baseline reflects the state prior to this connection attempt
// instead of racing with that configuration. Returns ok=false if the watcher is already
// running, in which case EnableWgWatcher must not be called.
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
w.muEnabled.Lock()
if w.enabled {
w.muEnabled.Unlock()
return false
}
// PrepareInitialHandshake reads the peer's current WireGuard handshake time. It must be
// called before the peer is (re)configured on the WireGuard interface, so the captured
// baseline reflects the state prior to this connection attempt instead of racing with
// that configuration.
func (w *WGWatcher) PrepareInitialHandshake() {
w.log.Debugf("enable WireGuard watcher")
w.enabled = true
w.muEnabled.Unlock()
handshake, _ := w.wgState()
w.initialHandshake = handshake
return true
}
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
@@ -74,10 +63,6 @@ func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
// for context lifecycle management.
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
w.muEnabled.Lock()
w.enabled = false
w.muEnabled.Unlock()
}
// Reset signals the watcher that the WireGuard peer has been reset and a new
@@ -103,6 +88,7 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
case <-timer.C:
handshake, ok := w.handshakeCheck(lastHandshake)
if !ok {
// early ctx cancel check return
if ctx.Err() != nil {
return
}
@@ -147,9 +133,9 @@ func (w *WGWatcher) handshakeCheck(lastHandshake time.Time) (*time.Time, bool) {
w.log.Tracef("previous handshake, handshake: %v, %v", lastHandshake, handshake)
// the current know handshake did not change
// the current known handshake did not change
if handshake.Equal(lastHandshake) {
w.log.Warnf("WireGuard handshake timed out: %v", handshake)
w.log.Warnf("WireGuard handshake not updated: %v", handshake)
return nil, false
}

View File

@@ -7,7 +7,6 @@ import (
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/iface/configurer"
)
@@ -35,8 +34,7 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
@@ -66,8 +64,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
ctx, cancel := context.WithCancel(context.Background())
ok := watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should not be enabled yet")
watcher.PrepareInitialHandshake()
wg := &sync.WaitGroup{}
wg.Add(1)
@@ -83,8 +80,7 @@ func TestWGWatcher_ReEnable(t *testing.T) {
ctx, cancel = context.WithCancel(context.Background())
defer cancel()
ok = watcher.PrepareInitialHandshake()
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
watcher.PrepareInitialHandshake()
onDisconnected := make(chan struct{}, 1)
go watcher.EnableWgWatcher(ctx, time.Now(), func() {

View File

@@ -0,0 +1,191 @@
package routemanager
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/routeselector"
"github.com/netbirdio/netbird/route"
)
func newExitNodeTestManager() *DefaultManager {
return &DefaultManager{routeSelector: routeselector.NewRouteSelector()}
}
func exitRoute(netID, peer string, skipAutoApply bool) *route.Route {
return &route.Route{
NetID: route.NetID(netID),
Network: netip.MustParsePrefix("0.0.0.0/0"),
Peer: peer,
SkipAutoApply: skipAutoApply,
}
}
func TestPickPreferredExitNode(t *testing.T) {
tests := []struct {
name string
info exitNodeInfo
want route.NetID
}{
{
name: "persisted user selection wins over management",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
userSelected: []route.NetID{"b"},
selectedByManagement: []route.NetID{"a"},
},
want: "b",
},
{
name: "multiple user-selected self-heal to deterministic min",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
userSelected: []route.NetID{"c", "a"},
},
want: "a",
},
{
name: "explicit opt-out keeps none",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b"},
userDeselected: []route.NetID{"a", "b"},
},
want: "",
},
{
name: "fresh defaults to management auto-apply pick",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b", "c"},
selectedByManagement: []route.NetID{"b"},
},
want: "b",
},
{
name: "no user pick and no management auto-apply selects none",
info: exitNodeInfo{
allIDs: []route.NetID{"c", "a", "b"},
},
want: "",
},
{
name: "user-deselect does not block a management auto-apply sibling",
info: exitNodeInfo{
allIDs: []route.NetID{"a", "b"},
userDeselected: []route.NetID{"a"},
selectedByManagement: []route.NetID{"b"},
},
want: "b",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, pickPreferredExitNode(tt.info), "preferred exit node")
})
}
}
func TestEnforceSingleExitNode(t *testing.T) {
m := newExitNodeTestManager()
all := []route.NetID{"a", "b", "c"}
m.enforceSingleExitNode("b", all)
assert.False(t, m.routeSelector.IsSelected("a"), "a should be deselected")
assert.True(t, m.routeSelector.IsSelected("b"), "b should be the only selected exit node")
assert.False(t, m.routeSelector.IsSelected("c"), "c should be deselected")
// Switching the preferred node moves the single selection.
m.enforceSingleExitNode("c", all)
assert.False(t, m.routeSelector.IsSelected("a"), "a stays deselected")
assert.False(t, m.routeSelector.IsSelected("b"), "b should now be deselected")
assert.True(t, m.routeSelector.IsSelected("c"), "c should now be selected")
// Empty preferred turns every exit node off.
m.enforceSingleExitNode("", all)
for _, id := range all {
assert.False(t, m.routeSelector.IsSelected(id), "no exit node should be selected")
}
}
func TestEnforceSingleExitNode_RespectsDeselectAll(t *testing.T) {
m := newExitNodeTestManager()
m.routeSelector.DeselectAllRoutes()
m.enforceSingleExitNode("b", []route.NetID{"a", "b"})
assert.True(t, m.routeSelector.IsDeselectAll(), "global deselect-all must stay in effect")
assert.False(t, m.routeSelector.IsSelected("b"), "no exit node should be forced on while deselect-all is set")
}
func TestUpdateRouteSelectorFromManagement_FreshSelectsOne(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
"lan|192.168.1.0/24": {{NetID: "lan", Network: netip.MustParsePrefix("192.168.1.0/24"), Peer: "p3"}},
"exitC|0.0.0.0/0": {exitRoute("exitC", "p4", false)},
}
m.updateRouteSelectorFromManagement(routes)
// Exactly one exit node (the deterministic first) is selected.
assert.True(t, m.routeSelector.IsSelected("exitA"), "exitA is the deterministic default")
assert.False(t, m.routeSelector.IsSelected("exitB"), "exitB must not also be selected")
assert.False(t, m.routeSelector.IsSelected("exitC"), "exitC must not also be selected")
// Non-exit routes are left at their default-on state.
assert.True(t, m.routeSelector.IsSelected("lan"), "non-exit route selection is untouched")
}
func TestUpdateRouteSelectorFromManagement_HonorsPersistedPick(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
}
all := []route.NetID{"exitA", "exitB"}
// Simulate the state the runtime select path leaves behind: exactly one
// exit node explicitly selected, its sibling deselected.
require.NoError(t, m.routeSelector.SelectRoutes([]route.NetID{"exitB"}, true, all))
require.NoError(t, m.routeSelector.DeselectRoutes([]route.NetID{"exitA"}, all))
m.updateRouteSelectorFromManagement(routes)
assert.True(t, m.routeSelector.IsSelected("exitB"), "persisted pick must stay selected")
assert.False(t, m.routeSelector.IsSelected("exitA"), "the other exit node stays deselected")
}
func TestUpdateRouteSelectorFromManagement_OptOutKeepsNone(t *testing.T) {
m := newExitNodeTestManager()
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", false)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", false)},
}
all := []route.NetID{"exitA", "exitB"}
// User deselected exit nodes and selected none.
require.NoError(t, m.routeSelector.DeselectRoutes(all, all))
m.updateRouteSelectorFromManagement(routes)
assert.False(t, m.routeSelector.IsSelected("exitA"), "opt-out keeps exitA off")
assert.False(t, m.routeSelector.IsSelected("exitB"), "opt-out keeps exitB off")
}
func TestUpdateRouteSelectorFromManagement_NoAutoApplySelectsNone(t *testing.T) {
m := newExitNodeTestManager()
// SkipAutoApply=true: management offers the exit nodes but doesn't request
// auto-activation, so none should be selected until the user picks one.
routes := route.HAMap{
"exitA|0.0.0.0/0": {exitRoute("exitA", "p1", true)},
"exitB|0.0.0.0/0": {exitRoute("exitB", "p2", true)},
}
m.updateRouteSelectorFromManagement(routes)
assert.False(t, m.routeSelector.IsSelected("exitA"), "no auto-apply keeps exitA off")
assert.False(t, m.routeSelector.IsSelected("exitB"), "no auto-apply keeps exitB off")
}

View File

@@ -701,7 +701,13 @@ func resolveURLsToIPs(urls []string) []net.IP {
return ips
}
// updateRouteSelectorFromManagement updates the route selector based on the isSelected status from the management server
// updateRouteSelectorFromManagement reconciles exit-node selection on every
// network map: it keeps at most one exit node selected — the user's persisted
// pick, else whatever management marks for auto-apply (SkipAutoApply=false),
// else none. We never auto-activate an exit node the map doesn't request; it
// stays off until the user picks it. Exit nodes are mutually exclusive, but the
// RouteSelector stores routes with default-on semantics, so without this every
// available exit node would report selected at once.
func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HAMap) {
m.mirrorV6ExitPairSelections(clientRoutes)
@@ -712,13 +718,14 @@ func (m *DefaultManager) updateRouteSelectorFromManagement(clientRoutes route.HA
return
}
exitNodeInfo := m.collectExitNodeInfo(clientRoutes)
if len(exitNodeInfo.allIDs) == 0 {
info := m.collectExitNodeInfo(clientRoutes)
if len(info.allIDs) == 0 {
return
}
m.updateExitNodeSelections(exitNodeInfo)
m.logExitNodeUpdate(exitNodeInfo)
preferred := pickPreferredExitNode(info)
m.enforceSingleExitNode(preferred, info.allIDs)
m.logExitNodeUpdate(info, preferred)
}
// mirrorV6ExitPairSelections keeps every synthesized "-v6" exit route's selection
@@ -746,6 +753,10 @@ type exitNodeInfo struct {
userDeselected []route.NetID
}
// collectExitNodeInfo categorises the available exit nodes by their persisted
// selection state. It keys on the base (v4) NetID and skips the synthesized
// "-v6" partner, which inherits its base's selection through the RouteSelector
// — counting it separately would double-count the pair.
func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeInfo {
var info exitNodeInfo
@@ -755,6 +766,9 @@ func (m *DefaultManager) collectExitNodeInfo(clientRoutes route.HAMap) exitNodeI
}
netID := haID.NetID()
if strings.HasSuffix(string(netID), route.V6ExitSuffix) {
continue
}
info.allIDs = append(info.allIDs, netID)
if m.routeSelector.HasUserSelectionForRoute(netID) {
@@ -791,45 +805,52 @@ func (m *DefaultManager) checkManagementSelection(routes []*route.Route, netID r
}
}
func (m *DefaultManager) updateExitNodeSelections(info exitNodeInfo) {
routesToDeselect := m.getRoutesToDeselect(info.allIDs)
m.deselectExitNodes(routesToDeselect)
m.selectExitNodesByManagement(info.selectedByManagement, info.allIDs)
// pickPreferredExitNode chooses the single exit node to keep selected. In order:
// - a persisted user selection wins (deterministic if several survive from
// legacy state, so the set self-heals down to one);
// - otherwise activate only what management marks for auto-apply
// (SkipAutoApply=false); the lexicographically first if it marks several.
//
// Returns "" when neither holds — we never force an arbitrary exit node on. A
// route the map doesn't auto-apply stays off until the user selects it.
// info.userDeselected is informational only: an explicit deselect simply keeps
// that route out of both lists above, so it can't be picked.
func pickPreferredExitNode(info exitNodeInfo) route.NetID {
if len(info.userSelected) > 0 {
return minNetID(info.userSelected)
}
if len(info.selectedByManagement) > 0 {
return minNetID(info.selectedByManagement)
}
return ""
}
func (m *DefaultManager) getRoutesToDeselect(allIDs []route.NetID) []route.NetID {
var routesToDeselect []route.NetID
for _, netID := range allIDs {
if !m.routeSelector.HasUserSelectionForRoute(netID) {
routesToDeselect = append(routesToDeselect, netID)
// enforceSingleExitNode makes preferred the only selected exit node: every other
// available exit node is deselected and preferred (if any) is selected, without
// disturbing non-exit route selections. The whole reconciliation runs under a
// single RouteSelector lock (SetExclusiveExitNode) so a concurrent deselect-all
// cannot interleave and get undone; a global deselect-all is left untouched so
// the user's "all off" stays in effect.
func (m *DefaultManager) enforceSingleExitNode(preferred route.NetID, allIDs []route.NetID) {
m.routeSelector.SetExclusiveExitNode(preferred, allIDs)
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo, preferred route.NetID) {
log.Debugf("Exit node selection: %d available, preferred=%q (%d user-selected, %d user-deselected, %d management-selected)",
len(info.allIDs), preferred, len(info.userSelected), len(info.userDeselected), len(info.selectedByManagement))
}
// minNetID returns the lexicographically smallest NetID, for a deterministic
// default pick that stays stable across restarts.
func minNetID(ids []route.NetID) route.NetID {
if len(ids) == 0 {
return ""
}
best := ids[0]
for _, id := range ids[1:] {
if id < best {
best = id
}
}
return routesToDeselect
}
func (m *DefaultManager) deselectExitNodes(routesToDeselect []route.NetID) {
if len(routesToDeselect) == 0 {
return
}
err := m.routeSelector.DeselectRoutes(routesToDeselect, routesToDeselect)
if err != nil {
log.Warnf("Failed to deselect exit nodes: %v", err)
}
}
func (m *DefaultManager) selectExitNodesByManagement(selectedByManagement []route.NetID, allIDs []route.NetID) {
if len(selectedByManagement) == 0 {
return
}
err := m.routeSelector.SelectRoutes(selectedByManagement, true, allIDs)
if err != nil {
log.Warnf("Failed to select exit nodes: %v", err)
}
}
func (m *DefaultManager) logExitNodeUpdate(info exitNodeInfo) {
log.Debugf("Updated route selector: %d exit nodes available, %d selected by management, %d user-selected, %d user-deselected",
len(info.allIDs), len(info.selectedByManagement), len(info.userSelected), len(info.userDeselected))
return best
}

View File

@@ -115,7 +115,38 @@ func (rs *RouteSelector) DeselectAllRoutes() {
clear(rs.selectedRoutes)
}
// IsDeselectAll reports whether the user has explicitly deselected all routes.
// SetExclusiveExitNode atomically makes preferred the only selected exit node
// among exitIDs: every other ID in exitIDs is deselected and preferred (when
// non-empty) is selected, all under a single lock. Holding the lock across the
// whole reconciliation prevents a concurrent DeselectAllRoutes from interleaving
// between the deselect and select steps and being silently undone. A global
// deselect-all is left untouched so the user's "all off" stays in effect;
// non-exit routes are never referenced, so their selection is preserved.
func (rs *RouteSelector) SetExclusiveExitNode(preferred route.NetID, exitIDs []route.NetID) {
rs.mu.Lock()
defer rs.mu.Unlock()
if rs.deselectAll {
return
}
for _, id := range exitIDs {
if id == preferred {
continue
}
rs.deselectedRoutes[id] = struct{}{}
delete(rs.selectedRoutes, id)
}
if preferred != "" {
delete(rs.deselectedRoutes, preferred)
rs.selectedRoutes[preferred] = struct{}{}
}
}
// IsDeselectAll reports whether the global "deselect all" flag is set, i.e. the
// user explicitly disabled every route. Callers enforcing per-route invariants
// (e.g. single exit node) should leave the selection untouched when it is.
func (rs *RouteSelector) IsDeselectAll() bool {
rs.mu.RLock()
defer rs.mu.RUnlock()

View File

@@ -0,0 +1,171 @@
//go:build e2e
package agentnetwork
import (
"context"
"strings"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/e2e/harness"
"github.com/netbirdio/netbird/shared/management/http/api"
)
// TestVLLMProvider proves the proxy supports a self-hosted vLLM backend. vLLM is
// OpenAI-compatible, so it uses the "vllm" catalog entry (KindCustom) and is
// reached over plain HTTP — no TLS anywhere on the path:
//
// client --tunnel--> netbird proxy --http--> vllm (:8000, OpenAI-compatible)
//
// The mock vLLM server answers /v1/chat/completions with an OpenAI-shaped
// completion carrying a non-zero usage block. The test asserts the chat returns
// 200 with the completion, that the request is recorded in the access log by its
// session id, and that vLLM's usage block is metered into a consumption row —
// which together prove request routing, response parsing, and token accounting
// all work for a self-hosted OpenAI-compatible provider.
//
// It needs no external credentials (the mock ignores auth), so it always runs.
func TestVLLMProvider(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
vllm, err := harness.StartVLLM(ctx, srv)
require.NoError(t, err, "start mock vLLM server")
t.Cleanup(func() { _ = vllm.Terminate(context.Background()) })
grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-vllm"})
require.NoError(t, err, "create group")
t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) })
ephemeral := false
sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{
Name: "e2e-vllm-client",
Type: "reusable",
ExpiresIn: 86400,
UsageLimit: 0,
AutoGroups: []string{grp.Id},
Ephemeral: &ephemeral,
})
require.NoError(t, err, "mint setup key")
require.NotEmpty(t, sk.Key, "setup key plaintext")
// vLLM provider pointed at the mock over plain HTTP. The mock ignores auth,
// so a dummy key satisfies the "Bearer ${API_KEY}" template. The served model
// is enumerated so the router dispatches this model string to this provider.
dummyKey := "sk-vllm-e2e"
prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{
Name: "vllm",
ProviderId: "vllm",
UpstreamUrl: vllm.URL,
ApiKey: &dummyKey,
Enabled: ptr(true),
BootstrapCluster: ptr(harness.AgentNetworkCluster),
Models: &[]api.AgentNetworkProviderModel{
{Id: harness.VLLMModel, InputPer1k: 0.001, OutputPer1k: 0.002},
},
})
require.NoError(t, err, "create vllm provider")
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) })
// Token limit far above the handful of tokens this test drives, so it never
// blocks but switches on usage metering — the switch that makes consumption
// rows get recorded.
enabled := true
pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{
Name: "e2e-vllm-allow",
Enabled: &enabled,
SourceGroups: []string{grp.Id},
DestinationProviderIds: []string{prov.Id},
Limits: &api.AgentNetworkPolicyLimits{
TokenLimit: api.AgentNetworkPolicyTokenLimit{
Enabled: true,
GroupCap: 10_000_000,
UserCap: 10_000_000,
WindowSeconds: 60,
},
},
})
require.NoError(t, err, "create policy")
t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) })
settings, err := srv.GetSettings(ctx)
require.NoError(t, err, "read settings")
require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned")
proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-vllm-proxy")
require.NoError(t, err, "mint proxy token")
px, err := harness.StartProxy(ctx, srv, proxyToken)
require.NoError(t, err, "start proxy")
t.Cleanup(func() { _ = px.Terminate(context.Background()) })
cl, err := harness.StartClient(ctx, srv, sk.Key)
require.NoError(t, err, "start client")
t.Cleanup(func() { _ = cl.Terminate(context.Background()) })
require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management")
if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil {
t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background()))
}
proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint)
require.NoError(t, err, "resolve endpoint to proxy IP")
before, _ := srv.ListAccessLogs(ctx)
sessionID := "e2e-session-vllm"
// Retry to absorb tunnel/DNS jitter on the first call.
var code int
var body string
deadline := time.Now().Add(90 * time.Second)
for time.Now().Before(deadline) {
c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, harness.VLLMModel, "Reply with exactly: pong", sessionID)
if cerr == nil {
code, body = c, b
if code == 200 {
break
}
}
time.Sleep(5 * time.Second)
}
require.Equal(t, 200, code,
"chat through the vLLM provider must return 200; body: %s\n=== vllm logs ===\n%s\n=== proxy logs ===\n%s",
body, vllm.Logs(context.Background()), px.Logs(context.Background()))
require.True(t, strings.Contains(body, "chat.completion"),
"body should be an OpenAI-compatible chat completion; got: %s", body)
// The request must surface as an access-log row carrying our session id.
require.Eventually(t, func() bool {
logs, lerr := srv.ListAccessLogs(ctx)
return lerr == nil && logs.TotalRecords > before.TotalRecords
}, 30*time.Second, 2*time.Second, "an access-log row should be ingested for the vLLM provider")
require.Eventually(t, func() bool {
logs, lerr := srv.ListAccessLogs(ctx)
if lerr != nil {
return false
}
for _, r := range logs.Data {
if r.SessionId != nil && *r.SessionId == sessionID {
return true
}
}
return false
}, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row", sessionID)
// vLLM's usage block (prompt_tokens=11, completion_tokens=2) must be parsed
// and metered into a consumption row with positive token counts.
require.Eventually(t, func() bool {
rows, lerr := srv.ListConsumption(ctx)
if lerr != nil {
return false
}
for _, r := range rows {
if r.TokensInput > 0 && r.TokensOutput > 0 {
return true
}
}
return false
}, 60*time.Second, 3*time.Second, "vLLM usage must be metered into a consumption row")
}

113
e2e/harness/vllm.go Normal file
View File

@@ -0,0 +1,113 @@
//go:build e2e
package harness
import (
"context"
"fmt"
"os"
"path/filepath"
"time"
"github.com/docker/docker/api/types/container"
"github.com/testcontainers/testcontainers-go"
"github.com/testcontainers/testcontainers-go/wait"
)
const (
vllmImage = "nginx:alpine"
vllmAlias = "vllm"
vllmPort = "8000/tcp"
// VLLMModel is the served model id the mock advertises and echoes back. It
// matches a real small model commonly served by vLLM so the provider's
// enumerated model and the client's request line up.
VLLMModel = "Qwen/Qwen2.5-0.5B-Instruct"
)
// vllmNginxConf emulates a vLLM OpenAI-compatible server over plain HTTP (vLLM's
// default: no TLS, port 8000). It answers /v1/models with a one-model list and
// any chat/completions path with a canned OpenAI-shaped chat completion carrying
// a non-zero usage block, so the proxy's OpenAI parser records real token
// consumption. Running actual vLLM in CI is infeasible (GPU + multi-GB model
// download), so this stands in for the wire contract the proxy depends on.
const vllmNginxConf = `pid /tmp/nginx.pid;
events {}
http {
server {
listen 8000;
location = /v1/models {
default_type application/json;
return 200 '{"object":"list","data":[{"id":"Qwen/Qwen2.5-0.5B-Instruct","object":"model","owned_by":"vllm"}]}';
}
location / {
default_type application/json;
return 200 '{"id":"chatcmpl-e2e-vllm","object":"chat.completion","created":1700000000,"model":"Qwen/Qwen2.5-0.5B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":2,"total_tokens":13}}';
}
}
}
`
// VLLM is a mock vLLM OpenAI-compatible server on the combined server's network,
// reachable at http://vllm:8000. A "vllm" provider points at it to exercise the
// proxy's support for self-hosted OpenAI-compatible backends.
type VLLM struct {
container testcontainers.Container
workDir string
// URL is the upstream URL the vllm provider points at (http://<alias>:8000).
URL string
}
// StartVLLM runs the mock vLLM server on the shared network over plain HTTP.
func StartVLLM(ctx context.Context, c *Combined) (*VLLM, error) {
workDir, err := os.MkdirTemp("/tmp", "nb-e2e-vllm-*")
if err != nil {
return nil, fmt.Errorf("create vllm work dir: %w", err)
}
// Widen so the (non-root worker) nginx container can traverse the bind mount.
if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e config dir
return nil, fmt.Errorf("chmod vllm dir: %w", err)
}
if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(vllmNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config
return nil, fmt.Errorf("write nginx conf: %w", err)
}
req := testcontainers.ContainerRequest{
Image: vllmImage,
ExposedPorts: []string{vllmPort},
Networks: []string{c.network.Name},
NetworkAliases: map[string][]string{c.network.Name: {vllmAlias}},
Cmd: []string{"nginx", "-c", "/conf/nginx.conf", "-g", "daemon off;"},
HostConfigModifier: func(hc *container.HostConfig) {
hc.Binds = append(hc.Binds, workDir+":/conf:ro")
},
WaitingFor: wait.ForListeningPort(vllmPort).WithStartupTimeout(60 * time.Second),
}
ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
ContainerRequest: req,
Started: true,
})
if err != nil {
_ = os.RemoveAll(workDir)
return nil, fmt.Errorf("start vllm container: %w", err)
}
return &VLLM{container: ctr, workDir: workDir, URL: "http://" + vllmAlias + ":8000"}, nil
}
// Logs returns the vLLM container logs, for diagnostics on failure.
func (v *VLLM) Logs(ctx context.Context) string {
return containerLogs(ctx, v.container)
}
// Terminate stops the vLLM container and cleans its work dir.
func (v *VLLM) Terminate(ctx context.Context) error {
var err error
if v.container != nil {
err = v.container.Terminate(ctx)
}
if v.workDir != "" {
_ = os.RemoveAll(v.workDir)
}
return err
}

View File

@@ -33,10 +33,15 @@ const ConnectTimeout = 10 * time.Second
const healthCheckTimeout = 5 * time.Second
const (
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size
// for the management client connection. Value is in bytes.
EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE"
// defaultMaxRecvMsgSize is the max gRPC receive message size used for the
// management client connection when EnvMaxRecvMsgSize is unset or invalid.
// It overrides the gRPC library default of 4 MB.
defaultMaxRecvMsgSize = 1024 * 1024 * 16
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
errMsgNoMgmtConnection = "no connection to management"
)
@@ -84,22 +89,22 @@ type ExposeResponse struct {
}
// MaxRecvMsgSize returns the configured max gRPC receive message size from
// the environment, or 0 if unset (which uses the gRPC default of 4 MB).
// the environment, or defaultMaxRecvMsgSize (16 MB) if unset or invalid.
func MaxRecvMsgSize() int {
val := os.Getenv(EnvMaxRecvMsgSize)
if val == "" {
return 0
return defaultMaxRecvMsgSize
}
size, err := strconv.Atoi(val)
if err != nil {
log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err)
return 0
return defaultMaxRecvMsgSize
}
if size <= 0 {
log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size)
return 0
return defaultMaxRecvMsgSize
}
return size

View File

@@ -21,11 +21,11 @@ func TestMaxRecvMsgSize(t *testing.T) {
envValue string
expected int
}{
{name: "unset returns 0", envValue: "", expected: 0},
{name: "unset returns default", envValue: "", expected: defaultMaxRecvMsgSize},
{name: "valid value", envValue: "10485760", expected: 10485760},
{name: "non-numeric returns 0", envValue: "abc", expected: 0},
{name: "negative returns 0", envValue: "-1", expected: 0},
{name: "zero returns 0", envValue: "0", expected: 0},
{name: "non-numeric returns default", envValue: "abc", expected: defaultMaxRecvMsgSize},
{name: "negative returns default", envValue: "-1", expected: defaultMaxRecvMsgSize},
{name: "zero returns default", envValue: "0", expected: defaultMaxRecvMsgSize},
}
for _, tt := range tests {