mirror of
https://github.com/netbirdio/netbird.git
synced 2026-07-03 05:09:54 +00:00
Compare commits
17 Commits
dependabot
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b3dd9103d | ||
|
|
8e3b284f4b | ||
|
|
21aa933584 | ||
|
|
1dfa85a917 | ||
|
|
859fe19fff | ||
|
|
e40cb294f6 | ||
|
|
e203e0f42a | ||
|
|
167be3a30f | ||
|
|
1d8b5f6e5c | ||
|
|
7d4736de55 | ||
|
|
06839a4731 | ||
|
|
eb422a5cd3 | ||
|
|
0aa0f7c76b | ||
|
|
7c0d8cbae0 | ||
|
|
2ab99eefa6 | ||
|
|
ff04ffb534 | ||
|
|
980598ed4a |
9
.github/workflows/agent-network-e2e.yml
vendored
9
.github/workflows/agent-network-e2e.yml
vendored
@@ -1,10 +1,10 @@
|
||||
name: Agent Network E2E
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
pull_request:
|
||||
# Nightly at 03:00 UTC, plus on demand from the Actions tab.
|
||||
schedule:
|
||||
- cron: "0 3 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
@@ -13,7 +13,6 @@ concurrency:
|
||||
jobs:
|
||||
e2e:
|
||||
name: Agent Network E2E
|
||||
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
|
||||
2
.github/workflows/golang-test-linux.yml
vendored
2
.github/workflows/golang-test-linux.yml
vendored
@@ -158,7 +158,7 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Test
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags devcert -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -coverprofile=coverage.txt -tags 'devcert privileged' -exec 'sudo --preserve-env=CI,CGO_ENABLED' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/testutil/privileged)
|
||||
|
||||
- name: Upload coverage reports to Codecov
|
||||
if: matrix.arch == 'amd64'
|
||||
|
||||
5
.github/workflows/release.yml
vendored
5
.github/workflows/release.yml
vendored
@@ -293,8 +293,11 @@ jobs:
|
||||
${{ steps.goreleaser.outputs.artifacts }}
|
||||
JSON
|
||||
|
||||
# dockers_v2 artifacts have no top-level goarch field, so match the
|
||||
# per-platform -amd64 tag suffix instead; it works for both the old
|
||||
# dockers and the new dockers_v2 image naming.
|
||||
mapfile -t src_images < <(
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
||||
jq -r '.[] | select(.type == "Docker Image") | .name | select(startswith("ghcr.io/") and endswith("-amd64"))' /tmp/goreleaser-artifacts.json
|
||||
)
|
||||
|
||||
for src in "${src_images[@]}"; do
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
.claude
|
||||
.idea
|
||||
.run
|
||||
*.iml
|
||||
|
||||
@@ -10,7 +10,7 @@ var (
|
||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||
|
||||
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||
EnvKeyNBLazyConn = lazyconn.EnvLazyConn
|
||||
|
||||
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||
|
||||
@@ -71,12 +71,14 @@ var (
|
||||
extraIFaceBlackList []string
|
||||
anonymizeFlag bool
|
||||
dnsRouteInterval time.Duration
|
||||
lazyConnEnabled bool
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
// lazyConnEnabled is the parse target for the deprecated --enable-lazy-connection
|
||||
// flag. The flag is inert; the value is no longer read (use NB_LAZY_CONN instead).
|
||||
lazyConnEnabled bool
|
||||
mtu uint16
|
||||
profilesDisabled bool
|
||||
updateSettingsDisabled bool
|
||||
captureEnabled bool
|
||||
networksDisabled bool
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "netbird",
|
||||
@@ -210,7 +212,8 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassEnabled, enableRosenpassFlag, false, "[Experimental] Enable Rosenpass feature. If enabled, the connection will be post-quantum secured via Rosenpass.")
|
||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand. Note: this setting may be overridden by management configuration.")
|
||||
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "Deprecated: no longer used. Lazy connections are controlled by the server and the NB_LAZY_CONN environment variable.")
|
||||
_ = upCmd.PersistentFlags().MarkDeprecated(enableLazyConnectionFlag, "no longer used; lazy connections are controlled by the server and the NB_LAZY_CONN environment variable")
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -479,10 +479,6 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
req.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
req.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
|
||||
return &req
|
||||
}
|
||||
|
||||
@@ -600,9 +596,6 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.DisableIPv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &ic, nil
|
||||
}
|
||||
|
||||
@@ -718,9 +711,6 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.DisableIpv6 = &disableIPv6
|
||||
}
|
||||
|
||||
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||
}
|
||||
return &loginRequest, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -17,12 +17,15 @@ import (
|
||||
|
||||
type KernelConfigurer struct {
|
||||
deviceName string
|
||||
statsCache *statsCache
|
||||
}
|
||||
|
||||
func NewKernelConfigurer(deviceName string) *KernelConfigurer {
|
||||
return &KernelConfigurer{
|
||||
c := &KernelConfigurer{
|
||||
deviceName: deviceName,
|
||||
}
|
||||
c.statsCache = newStatsCache(statsCacheTTL, c.fetchStats)
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
@@ -246,12 +249,6 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
||||
}
|
||||
}()
|
||||
|
||||
// validate if device with name exists
|
||||
_, err = wg.Device(c.deviceName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return wg.ConfigureDevice(c.deviceName, config)
|
||||
}
|
||||
|
||||
@@ -300,6 +297,14 @@ func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
return c.statsCache.get()
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) fetchStats() (map[string]WGStats, error) {
|
||||
stats := make(map[string]WGStats)
|
||||
wg, err := wgctrl.New()
|
||||
if err != nil {
|
||||
@@ -326,7 +331,3 @@ func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
}
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (c *KernelConfigurer) LastActivities() map[string]monotime.Time {
|
||||
return nil
|
||||
}
|
||||
|
||||
52
client/iface/configurer/stats_cache.go
Normal file
52
client/iface/configurer/stats_cache.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/sync/singleflight"
|
||||
)
|
||||
|
||||
const statsCacheTTL = 1 * time.Second
|
||||
|
||||
type statsCache struct {
|
||||
ttl time.Duration
|
||||
fetch func() (map[string]WGStats, error)
|
||||
|
||||
mu sync.RWMutex
|
||||
value map[string]WGStats
|
||||
expireAt time.Time
|
||||
|
||||
sf singleflight.Group
|
||||
}
|
||||
|
||||
func newStatsCache(ttl time.Duration, fetch func() (map[string]WGStats, error)) *statsCache {
|
||||
return &statsCache{ttl: ttl, fetch: fetch}
|
||||
}
|
||||
|
||||
func (c *statsCache) get() (map[string]WGStats, error) {
|
||||
c.mu.RLock()
|
||||
if c.value != nil && time.Now().Before(c.expireAt) {
|
||||
value := c.value
|
||||
c.mu.RUnlock()
|
||||
return value, nil
|
||||
}
|
||||
c.mu.RUnlock()
|
||||
|
||||
value, err, _ := c.sf.Do("stats", func() (interface{}, error) {
|
||||
res, err := c.fetch()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.value = res
|
||||
c.expireAt = time.Now().Add(c.ttl)
|
||||
c.mu.Unlock()
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return value.(map[string]WGStats), nil
|
||||
}
|
||||
70
client/iface/configurer/stats_cache_test.go
Normal file
70
client/iface/configurer/stats_cache_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package configurer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestStatsCache_CachesWithinTTL(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
c := newStatsCache(50*time.Millisecond, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
return map[string]WGStats{"p": {}}, nil
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err := c.get()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, int64(1), calls.Load(), "within TTL only one underlying fetch")
|
||||
|
||||
time.Sleep(60 * time.Millisecond)
|
||||
_, err := c.get()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), calls.Load(), "after TTL expiry a fresh fetch happens")
|
||||
}
|
||||
|
||||
func TestStatsCache_SingleFlight(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
release := make(chan struct{})
|
||||
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
<-release
|
||||
return map[string]WGStats{}, nil
|
||||
})
|
||||
|
||||
const n = 50
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_, _ = c.get()
|
||||
}()
|
||||
}
|
||||
time.Sleep(20 * time.Millisecond)
|
||||
close(release)
|
||||
wg.Wait()
|
||||
|
||||
require.Equal(t, int64(1), calls.Load(), "concurrent misses collapse into one fetch")
|
||||
}
|
||||
|
||||
func TestStatsCache_ErrorNotCached(t *testing.T) {
|
||||
var calls atomic.Int64
|
||||
wantErr := errors.New("dump failed")
|
||||
c := newStatsCache(time.Minute, func() (map[string]WGStats, error) {
|
||||
calls.Add(1)
|
||||
return nil, wantErr
|
||||
})
|
||||
|
||||
_, err := c.get()
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
_, err = c.get()
|
||||
require.ErrorIs(t, err, wantErr)
|
||||
require.Equal(t, int64(2), calls.Load(), "errors are not cached; each call retries")
|
||||
}
|
||||
@@ -40,6 +40,7 @@ type WGUSPConfigurer struct {
|
||||
device *device.Device
|
||||
deviceName string
|
||||
activityRecorder *bind.ActivityRecorder
|
||||
statsCache *statsCache
|
||||
|
||||
uapiListener net.Listener
|
||||
}
|
||||
@@ -50,16 +51,19 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
|
||||
wgCfg.startUAPI()
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
wgCfg := &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
wgCfg.statsCache = newStatsCache(statsCacheTTL, wgCfg.fetchStats)
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
@@ -348,6 +352,10 @@ func (t *WGUSPConfigurer) Close() {
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||
return t.statsCache.get()
|
||||
}
|
||||
|
||||
func (t *WGUSPConfigurer) fetchStats() (map[string]WGStats, error) {
|
||||
ipc, err := t.device.IpcGet()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("ipc get: %w", err)
|
||||
|
||||
@@ -322,7 +322,6 @@ func (a *Auth) setSystemInfoFlags(info *system.Info) {
|
||||
a.config.BlockLANAccess,
|
||||
a.config.BlockInbound,
|
||||
a.config.DisableIPv6,
|
||||
a.config.LazyConnectionEnabled,
|
||||
a.config.EnableSSHRoot,
|
||||
a.config.EnableSSHSFTP,
|
||||
a.config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -16,6 +16,16 @@ import (
|
||||
"github.com/netbirdio/netbird/route"
|
||||
)
|
||||
|
||||
// lazyForce is the resolved local decision for lazy connections, layered above the
|
||||
// management feature flag. lazyForceNone defers to management.
|
||||
type lazyForce int
|
||||
|
||||
const (
|
||||
lazyForceNone lazyForce = iota
|
||||
lazyForceOn
|
||||
lazyForceOff
|
||||
)
|
||||
|
||||
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||
//
|
||||
// The connection manager is responsible for:
|
||||
@@ -28,7 +38,7 @@ type ConnMgr struct {
|
||||
peerStore *peerstore.Store
|
||||
statusRecorder *peer.Status
|
||||
iface lazyconn.WGIface
|
||||
enabledLocally bool
|
||||
force lazyForce
|
||||
rosenpassEnabled bool
|
||||
|
||||
lazyConnMgr *manager.Manager
|
||||
@@ -43,28 +53,34 @@ func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerSto
|
||||
peerStore: peerStore,
|
||||
statusRecorder: statusRecorder,
|
||||
iface: iface,
|
||||
force: resolveLazyForce(engineConfig.LazyConnection),
|
||||
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||
}
|
||||
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||
e.enabledLocally = true
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||
// Start initializes the connection manager. It starts the lazy connection manager when a
|
||||
// local override forces it on; with no local override it waits for the management feature flag.
|
||||
func (e *ConnMgr) Start(ctx context.Context) {
|
||||
if e.lazyConnMgr != nil {
|
||||
log.Errorf("lazy connection manager is already started")
|
||||
return
|
||||
}
|
||||
|
||||
if !e.enabledLocally {
|
||||
log.Infof("lazy connection manager is disabled")
|
||||
switch e.force {
|
||||
case lazyForceOff:
|
||||
log.Infof("lazy connection manager is disabled by local override (%s or MDM policy)", lazyconn.EnvLazyConn)
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
case lazyForceNone:
|
||||
log.Infof("lazy connection manager is managed by the management feature flag")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
}
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -76,8 +92,8 @@ func (e *ConnMgr) Start(ctx context.Context) {
|
||||
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||
// do not disable lazy connection manager if it was enabled by env var
|
||||
if e.enabledLocally {
|
||||
// a local override (NB_LAZY_CONN or local config) takes precedence over management
|
||||
if e.force != lazyForceNone {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -89,6 +105,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
|
||||
if e.rosenpassEnabled {
|
||||
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -98,6 +115,7 @@ func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) er
|
||||
return e.addPeersToLazyConnManager()
|
||||
} else {
|
||||
if e.lazyConnMgr == nil {
|
||||
e.statusRecorder.UpdateLazyConnection(false)
|
||||
return nil
|
||||
}
|
||||
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||
@@ -309,6 +327,25 @@ func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
|
||||
}
|
||||
|
||||
// resolveLazyForce determines the local override. NB_LAZY_CONN takes precedence; when it
|
||||
// is unset the MDM policy override (mdmState) applies. Either wins in both directions over
|
||||
// the management feature flag; StateUnset for both defers to management.
|
||||
func resolveLazyForce(mdmState lazyconn.State) lazyForce {
|
||||
state := lazyconn.EnvState()
|
||||
if state == lazyconn.StateUnset {
|
||||
state = mdmState
|
||||
}
|
||||
|
||||
switch state {
|
||||
case lazyconn.StateOn:
|
||||
return lazyForceOn
|
||||
case lazyconn.StateOff:
|
||||
return lazyForceOff
|
||||
default:
|
||||
return lazyForceNone
|
||||
}
|
||||
}
|
||||
|
||||
func inactivityThresholdEnv() *time.Duration {
|
||||
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||
if envValue == "" {
|
||||
|
||||
40
client/internal/conn_mgr_test.go
Normal file
40
client/internal/conn_mgr_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
)
|
||||
|
||||
func TestResolveLazyForce(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
env string
|
||||
envSet bool
|
||||
mdm lazyconn.State
|
||||
want lazyForce
|
||||
}{
|
||||
{name: "env unset, mdm unset -> defer to management", mdm: lazyconn.StateUnset, want: lazyForceNone},
|
||||
{name: "env on -> force on", env: "on", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOn},
|
||||
{name: "env off -> force off", env: "off", envSet: true, mdm: lazyconn.StateUnset, want: lazyForceOff},
|
||||
{name: "env unset, mdm on -> force on", mdm: lazyconn.StateOn, want: lazyForceOn},
|
||||
{name: "env unset, mdm off -> force off", mdm: lazyconn.StateOff, want: lazyForceOff},
|
||||
{name: "env on beats mdm off", env: "on", envSet: true, mdm: lazyconn.StateOff, want: lazyForceOn},
|
||||
{name: "env off beats mdm on", env: "off", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOff},
|
||||
{name: "unrecognized env, mdm on -> mdm wins", env: "auto", envSet: true, mdm: lazyconn.StateOn, want: lazyForceOn},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Setenv(lazyconn.EnvLazyConn, tt.env)
|
||||
if !tt.envSet {
|
||||
os.Unsetenv(lazyconn.EnvLazyConn)
|
||||
}
|
||||
|
||||
if got := resolveLazyForce(tt.mdm); got != tt.want {
|
||||
t.Fatalf("resolveLazyForce(%v) = %v, want %v", tt.mdm, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -314,6 +315,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
||||
c.statusRecorder.MarkManagementConnected()
|
||||
|
||||
if metricsConfig := loginResp.GetNetbirdConfig().GetMetrics(); metricsConfig != nil {
|
||||
c.clientMetrics.UpdatePushFromMgm(c.ctx, metricsConfig.GetEnabled())
|
||||
}
|
||||
|
||||
localPeerState := peer.LocalPeerState{
|
||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||
PubKey: myPrivateKey.PublicKey().String(),
|
||||
@@ -399,6 +404,7 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
StateManager: stateManager,
|
||||
UpdateManager: c.updateManager,
|
||||
ClientMetrics: c.clientMetrics,
|
||||
MetricsCtx: c.ctx,
|
||||
}, mobileDependency)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
@@ -596,7 +602,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
||||
BlockInbound: config.BlockInbound,
|
||||
DisableIPv6: config.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||
LazyConnection: lazyconn.ParseState(config.LazyConnection),
|
||||
|
||||
MTU: selectMTU(config.MTU, peerConfig.Mtu),
|
||||
LogPath: logPath,
|
||||
@@ -670,7 +676,6 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
||||
config.BlockLANAccess,
|
||||
config.BlockInbound,
|
||||
config.DisableIPv6,
|
||||
config.LazyConnectionEnabled,
|
||||
config.EnableSSHRoot,
|
||||
config.EnableSSHSFTP,
|
||||
config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -681,7 +681,7 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
||||
configContent.WriteString(fmt.Sprintf("ClientCertKeyPath: %s\n", g.internalConfig.ClientCertKeyPath))
|
||||
}
|
||||
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||
configContent.WriteString(fmt.Sprintf("LazyConnection: %q\n", g.internalConfig.LazyConnection))
|
||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
||||
}
|
||||
|
||||
|
||||
@@ -885,7 +885,7 @@ func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
||||
DNSRouteInterval: 5 * time.Second,
|
||||
ClientCertPath: "/tmp/cert",
|
||||
ClientCertKeyPath: "/tmp/key",
|
||||
LazyConnectionEnabled: true,
|
||||
LazyConnection: "on",
|
||||
MTU: 1280,
|
||||
}
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -147,7 +148,9 @@ type EngineConfig struct {
|
||||
BlockInbound bool
|
||||
DisableIPv6 bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
// LazyConnection is the MDM-sourced lazy-connection override; StateUnset defers to
|
||||
// the env var and management feature flag.
|
||||
LazyConnection lazyconn.State
|
||||
|
||||
MTU uint16
|
||||
|
||||
@@ -172,6 +175,7 @@ type EngineServices struct {
|
||||
StateManager *statemanager.Manager
|
||||
UpdateManager *updater.Manager
|
||||
ClientMetrics *metrics.ClientMetrics
|
||||
MetricsCtx context.Context
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
@@ -264,6 +268,7 @@ type Engine struct {
|
||||
|
||||
// clientMetrics collects and pushes metrics
|
||||
clientMetrics *metrics.ClientMetrics
|
||||
metricsCtx context.Context
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
@@ -316,6 +321,7 @@ func NewEngine(
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
clientMetrics: services.ClientMetrics,
|
||||
metricsCtx: services.MetricsCtx,
|
||||
updateManager: services.UpdateManager,
|
||||
syncStoreDir: config.StateDir,
|
||||
}
|
||||
@@ -997,6 +1003,8 @@ func (e *Engine) updateNetbirdConfig(wCfg *mgmProto.NetbirdConfig) error {
|
||||
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||
}
|
||||
|
||||
e.handleMetricsUpdate(wCfg.GetMetrics())
|
||||
|
||||
if err := e.PopulateNetbirdConfig(wCfg, nil); err != nil {
|
||||
log.Warnf("Failed to update DNS server config: %v", err)
|
||||
}
|
||||
@@ -1066,6 +1074,14 @@ func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
||||
return e.flowManager.Update(flowConfig)
|
||||
}
|
||||
|
||||
func (e *Engine) handleMetricsUpdate(config *mgmProto.MetricsConfig) {
|
||||
if config == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("received metrics configuration from management: enabled=%v", config.GetEnabled())
|
||||
e.clientMetrics.UpdatePushFromMgm(e.metricsCtx, config.GetEnabled())
|
||||
}
|
||||
|
||||
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||
if config.GetInterval() == nil {
|
||||
return nil, errors.New("flow interval is nil")
|
||||
@@ -1117,7 +1133,6 @@ func (e *Engine) applyInfoFlags(info *system.Info) {
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
@@ -1986,7 +2001,6 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
||||
e.config.BlockLANAccess,
|
||||
e.config.BlockInbound,
|
||||
e.config.DisableIPv6,
|
||||
e.config.LazyConnectionEnabled,
|
||||
e.config.EnableSSHRoot,
|
||||
e.config.EnableSSHSFTP,
|
||||
e.config.EnableSSHLocalPortForwarding,
|
||||
|
||||
@@ -3,24 +3,57 @@ package lazyconn
|
||||
import (
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
EnvEnableLazyConn = "NB_ENABLE_EXPERIMENTAL_LAZY_CONN"
|
||||
EnvLazyConn = "NB_LAZY_CONN"
|
||||
EnvInactivityThreshold = "NB_LAZY_CONN_INACTIVITY_THRESHOLD"
|
||||
)
|
||||
|
||||
func IsLazyConnEnabledByEnv() bool {
|
||||
val := os.Getenv(EnvEnableLazyConn)
|
||||
if val == "" {
|
||||
return false
|
||||
}
|
||||
enabled, err := strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse %s: %v", EnvEnableLazyConn, err)
|
||||
return false
|
||||
}
|
||||
return enabled
|
||||
// State is the tri-state local override for lazy connections read from the environment.
|
||||
type State int
|
||||
|
||||
const (
|
||||
// StateUnset means no local override; defer to the management feature flag.
|
||||
StateUnset State = iota
|
||||
// StateOn forces lazy connections on, overriding management.
|
||||
StateOn
|
||||
// StateOff forces lazy connections off, overriding management.
|
||||
StateOff
|
||||
)
|
||||
|
||||
// EnvState reads NB_LAZY_CONN and returns the local override state.
|
||||
func EnvState() State {
|
||||
return ParseState(os.Getenv(EnvLazyConn))
|
||||
}
|
||||
|
||||
// ParseState interprets a lazy-connection override value (from the environment or an MDM
|
||||
// policy). It accepts the on/off aliases plus any value strconv.ParseBool understands
|
||||
// (true/false/1/0). An empty or unrecognized value returns StateUnset so that the
|
||||
// management feature flag remains in control.
|
||||
func ParseState(raw string) State {
|
||||
if raw == "" {
|
||||
return StateUnset
|
||||
}
|
||||
|
||||
normalized := strings.ToLower(strings.TrimSpace(raw))
|
||||
switch normalized {
|
||||
case "on":
|
||||
return StateOn
|
||||
case "off":
|
||||
return StateOff
|
||||
}
|
||||
|
||||
enabled, err := strconv.ParseBool(normalized)
|
||||
if err != nil {
|
||||
log.Warnf("failed to parse lazy connection value %q (from %s env or MDM policy): %v", raw, EnvLazyConn, err)
|
||||
return StateUnset
|
||||
}
|
||||
if enabled {
|
||||
return StateOn
|
||||
}
|
||||
return StateOff
|
||||
}
|
||||
|
||||
45
client/internal/lazyconn/env_test.go
Normal file
45
client/internal/lazyconn/env_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package lazyconn
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestEnvState(t *testing.T) {
|
||||
tests := []struct {
|
||||
value string
|
||||
set bool
|
||||
want State
|
||||
}{
|
||||
{set: false, want: StateUnset},
|
||||
{value: "", set: true, want: StateUnset},
|
||||
{value: "on", set: true, want: StateOn},
|
||||
{value: "ON", set: true, want: StateOn},
|
||||
{value: "true", set: true, want: StateOn},
|
||||
{value: "1", set: true, want: StateOn},
|
||||
{value: " on ", set: true, want: StateOn},
|
||||
{value: "off", set: true, want: StateOff},
|
||||
{value: "OFF", set: true, want: StateOff},
|
||||
{value: "false", set: true, want: StateOff},
|
||||
{value: "0", set: true, want: StateOff},
|
||||
{value: "auto", set: true, want: StateUnset},
|
||||
{value: "garbage", set: true, want: StateUnset},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
name := tt.value
|
||||
if !tt.set {
|
||||
name = "unset"
|
||||
}
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Setenv(EnvLazyConn, tt.value)
|
||||
if !tt.set {
|
||||
os.Unsetenv(EnvLazyConn)
|
||||
}
|
||||
|
||||
if got := EnvState(); got != tt.want {
|
||||
t.Fatalf("EnvState() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -60,6 +60,13 @@ func getMetricsInterval() time.Duration {
|
||||
return interval
|
||||
}
|
||||
|
||||
// isMetricsPushEnvSet returns true if NB_METRICS_PUSH_ENABLED is explicitly set (to any value).
|
||||
// When set, the env var takes full precedence over management server configuration.
|
||||
func isMetricsPushEnvSet() bool {
|
||||
_, set := os.LookupEnv(EnvMetricsPushEnabled)
|
||||
return set
|
||||
}
|
||||
|
||||
func isForceSending() bool {
|
||||
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
|
||||
return force
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -75,7 +76,7 @@ type ClientMetrics struct {
|
||||
agentInfo AgentInfo
|
||||
mu sync.RWMutex
|
||||
|
||||
push *Push
|
||||
push atomic.Pointer[Push]
|
||||
pushMu sync.Mutex
|
||||
wg sync.WaitGroup
|
||||
pushCancel context.CancelFunc
|
||||
@@ -167,10 +168,7 @@ func (c *ClientMetrics) UpdateAgentInfo(agentInfo AgentInfo, publicKey string) {
|
||||
c.agentInfo = agentInfo
|
||||
c.mu.Unlock()
|
||||
|
||||
c.pushMu.Lock()
|
||||
push := c.push
|
||||
c.pushMu.Unlock()
|
||||
if push != nil {
|
||||
if push := c.push.Load(); push != nil {
|
||||
push.SetPeerID(agentInfo.peerID)
|
||||
}
|
||||
}
|
||||
@@ -184,7 +182,7 @@ func (c *ClientMetrics) Export(w io.Writer) error {
|
||||
return c.impl.Export(w)
|
||||
}
|
||||
|
||||
// StartPush starts periodic pushing of metrics with the given configuration
|
||||
// StartPush starts periodic pushing of metrics with the given configuration.
|
||||
// Precedence: PushConfig.ServerAddress > remote config server_url
|
||||
func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
||||
if c == nil {
|
||||
@@ -194,11 +192,58 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
||||
c.pushMu.Lock()
|
||||
defer c.pushMu.Unlock()
|
||||
|
||||
if c.push != nil {
|
||||
if c.push.Load() != nil {
|
||||
log.Warnf("metrics push already running")
|
||||
return
|
||||
}
|
||||
|
||||
c.startPushLocked(ctx, config)
|
||||
}
|
||||
|
||||
// StopPush stops the periodic metrics push.
|
||||
func (c *ClientMetrics) StopPush() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.pushMu.Lock()
|
||||
defer c.pushMu.Unlock()
|
||||
|
||||
c.stopPushLocked()
|
||||
}
|
||||
|
||||
// UpdatePushFromMgm updates metrics push based on management server configuration.
|
||||
// If NB_METRICS_PUSH_ENABLED is explicitly set (true or false), management config is ignored.
|
||||
// When unset, management controls whether push is enabled.
|
||||
func (c *ClientMetrics) UpdatePushFromMgm(ctx context.Context, enabled bool) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if isMetricsPushEnvSet() {
|
||||
log.Debugf("ignoring management config, env var is explicitly set: %s", EnvMetricsPushEnabled)
|
||||
return
|
||||
}
|
||||
|
||||
c.pushMu.Lock()
|
||||
defer c.pushMu.Unlock()
|
||||
|
||||
if enabled {
|
||||
if c.push.Load() != nil {
|
||||
return
|
||||
}
|
||||
log.Infof("enabled metrics push by management")
|
||||
c.startPushLocked(ctx, PushConfigFromEnv())
|
||||
} else {
|
||||
if c.push.Load() == nil {
|
||||
return
|
||||
}
|
||||
log.Infof("disabled metrics push by management")
|
||||
c.stopPushLocked()
|
||||
}
|
||||
}
|
||||
|
||||
// startPushLocked starts push. Caller must hold pushMu.
|
||||
func (c *ClientMetrics) startPushLocked(ctx context.Context, config PushConfig) {
|
||||
c.mu.RLock()
|
||||
agentVersion := c.agentInfo.Version
|
||||
peerID := c.agentInfo.peerID
|
||||
@@ -214,26 +259,23 @@ func (c *ClientMetrics) StartPush(ctx context.Context, config PushConfig) {
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
c.pushCancel = cancel
|
||||
c.push.Store(push)
|
||||
|
||||
c.wg.Add(1)
|
||||
go func() {
|
||||
defer c.wg.Done()
|
||||
push.Start(ctx)
|
||||
c.push.CompareAndSwap(push, nil)
|
||||
}()
|
||||
c.push = push
|
||||
}
|
||||
|
||||
func (c *ClientMetrics) StopPush() {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
c.pushMu.Lock()
|
||||
defer c.pushMu.Unlock()
|
||||
if c.push == nil {
|
||||
// stopPushLocked stops push. Caller must hold pushMu.
|
||||
func (c *ClientMetrics) stopPushLocked() {
|
||||
if c.push.Load() == nil {
|
||||
return
|
||||
}
|
||||
|
||||
c.pushCancel()
|
||||
c.wg.Wait()
|
||||
c.push = nil
|
||||
c.push.Store(nil)
|
||||
}
|
||||
|
||||
@@ -803,15 +803,17 @@ func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
|
||||
}
|
||||
|
||||
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
|
||||
if !conn.wgWatcher.IsEnabled() {
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||
}()
|
||||
if !conn.wgWatcher.PrepareInitialHandshake() {
|
||||
return
|
||||
}
|
||||
|
||||
wgWatcherCtx, wgWatcherCancel := context.WithCancel(conn.ctx)
|
||||
conn.wgWatcherCancel = wgWatcherCancel
|
||||
conn.wgWatcherWg.Add(1)
|
||||
go func() {
|
||||
defer conn.wgWatcherWg.Done()
|
||||
conn.wgWatcher.EnableWgWatcher(wgWatcherCtx, enabledTime, conn.onWGDisconnected, conn.onWGHandshakeSuccess)
|
||||
}()
|
||||
}
|
||||
|
||||
func (conn *Conn) disableWgWatcherIfNeeded() {
|
||||
|
||||
@@ -31,7 +31,9 @@ type WGWatcher struct {
|
||||
stateDump *stateDump
|
||||
|
||||
enabled bool
|
||||
muEnabled sync.RWMutex
|
||||
muEnabled sync.Mutex
|
||||
// initialHandshake is not thread-safe; never call PrepareInitialHandshake and EnableWgWatcher concurrently.
|
||||
initialHandshake time.Time
|
||||
|
||||
resetCh chan struct{}
|
||||
}
|
||||
@@ -46,38 +48,38 @@ func NewWGWatcher(log *log.Entry, wgIfaceStater WGInterfaceStater, peerKey strin
|
||||
}
|
||||
}
|
||||
|
||||
// EnableWgWatcher starts the WireGuard watcher. If it is already enabled, it will return immediately and do nothing.
|
||||
// The watcher runs until ctx is cancelled. Caller is responsible for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||
// PrepareInitialHandshake reserves the watcher and reads the peer's current WireGuard
|
||||
// handshake time. It must be called before the peer is (re)configured on the WireGuard
|
||||
// interface, so the captured baseline reflects the state prior to this connection attempt
|
||||
// instead of racing with that configuration. Returns ok=false if the watcher is already
|
||||
// running, in which case EnableWgWatcher must not be called.
|
||||
func (w *WGWatcher) PrepareInitialHandshake() (ok bool) {
|
||||
w.muEnabled.Lock()
|
||||
if w.enabled {
|
||||
w.muEnabled.Unlock()
|
||||
return
|
||||
return false
|
||||
}
|
||||
|
||||
w.log.Debugf("enable WireGuard watcher")
|
||||
w.enabled = true
|
||||
w.muEnabled.Unlock()
|
||||
|
||||
initialHandshake, err := w.wgState()
|
||||
if err != nil {
|
||||
w.log.Warnf("failed to read initial wg stats: %v", err)
|
||||
}
|
||||
handshake, _ := w.wgState()
|
||||
w.initialHandshake = handshake
|
||||
return true
|
||||
}
|
||||
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, initialHandshake)
|
||||
// EnableWgWatcher runs the WireGuard watcher loop using the handshake baseline captured by
|
||||
// PrepareInitialHandshake. The watcher runs until ctx is cancelled. Caller is responsible
|
||||
// for context lifecycle management.
|
||||
func (w *WGWatcher) EnableWgWatcher(ctx context.Context, enabledTime time.Time, onDisconnectedFn func(), onHandshakeSuccessFn func(when time.Time)) {
|
||||
w.periodicHandshakeCheck(ctx, onDisconnectedFn, onHandshakeSuccessFn, enabledTime, w.initialHandshake)
|
||||
|
||||
w.muEnabled.Lock()
|
||||
w.enabled = false
|
||||
w.muEnabled.Unlock()
|
||||
}
|
||||
|
||||
// IsEnabled returns true if the WireGuard watcher is currently enabled
|
||||
func (w *WGWatcher) IsEnabled() bool {
|
||||
w.muEnabled.RLock()
|
||||
defer w.muEnabled.RUnlock()
|
||||
return w.enabled
|
||||
}
|
||||
|
||||
// Reset signals the watcher that the WireGuard peer has been reset and a new
|
||||
// handshake is expected. This restarts the handshake timeout from scratch.
|
||||
func (w *WGWatcher) Reset() {
|
||||
@@ -101,13 +103,16 @@ func (w *WGWatcher) periodicHandshakeCheck(ctx context.Context, onDisconnectedFn
|
||||
case <-timer.C:
|
||||
handshake, ok := w.handshakeCheck(lastHandshake)
|
||||
if !ok {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
onDisconnectedFn()
|
||||
return
|
||||
}
|
||||
if lastHandshake.IsZero() {
|
||||
elapsed := calcElapsed(enabledTime, *handshake)
|
||||
w.log.Infof("first wg handshake detected within: %.2fsec, (%s)", elapsed, handshake)
|
||||
if onHandshakeSuccessFn != nil {
|
||||
if onHandshakeSuccessFn != nil && ctx.Err() == nil {
|
||||
onHandshakeSuccessFn(*handshake)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||
)
|
||||
@@ -34,6 +35,9 @@ func TestWGWatcher_EnableWgWatcher(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ok := watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should not be enabled yet")
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
mlog.Infof("onDisconnectedFn")
|
||||
@@ -62,6 +66,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
watcher := NewWGWatcher(mlog, mocWgIface, "", newStateDump("peer", mlog, &Status{}))
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
ok := watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should not be enabled yet")
|
||||
|
||||
wg := &sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
@@ -76,6 +83,9 @@ func TestWGWatcher_ReEnable(t *testing.T) {
|
||||
ctx, cancel = context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ok = watcher.PrepareInitialHandshake()
|
||||
require.True(t, ok, "watcher should be re-enabled after the previous run stopped")
|
||||
|
||||
onDisconnected := make(chan struct{}, 1)
|
||||
go watcher.EnableWgWatcher(ctx, time.Now(), func() {
|
||||
onDisconnected <- struct{}{}
|
||||
|
||||
@@ -101,8 +101,6 @@ type ConfigInput struct {
|
||||
|
||||
DNSLabels domain.List
|
||||
|
||||
LazyConnectionEnabled *bool
|
||||
|
||||
MTU *uint16
|
||||
}
|
||||
|
||||
@@ -180,7 +178,9 @@ type Config struct {
|
||||
|
||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
// LazyConnection is the MDM-managed lazy-connection override ("on"/"off"/"").
|
||||
// Runtime-only: re-derived from MDM policy on each load, never persisted.
|
||||
LazyConnection string `json:"-"`
|
||||
|
||||
MTU uint16
|
||||
|
||||
@@ -386,7 +386,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.NetworkMonitor != nil && input.NetworkMonitor != config.NetworkMonitor {
|
||||
if input.NetworkMonitor != nil && (config.NetworkMonitor == nil || *input.NetworkMonitor != *config.NetworkMonitor) {
|
||||
log.Infof("switching Network Monitor to %t", *input.NetworkMonitor)
|
||||
config.NetworkMonitor = input.NetworkMonitor
|
||||
updated = true
|
||||
@@ -454,7 +454,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRoot != nil && input.EnableSSHRoot != config.EnableSSHRoot {
|
||||
if input.EnableSSHRoot != nil && (config.EnableSSHRoot == nil || *input.EnableSSHRoot != *config.EnableSSHRoot) {
|
||||
if *input.EnableSSHRoot {
|
||||
log.Infof("enabling SSH root login")
|
||||
} else {
|
||||
@@ -464,7 +464,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHSFTP != nil && input.EnableSSHSFTP != config.EnableSSHSFTP {
|
||||
if input.EnableSSHSFTP != nil && (config.EnableSSHSFTP == nil || *input.EnableSSHSFTP != *config.EnableSSHSFTP) {
|
||||
if *input.EnableSSHSFTP {
|
||||
log.Infof("enabling SSH SFTP subsystem")
|
||||
} else {
|
||||
@@ -474,7 +474,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHLocalPortForwarding != nil && input.EnableSSHLocalPortForwarding != config.EnableSSHLocalPortForwarding {
|
||||
if input.EnableSSHLocalPortForwarding != nil && (config.EnableSSHLocalPortForwarding == nil || *input.EnableSSHLocalPortForwarding != *config.EnableSSHLocalPortForwarding) {
|
||||
if *input.EnableSSHLocalPortForwarding {
|
||||
log.Infof("enabling SSH local port forwarding")
|
||||
} else {
|
||||
@@ -484,7 +484,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.EnableSSHRemotePortForwarding != nil && input.EnableSSHRemotePortForwarding != config.EnableSSHRemotePortForwarding {
|
||||
if input.EnableSSHRemotePortForwarding != nil && (config.EnableSSHRemotePortForwarding == nil || *input.EnableSSHRemotePortForwarding != *config.EnableSSHRemotePortForwarding) {
|
||||
if *input.EnableSSHRemotePortForwarding {
|
||||
log.Infof("enabling SSH remote port forwarding")
|
||||
} else {
|
||||
@@ -494,7 +494,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||
if input.DisableSSHAuth != nil && (config.DisableSSHAuth == nil || *input.DisableSSHAuth != *config.DisableSSHAuth) {
|
||||
if *input.DisableSSHAuth {
|
||||
log.Infof("disabling SSH authentication")
|
||||
} else {
|
||||
@@ -504,7 +504,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.SSHJWTCacheTTL != nil && input.SSHJWTCacheTTL != config.SSHJWTCacheTTL {
|
||||
if input.SSHJWTCacheTTL != nil && (config.SSHJWTCacheTTL == nil || *input.SSHJWTCacheTTL != *config.SSHJWTCacheTTL) {
|
||||
log.Infof("updating SSH JWT cache TTL to %d seconds", *input.SSHJWTCacheTTL)
|
||||
config.SSHJWTCacheTTL = input.SSHJWTCacheTTL
|
||||
updated = true
|
||||
@@ -587,7 +587,7 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||
if input.DisableNotifications != nil && (config.DisableNotifications == nil || *input.DisableNotifications != *config.DisableNotifications) {
|
||||
if *input.DisableNotifications {
|
||||
log.Infof("disabling notifications")
|
||||
} else {
|
||||
@@ -632,12 +632,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||
updated = true
|
||||
}
|
||||
|
||||
if input.MTU != nil && *input.MTU != config.MTU {
|
||||
log.Infof("updating MTU to %d (old value %d)", *input.MTU, config.MTU)
|
||||
config.MTU = *input.MTU
|
||||
@@ -728,6 +722,15 @@ func (config *Config) applyMDMPolicy(policy *mdm.Policy) {
|
||||
log.Warnf("MDM wireguard port %d out of range [1,65535]; keeping previous value", v)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := policy.GetBool(mdm.KeyLazyConnection); ok {
|
||||
state := "off"
|
||||
if v {
|
||||
state = "on"
|
||||
}
|
||||
config.LazyConnection = state
|
||||
logApplied(mdm.KeyLazyConnection, state)
|
||||
}
|
||||
}
|
||||
|
||||
// parseURL parses and validates the URL for the named service. The URL
|
||||
|
||||
@@ -130,6 +130,37 @@ func TestApply_MDMBoolKeysOverrideOnDiskValue(t *testing.T) {
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyRosenpassEnabled))
|
||||
}
|
||||
|
||||
func TestApply_MDMLazyConnection(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
raw any
|
||||
want string
|
||||
}{
|
||||
{"native true", true, "on"},
|
||||
{"native false", false, "off"},
|
||||
{"string on", "on", "on"},
|
||||
{"string off", "off", "off"},
|
||||
{"string yes", "yes", "on"},
|
||||
{"string no", "no", "off"},
|
||||
}
|
||||
for _, c := range cases {
|
||||
t.Run(c.name, func(t *testing.T) {
|
||||
withMDMPolicy(t, mdm.NewPolicy(map[string]any{
|
||||
mdm.KeyLazyConnection: c.raw,
|
||||
}))
|
||||
|
||||
cfg, err := UpdateOrCreateConfig(ConfigInput{
|
||||
ConfigPath: filepath.Join(t.TempDir(), "config.json"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cfg)
|
||||
|
||||
assert.Equal(t, c.want, cfg.LazyConnection)
|
||||
assert.True(t, cfg.Policy().HasKey(mdm.KeyLazyConnection))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApply_MDMPreSharedKeyRedactionSentinelRejected(t *testing.T) {
|
||||
const maskSentinel = "**********"
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ func GetEnvKeyNBForceRelay() string {
|
||||
|
||||
// GetEnvKeyNBLazyConn Exports the environment variable for the iOS client
|
||||
func GetEnvKeyNBLazyConn() string {
|
||||
return lazyconn.EnvEnableLazyConn
|
||||
return lazyconn.EnvLazyConn
|
||||
}
|
||||
|
||||
// GetEnvKeyNBInactivityThreshold Exports the environment variable for the iOS client
|
||||
|
||||
@@ -27,6 +27,7 @@ var allKeys = []string{
|
||||
KeyWireguardPort,
|
||||
KeySplitTunnelMode,
|
||||
KeySplitTunnelApps,
|
||||
KeyLazyConnection,
|
||||
}
|
||||
|
||||
// canonicalKey maps the lowercase form of a managed-config value name to
|
||||
|
||||
@@ -11,6 +11,7 @@ package mdm
|
||||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -41,6 +42,11 @@ const (
|
||||
// construction — only one mode can be set at a time.
|
||||
KeySplitTunnelMode = "splitTunnelMode"
|
||||
KeySplitTunnelApps = "splitTunnelApps"
|
||||
|
||||
// KeyLazyConnection forces the lazy-connection feature on or off, overriding
|
||||
// the management feature flag. Read as a bool (native bool, or on/off,
|
||||
// true/false, 1/0, yes/no); absent = defer to management.
|
||||
KeyLazyConnection = "lazyConnection"
|
||||
)
|
||||
|
||||
// Split-tunnel mode literals (KeySplitTunnelMode values).
|
||||
@@ -62,12 +68,13 @@ var boolStringLiterals = map[string]bool{
|
||||
"true": true,
|
||||
"1": true,
|
||||
"yes": true,
|
||||
"on": true,
|
||||
"false": false,
|
||||
"0": false,
|
||||
"no": false,
|
||||
"off": false,
|
||||
}
|
||||
|
||||
|
||||
// Policy holds MDM-managed settings read from the platform source. A nil or
|
||||
// empty Policy means no enforcement is active.
|
||||
type Policy struct {
|
||||
@@ -150,7 +157,8 @@ func (p *Policy) GetString(key string) (string, bool) {
|
||||
}
|
||||
|
||||
// GetBool returns the managed value for key coerced to bool, and whether the
|
||||
// key was set. Accepts native bool and string literals "true"/"false"/"1"/"0".
|
||||
// key was set. Accepts native bool and string literals (true/false, 1/0,
|
||||
// yes/no, on/off), case-insensitively and trimmed of surrounding whitespace.
|
||||
func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
if p == nil {
|
||||
return false, false
|
||||
@@ -163,7 +171,7 @@ func (p *Policy) GetBool(key string) (bool, bool) {
|
||||
case bool:
|
||||
return t, true
|
||||
case string:
|
||||
b, known := boolStringLiterals[t]
|
||||
b, known := boolStringLiterals[strings.ToLower(strings.TrimSpace(t))]
|
||||
return b, known
|
||||
case int:
|
||||
return t != 0, true
|
||||
|
||||
@@ -31,8 +31,8 @@ func TestPolicy_Empty(t *testing.T) {
|
||||
|
||||
func TestPolicy_HasKey(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true,
|
||||
})
|
||||
assert.False(t, p.IsEmpty())
|
||||
assert.True(t, p.HasKey(KeyManagementURL))
|
||||
@@ -53,8 +53,8 @@ func TestPolicy_ManagedKeysSorted(t *testing.T) {
|
||||
func TestPolicy_GetString(t *testing.T) {
|
||||
p := NewPolicy(map[string]any{
|
||||
KeyManagementURL: "https://corp.example.com",
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
KeyDisableProfiles: true, // wrong type for GetString
|
||||
KeyPreSharedKey: "", // empty rejected
|
||||
})
|
||||
v, ok := p.GetString(KeyManagementURL)
|
||||
assert.True(t, ok)
|
||||
@@ -85,6 +85,11 @@ func TestPolicy_GetBool(t *testing.T) {
|
||||
{"string 0", "0", false, true},
|
||||
{"string yes", "yes", true, true},
|
||||
{"string no", "no", false, true},
|
||||
{"string on", "on", true, true},
|
||||
{"string off", "off", false, true},
|
||||
{"mixed case On", "On", true, true},
|
||||
{"upper TRUE", "TRUE", true, true},
|
||||
{"padded yes", " yes ", true, true},
|
||||
{"int nonzero", 1, true, true},
|
||||
{"int zero", 0, false, true},
|
||||
{"int64 nonzero", int64(2), true, true},
|
||||
|
||||
@@ -152,7 +152,6 @@ func (s *Server) restartEngineForMDMLocked() error {
|
||||
s.config = config
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
ctx, cancel := context.WithCancel(s.rootCtx)
|
||||
s.actCancel = cancel
|
||||
@@ -305,7 +304,6 @@ func setConfigRequestHasConfigOverrides(msg *proto.SetConfigRequest) bool {
|
||||
msg.DisableFirewall != nil ||
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil ||
|
||||
msg.DisableIpv6 != nil ||
|
||||
msg.EnableSSHRoot != nil ||
|
||||
@@ -348,7 +346,6 @@ func loginRequestHasConfigOverrides(msg *proto.LoginRequest) bool {
|
||||
msg.BlockLanAccess != nil ||
|
||||
msg.DisableNotifications != nil ||
|
||||
len(msg.DnsLabels) > 0 || msg.CleanDNSLabels ||
|
||||
msg.LazyConnectionEnabled != nil ||
|
||||
msg.BlockInbound != nil
|
||||
}
|
||||
|
||||
|
||||
@@ -214,7 +214,6 @@ func (s *Server) Start() error {
|
||||
|
||||
s.statusRecorder.UpdateManagementAddress(config.ManagementURL.String())
|
||||
s.statusRecorder.UpdateRosenpass(config.RosenpassEnabled, config.RosenpassPermissive)
|
||||
s.statusRecorder.UpdateLazyConnection(config.LazyConnectionEnabled)
|
||||
|
||||
if s.sessionWatcher == nil {
|
||||
s.sessionWatcher = internal.NewSessionWatcher(s.rootCtx, s.statusRecorder)
|
||||
@@ -463,7 +462,6 @@ func (s *Server) setConfigInputFromRequest(msg *proto.SetConfigRequest) (profile
|
||||
config.DisableFirewall = msg.DisableFirewall
|
||||
config.BlockLANAccess = msg.BlockLanAccess
|
||||
config.DisableNotifications = msg.DisableNotifications
|
||||
config.LazyConnectionEnabled = msg.LazyConnectionEnabled
|
||||
config.BlockInbound = msg.BlockInbound
|
||||
config.DisableIPv6 = msg.DisableIpv6
|
||||
config.EnableSSHRoot = msg.EnableSSHRoot
|
||||
@@ -1647,7 +1645,6 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||
RosenpassPermissive: cfg.RosenpassPermissive,
|
||||
LazyConnectionEnabled: cfg.LazyConnectionEnabled,
|
||||
BlockInbound: cfg.BlockInbound,
|
||||
DisableNotifications: disableNotifications,
|
||||
NetworkMonitor: networkMonitor,
|
||||
|
||||
@@ -69,43 +69,41 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
disableFirewall := true
|
||||
blockLANAccess := true
|
||||
disableNotifications := true
|
||||
lazyConnectionEnabled := true
|
||||
blockInbound := true
|
||||
disableIPv6 := true
|
||||
mtu := int64(1280)
|
||||
sshJWTCacheTTL := int32(300)
|
||||
|
||||
req := &proto.SetConfigRequest{
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
ProfileName: profName,
|
||||
Username: currUser.Username,
|
||||
ManagementUrl: "https://new-api.netbird.io:443",
|
||||
AdminURL: "https://new-admin.netbird.io",
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
RosenpassPermissive: &rosenpassPermissive,
|
||||
ServerSSHAllowed: &serverSSHAllowed,
|
||||
InterfaceName: &interfaceName,
|
||||
WireguardPort: &wireguardPort,
|
||||
OptionalPreSharedKey: &preSharedKey,
|
||||
DisableAutoConnect: &disableAutoConnect,
|
||||
NetworkMonitor: &networkMonitor,
|
||||
DisableClientRoutes: &disableClientRoutes,
|
||||
DisableServerRoutes: &disableServerRoutes,
|
||||
DisableDns: &disableDNS,
|
||||
DisableFirewall: &disableFirewall,
|
||||
BlockLanAccess: &blockLANAccess,
|
||||
DisableNotifications: &disableNotifications,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableIpv6: &disableIPv6,
|
||||
NatExternalIPs: []string{"1.2.3.4", "5.6.7.8"},
|
||||
CleanNATExternalIPs: false,
|
||||
CustomDNSAddress: []byte("1.1.1.1:53"),
|
||||
ExtraIFaceBlacklist: []string{"eth1", "eth2"},
|
||||
DnsLabels: []string{"label1", "label2"},
|
||||
CleanDNSLabels: false,
|
||||
DnsRouteInterval: durationpb.New(2 * time.Minute),
|
||||
Mtu: &mtu,
|
||||
SshJWTCacheTTL: &sshJWTCacheTTL,
|
||||
}
|
||||
|
||||
_, err = s.SetConfig(ctx, req)
|
||||
@@ -140,7 +138,6 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
||||
require.Equal(t, blockLANAccess, cfg.BlockLANAccess)
|
||||
require.NotNil(t, cfg.DisableNotifications)
|
||||
require.Equal(t, disableNotifications, *cfg.DisableNotifications)
|
||||
require.Equal(t, lazyConnectionEnabled, cfg.LazyConnectionEnabled)
|
||||
require.Equal(t, blockInbound, cfg.BlockInbound)
|
||||
require.Equal(t, disableIPv6, cfg.DisableIPv6)
|
||||
require.Equal(t, []string{"1.2.3.4", "5.6.7.8"}, cfg.NATExternalIPs)
|
||||
@@ -164,13 +161,14 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
t.Helper()
|
||||
|
||||
metadataFields := map[string]bool{
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
"state": true, // protobuf internal
|
||||
"sizeCache": true, // protobuf internal
|
||||
"unknownFields": true, // protobuf internal
|
||||
"Username": true, // metadata
|
||||
"ProfileName": true, // metadata
|
||||
"CleanNATExternalIPs": true, // control flag for clearing
|
||||
"CleanDNSLabels": true, // control flag for clearing
|
||||
"LazyConnectionEnabled": true, // deprecated: proto field retained for compat, no longer applied
|
||||
}
|
||||
|
||||
expectedFields := map[string]bool{
|
||||
@@ -190,7 +188,6 @@ func verifyAllFieldsCovered(t *testing.T, req *proto.SetConfigRequest) {
|
||||
"DisableFirewall": true,
|
||||
"BlockLanAccess": true,
|
||||
"DisableNotifications": true,
|
||||
"LazyConnectionEnabled": true,
|
||||
"BlockInbound": true,
|
||||
"DisableIpv6": true,
|
||||
"NatExternalIPs": true,
|
||||
@@ -252,7 +249,6 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
"block-lan-access": "BlockLanAccess",
|
||||
"block-inbound": "BlockInbound",
|
||||
"disable-ipv6": "DisableIpv6",
|
||||
"enable-lazy-connection": "LazyConnectionEnabled",
|
||||
"external-ip-map": "NatExternalIPs",
|
||||
"dns-resolver-address": "CustomDNSAddress",
|
||||
"extra-iface-blacklist": "ExtraIFaceBlacklist",
|
||||
@@ -269,7 +265,8 @@ func TestCLIFlags_MappedToSetConfig(t *testing.T) {
|
||||
|
||||
// SetConfigRequest fields that don't have CLI flags (settable only via UI or other means).
|
||||
fieldsWithoutCLIFlags := map[string]bool{
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
"DisableNotifications": true, // Only settable via UI
|
||||
"LazyConnectionEnabled": true, // deprecated: no longer settable (managed by server + NB_LAZY_CONN)
|
||||
}
|
||||
|
||||
// Get all SetConfigRequest fields to verify our map is complete.
|
||||
|
||||
@@ -74,8 +74,6 @@ type Info struct {
|
||||
BlockInbound bool
|
||||
DisableIPv6 bool
|
||||
|
||||
LazyConnectionEnabled bool
|
||||
|
||||
EnableSSHRoot bool
|
||||
EnableSSHSFTP bool
|
||||
EnableSSHLocalPortForwarding bool
|
||||
@@ -87,7 +85,7 @@ func (i *Info) SetFlags(
|
||||
rosenpassEnabled, rosenpassPermissive bool,
|
||||
serverSSHAllowed *bool,
|
||||
disableClientRoutes, disableServerRoutes,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6, lazyConnectionEnabled bool,
|
||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, disableIPv6 bool,
|
||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||
disableSSHAuth *bool,
|
||||
) {
|
||||
@@ -105,8 +103,6 @@ func (i *Info) SetFlags(
|
||||
i.BlockInbound = blockInbound
|
||||
i.DisableIPv6 = disableIPv6
|
||||
|
||||
i.LazyConnectionEnabled = lazyConnectionEnabled
|
||||
|
||||
if enableSSHRoot != nil {
|
||||
i.EnableSSHRoot = *enableSSHRoot
|
||||
}
|
||||
|
||||
@@ -266,7 +266,6 @@ type serviceClient struct {
|
||||
mAllowSSH *systray.MenuItem
|
||||
mAutoConnect *systray.MenuItem
|
||||
mEnableRosenpass *systray.MenuItem
|
||||
mLazyConnEnabled *systray.MenuItem
|
||||
mBlockInbound *systray.MenuItem
|
||||
mNotifications *systray.MenuItem
|
||||
mAdvancedSettings *systray.MenuItem
|
||||
@@ -336,11 +335,11 @@ type serviceClient struct {
|
||||
// mNetworks + mExitNode submenu items. Combines features.DisableNetworks
|
||||
// AND s.connected — both must be true for the menus to be active.
|
||||
// Zero value (false) matches the Disable() call at AddMenuItem time.
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
networksMenuEnabled bool
|
||||
showNetworks bool
|
||||
wNetworks fyne.Window
|
||||
wProfiles fyne.Window
|
||||
wQuickActions fyne.Window
|
||||
|
||||
eventManager *event.Manager
|
||||
|
||||
@@ -1094,7 +1093,6 @@ func (s *serviceClient) onTrayReady() {
|
||||
s.mAllowSSH = s.mSettings.AddSubMenuItemCheckbox("Allow SSH", allowSSHMenuDescr, false)
|
||||
s.mAutoConnect = s.mSettings.AddSubMenuItemCheckbox("Connect on Startup", autoConnectMenuDescr, false)
|
||||
s.mEnableRosenpass = s.mSettings.AddSubMenuItemCheckbox("Enable Quantum-Resistance", quantumResistanceMenuDescr, false)
|
||||
s.mLazyConnEnabled = s.mSettings.AddSubMenuItemCheckbox("Enable Lazy Connections", lazyConnMenuDescr, false)
|
||||
s.mBlockInbound = s.mSettings.AddSubMenuItemCheckbox("Block Inbound Connections", blockInboundMenuDescr, false)
|
||||
s.mNotifications = s.mSettings.AddSubMenuItemCheckbox("Notifications", notificationsMenuDescr, false)
|
||||
s.mSettings.AddSeparator()
|
||||
@@ -1578,7 +1576,6 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
||||
config.RosenpassEnabled = cfg.RosenpassEnabled
|
||||
config.RosenpassPermissive = cfg.RosenpassPermissive
|
||||
config.DisableNotifications = &cfg.DisableNotifications
|
||||
config.LazyConnectionEnabled = cfg.LazyConnectionEnabled
|
||||
config.BlockInbound = cfg.BlockInbound
|
||||
config.NetworkMonitor = &cfg.NetworkMonitor
|
||||
config.DisableDNS = cfg.DisableDns
|
||||
@@ -1682,12 +1679,6 @@ func (s *serviceClient) loadSettings() {
|
||||
s.mEnableRosenpass.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.LazyConnectionEnabled {
|
||||
s.mLazyConnEnabled.Check()
|
||||
} else {
|
||||
s.mLazyConnEnabled.Uncheck()
|
||||
}
|
||||
|
||||
if cfg.BlockInbound {
|
||||
s.mBlockInbound.Check()
|
||||
} else {
|
||||
@@ -1833,7 +1824,6 @@ func (s *serviceClient) updateConfig() error {
|
||||
disableAutoStart := !s.mAutoConnect.Checked()
|
||||
sshAllowed := s.mAllowSSH.Checked()
|
||||
rosenpassEnabled := s.mEnableRosenpass.Checked()
|
||||
lazyConnectionEnabled := s.mLazyConnEnabled.Checked()
|
||||
blockInbound := s.mBlockInbound.Checked()
|
||||
notificationsDisabled := !s.mNotifications.Checked()
|
||||
|
||||
@@ -1856,14 +1846,13 @@ func (s *serviceClient) updateConfig() error {
|
||||
}
|
||||
|
||||
req := proto.SetConfigRequest{
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
LazyConnectionEnabled: &lazyConnectionEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableNotifications: ¬ificationsDisabled,
|
||||
ProfileName: activeProf.ID.String(),
|
||||
Username: currUser.Username,
|
||||
DisableAutoConnect: &disableAutoStart,
|
||||
ServerSSHAllowed: &sshAllowed,
|
||||
RosenpassEnabled: &rosenpassEnabled,
|
||||
BlockInbound: &blockInbound,
|
||||
DisableNotifications: ¬ificationsDisabled,
|
||||
}
|
||||
|
||||
if _, err := conn.SetConfig(s.ctx, &req); err != nil {
|
||||
|
||||
@@ -4,7 +4,6 @@ const (
|
||||
allowSSHMenuDescr = "Allow SSH connections"
|
||||
autoConnectMenuDescr = "Connect automatically when the service starts"
|
||||
quantumResistanceMenuDescr = "Enable post-quantum security via Rosenpass"
|
||||
lazyConnMenuDescr = "[Experimental] Enable lazy connections"
|
||||
blockInboundMenuDescr = "Block inbound connections to the local machine and routed networks"
|
||||
notificationsMenuDescr = "Enable notifications"
|
||||
advancedSettingsMenuDescr = "Advanced settings of the application"
|
||||
|
||||
@@ -43,8 +43,6 @@ func (h *eventHandler) listen(ctx context.Context) {
|
||||
h.handleAutoConnectClick()
|
||||
case <-h.client.mEnableRosenpass.ClickedCh:
|
||||
h.handleRosenpassClick()
|
||||
case <-h.client.mLazyConnEnabled.ClickedCh:
|
||||
h.handleLazyConnectionClick()
|
||||
case <-h.client.mBlockInbound.ClickedCh:
|
||||
h.handleBlockInboundClick()
|
||||
case <-h.client.mAdvancedSettings.ClickedCh:
|
||||
@@ -152,15 +150,6 @@ func (h *eventHandler) handleRosenpassClick() {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleLazyConnectionClick() {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||
log.Errorf("failed to update config: %v", err)
|
||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *eventHandler) handleBlockInboundClick() {
|
||||
h.toggleCheckbox(h.client.mBlockInbound)
|
||||
if err := h.updateConfigWithErr(); err != nil {
|
||||
|
||||
140
e2e/agentnetwork/skiptls_test.go
Normal file
140
e2e/agentnetwork/skiptls_test.go
Normal file
@@ -0,0 +1,140 @@
|
||||
//go:build e2e
|
||||
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/e2e/harness"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestProviderSkipTLSVerification proves skip_tls_verification is per-provider:
|
||||
// two providers share one self-signed upstream, one skipping TLS verification
|
||||
// and one not. The skip=true provider's chat reaches the upstream and returns
|
||||
// 200; the skip=false provider's chat fails at the TLS handshake — same
|
||||
// upstream, opposite outcome. This is the behaviour a target-level flag could
|
||||
// not give, since all of an account's providers share one synthesised target.
|
||||
func TestProviderSkipTLSVerification(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
up, err := harness.StartFakeUpstream(ctx, srv)
|
||||
require.NoError(t, err, "start self-signed upstream")
|
||||
t.Cleanup(func() { _ = up.Terminate(context.Background()) })
|
||||
|
||||
grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-skiptls"})
|
||||
require.NoError(t, err, "create group")
|
||||
t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) })
|
||||
|
||||
ephemeral := false
|
||||
sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{
|
||||
Name: "e2e-skiptls-client",
|
||||
Type: "reusable",
|
||||
ExpiresIn: 86400,
|
||||
UsageLimit: 0,
|
||||
AutoGroups: []string{grp.Id},
|
||||
Ephemeral: &ephemeral,
|
||||
})
|
||||
require.NoError(t, err, "mint setup key")
|
||||
require.NotEmpty(t, sk.Key, "setup key plaintext")
|
||||
|
||||
const (
|
||||
insecureModel = "insecure-model"
|
||||
secureModel = "secure-model"
|
||||
)
|
||||
|
||||
// Two providers on the SAME self-signed upstream, distinguished only by their
|
||||
// skip_tls_verification and a unique model string so the router picks each
|
||||
// unambiguously.
|
||||
newReq := func(name, model string, skip bool) api.AgentNetworkProviderRequest {
|
||||
key := "sk-dummy-e2e"
|
||||
return api.AgentNetworkProviderRequest{
|
||||
Name: name,
|
||||
ProviderId: "openai_api",
|
||||
UpstreamUrl: up.URL,
|
||||
ApiKey: &key,
|
||||
Enabled: ptr(true),
|
||||
SkipTlsVerification: ptr(skip),
|
||||
Models: &[]api.AgentNetworkProviderModel{
|
||||
{Id: model, InputPer1k: 0.001, OutputPer1k: 0.002},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// First create bootstraps the account cluster.
|
||||
insecureReq := newReq("skip-tls", insecureModel, true)
|
||||
insecureReq.BootstrapCluster = ptr(harness.AgentNetworkCluster)
|
||||
insecureProv, err := srv.CreateProvider(ctx, insecureReq)
|
||||
require.NoError(t, err, "create skip-tls provider")
|
||||
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), insecureProv.Id) })
|
||||
require.True(t, insecureProv.SkipTlsVerification, "response must echo skip_tls_verification=true")
|
||||
|
||||
secureProv, err := srv.CreateProvider(ctx, newReq("verify-tls", secureModel, false))
|
||||
require.NoError(t, err, "create verify-tls provider")
|
||||
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), secureProv.Id) })
|
||||
require.False(t, secureProv.SkipTlsVerification, "response must echo skip_tls_verification=false")
|
||||
|
||||
enabled := true
|
||||
pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{
|
||||
Name: "e2e-skiptls-allow",
|
||||
Enabled: &enabled,
|
||||
SourceGroups: []string{grp.Id},
|
||||
DestinationProviderIds: []string{insecureProv.Id, secureProv.Id},
|
||||
})
|
||||
require.NoError(t, err, "create policy")
|
||||
t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) })
|
||||
|
||||
settings, err := srv.GetSettings(ctx)
|
||||
require.NoError(t, err, "read settings")
|
||||
require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned")
|
||||
|
||||
proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-skiptls-proxy")
|
||||
require.NoError(t, err, "mint proxy token")
|
||||
px, err := harness.StartProxy(ctx, srv, proxyToken)
|
||||
require.NoError(t, err, "start proxy")
|
||||
t.Cleanup(func() { _ = px.Terminate(context.Background()) })
|
||||
|
||||
cl, err := harness.StartClient(ctx, srv, sk.Key)
|
||||
require.NoError(t, err, "start client")
|
||||
t.Cleanup(func() { _ = cl.Terminate(context.Background()) })
|
||||
|
||||
require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management")
|
||||
if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil {
|
||||
t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background()))
|
||||
}
|
||||
proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint)
|
||||
require.NoError(t, err, "resolve endpoint to proxy IP")
|
||||
|
||||
// Positive: skip=true reaches the self-signed upstream. Retry to absorb
|
||||
// tunnel/DNS jitter on the first call; success also proves the path works.
|
||||
var code int
|
||||
var body string
|
||||
deadline := time.Now().Add(90 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, insecureModel, "Reply with exactly: pong", "e2e-skiptls-insecure")
|
||||
if cerr == nil {
|
||||
code, body = c, b
|
||||
if code == 200 {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
require.Equal(t, 200, code,
|
||||
"skip_tls_verification=true must reach the self-signed upstream; body: %s\n=== upstream logs ===\n%s\n=== proxy logs ===\n%s",
|
||||
body, up.Logs(context.Background()), px.Logs(context.Background()))
|
||||
|
||||
// Negative: skip=false must fail the TLS handshake to the SAME upstream. The
|
||||
// path is already proven working, so a non-200 here is the cert rejection.
|
||||
secureCode, secureBody, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, secureModel, "Reply with exactly: pong", "e2e-skiptls-secure")
|
||||
require.NoError(t, cerr, "the chat call itself must complete (proxy returns an error status, not a transport error)")
|
||||
require.NotEqual(t, 200, secureCode,
|
||||
"skip_tls_verification=false must NOT reach the self-signed upstream; got %d, body: %s", secureCode, secureBody)
|
||||
require.GreaterOrEqual(t, secureCode, 500,
|
||||
"a TLS verification failure should surface as a 5xx from the proxy; got %d, body: %s", secureCode, secureBody)
|
||||
}
|
||||
171
e2e/agentnetwork/vllm_test.go
Normal file
171
e2e/agentnetwork/vllm_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
//go:build e2e
|
||||
|
||||
package agentnetwork
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/e2e/harness"
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestVLLMProvider proves the proxy supports a self-hosted vLLM backend. vLLM is
|
||||
// OpenAI-compatible, so it uses the "vllm" catalog entry (KindCustom) and is
|
||||
// reached over plain HTTP — no TLS anywhere on the path:
|
||||
//
|
||||
// client --tunnel--> netbird proxy --http--> vllm (:8000, OpenAI-compatible)
|
||||
//
|
||||
// The mock vLLM server answers /v1/chat/completions with an OpenAI-shaped
|
||||
// completion carrying a non-zero usage block. The test asserts the chat returns
|
||||
// 200 with the completion, that the request is recorded in the access log by its
|
||||
// session id, and that vLLM's usage block is metered into a consumption row —
|
||||
// which together prove request routing, response parsing, and token accounting
|
||||
// all work for a self-hosted OpenAI-compatible provider.
|
||||
//
|
||||
// It needs no external credentials (the mock ignores auth), so it always runs.
|
||||
func TestVLLMProvider(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
vllm, err := harness.StartVLLM(ctx, srv)
|
||||
require.NoError(t, err, "start mock vLLM server")
|
||||
t.Cleanup(func() { _ = vllm.Terminate(context.Background()) })
|
||||
|
||||
grp, err := srv.API().Groups.Create(ctx, api.PostApiGroupsJSONRequestBody{Name: "e2e-vllm"})
|
||||
require.NoError(t, err, "create group")
|
||||
t.Cleanup(func() { _ = srv.API().Groups.Delete(context.Background(), grp.Id) })
|
||||
|
||||
ephemeral := false
|
||||
sk, err := srv.API().SetupKeys.Create(ctx, api.PostApiSetupKeysJSONRequestBody{
|
||||
Name: "e2e-vllm-client",
|
||||
Type: "reusable",
|
||||
ExpiresIn: 86400,
|
||||
UsageLimit: 0,
|
||||
AutoGroups: []string{grp.Id},
|
||||
Ephemeral: &ephemeral,
|
||||
})
|
||||
require.NoError(t, err, "mint setup key")
|
||||
require.NotEmpty(t, sk.Key, "setup key plaintext")
|
||||
|
||||
// vLLM provider pointed at the mock over plain HTTP. The mock ignores auth,
|
||||
// so a dummy key satisfies the "Bearer ${API_KEY}" template. The served model
|
||||
// is enumerated so the router dispatches this model string to this provider.
|
||||
dummyKey := "sk-vllm-e2e"
|
||||
prov, err := srv.CreateProvider(ctx, api.AgentNetworkProviderRequest{
|
||||
Name: "vllm",
|
||||
ProviderId: "vllm",
|
||||
UpstreamUrl: vllm.URL,
|
||||
ApiKey: &dummyKey,
|
||||
Enabled: ptr(true),
|
||||
BootstrapCluster: ptr(harness.AgentNetworkCluster),
|
||||
Models: &[]api.AgentNetworkProviderModel{
|
||||
{Id: harness.VLLMModel, InputPer1k: 0.001, OutputPer1k: 0.002},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "create vllm provider")
|
||||
t.Cleanup(func() { _ = srv.DeleteProvider(context.Background(), prov.Id) })
|
||||
|
||||
// Token limit far above the handful of tokens this test drives, so it never
|
||||
// blocks but switches on usage metering — the switch that makes consumption
|
||||
// rows get recorded.
|
||||
enabled := true
|
||||
pol, err := srv.CreatePolicy(ctx, api.AgentNetworkPolicyRequest{
|
||||
Name: "e2e-vllm-allow",
|
||||
Enabled: &enabled,
|
||||
SourceGroups: []string{grp.Id},
|
||||
DestinationProviderIds: []string{prov.Id},
|
||||
Limits: &api.AgentNetworkPolicyLimits{
|
||||
TokenLimit: api.AgentNetworkPolicyTokenLimit{
|
||||
Enabled: true,
|
||||
GroupCap: 10_000_000,
|
||||
UserCap: 10_000_000,
|
||||
WindowSeconds: 60,
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err, "create policy")
|
||||
t.Cleanup(func() { _ = srv.DeletePolicy(context.Background(), pol.Id) })
|
||||
|
||||
settings, err := srv.GetSettings(ctx)
|
||||
require.NoError(t, err, "read settings")
|
||||
require.NotEmpty(t, settings.Endpoint, "endpoint must be assigned")
|
||||
|
||||
proxyToken, err := srv.CreateProxyTokenCLI(ctx, "e2e-vllm-proxy")
|
||||
require.NoError(t, err, "mint proxy token")
|
||||
px, err := harness.StartProxy(ctx, srv, proxyToken)
|
||||
require.NoError(t, err, "start proxy")
|
||||
t.Cleanup(func() { _ = px.Terminate(context.Background()) })
|
||||
|
||||
cl, err := harness.StartClient(ctx, srv, sk.Key)
|
||||
require.NoError(t, err, "start client")
|
||||
t.Cleanup(func() { _ = cl.Terminate(context.Background()) })
|
||||
|
||||
require.NoError(t, cl.WaitConnected(ctx, 90*time.Second), "client must connect to management")
|
||||
if err := cl.WaitProxyPeer(ctx, 180*time.Second); err != nil {
|
||||
t.Fatalf("client did not see the proxy peer: %v\n=== proxy logs ===\n%s", err, px.Logs(context.Background()))
|
||||
}
|
||||
proxyIP, err := cl.ResolveProxyIP(ctx, settings.Endpoint)
|
||||
require.NoError(t, err, "resolve endpoint to proxy IP")
|
||||
|
||||
before, _ := srv.ListAccessLogs(ctx)
|
||||
sessionID := "e2e-session-vllm"
|
||||
|
||||
// Retry to absorb tunnel/DNS jitter on the first call.
|
||||
var code int
|
||||
var body string
|
||||
deadline := time.Now().Add(90 * time.Second)
|
||||
for time.Now().Before(deadline) {
|
||||
c, b, cerr := cl.Chat(ctx, settings.Endpoint, proxyIP, harness.WireChat, harness.VLLMModel, "Reply with exactly: pong", sessionID)
|
||||
if cerr == nil {
|
||||
code, body = c, b
|
||||
if code == 200 {
|
||||
break
|
||||
}
|
||||
}
|
||||
time.Sleep(5 * time.Second)
|
||||
}
|
||||
require.Equal(t, 200, code,
|
||||
"chat through the vLLM provider must return 200; body: %s\n=== vllm logs ===\n%s\n=== proxy logs ===\n%s",
|
||||
body, vllm.Logs(context.Background()), px.Logs(context.Background()))
|
||||
require.True(t, strings.Contains(body, "chat.completion"),
|
||||
"body should be an OpenAI-compatible chat completion; got: %s", body)
|
||||
|
||||
// The request must surface as an access-log row carrying our session id.
|
||||
require.Eventually(t, func() bool {
|
||||
logs, lerr := srv.ListAccessLogs(ctx)
|
||||
return lerr == nil && logs.TotalRecords > before.TotalRecords
|
||||
}, 30*time.Second, 2*time.Second, "an access-log row should be ingested for the vLLM provider")
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
logs, lerr := srv.ListAccessLogs(ctx)
|
||||
if lerr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range logs.Data {
|
||||
if r.SessionId != nil && *r.SessionId == sessionID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 30*time.Second, 2*time.Second, "session id %q must be recorded in an access-log row", sessionID)
|
||||
|
||||
// vLLM's usage block (prompt_tokens=11, completion_tokens=2) must be parsed
|
||||
// and metered into a consumption row with positive token counts.
|
||||
require.Eventually(t, func() bool {
|
||||
rows, lerr := srv.ListConsumption(ctx)
|
||||
if lerr != nil {
|
||||
return false
|
||||
}
|
||||
for _, r := range rows {
|
||||
if r.TokensInput > 0 && r.TokensOutput > 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, 60*time.Second, 3*time.Second, "vLLM usage must be metered into a consumption row")
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -108,9 +109,48 @@ func (cl *Client) WaitConnected(ctx context.Context, timeout time.Duration) erro
|
||||
return cl.pollStatus(ctx, timeout, "Management: Connected")
|
||||
}
|
||||
|
||||
// WaitProxyPeer polls until the client sees the proxy peer connected (1/1).
|
||||
// WaitProxyPeer polls until the client sees at least one connected peer — the
|
||||
// proxy serving the agent-network endpoint. It requires ">=1 connected" rather
|
||||
// than an exact "1/1" because proxy peers from earlier tests linger in the
|
||||
// account as disconnected (each proxy container registers a fresh WireGuard key
|
||||
// and the peer is not removed on teardown), so the count is e.g. "1/2". Only the
|
||||
// live proxy can be connected, and the caller's subsequent chat is the real
|
||||
// end-to-end assertion.
|
||||
func (cl *Client) WaitProxyPeer(ctx context.Context, timeout time.Duration) error {
|
||||
return cl.pollStatus(ctx, timeout, "1/1 Connected")
|
||||
deadline := time.Now().Add(timeout)
|
||||
var last string
|
||||
for time.Now().Before(deadline) {
|
||||
out, _ := cl.Status(ctx)
|
||||
last = out
|
||||
if connectedPeers(out) >= 1 {
|
||||
return nil
|
||||
}
|
||||
time.Sleep(3 * time.Second)
|
||||
}
|
||||
return fmt.Errorf("timed out waiting for a connected proxy peer; last status:\n%s", last)
|
||||
}
|
||||
|
||||
// connectedPeers parses the "Peers count: X/Y Connected" line from `netbird
|
||||
// status` and returns X (the connected count), or 0 when absent/unparseable.
|
||||
func connectedPeers(status string) int {
|
||||
for _, line := range strings.Split(status, "\n") {
|
||||
line = strings.TrimSpace(line)
|
||||
rest, ok := strings.CutPrefix(line, "Peers count:")
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
rest = strings.TrimSpace(rest)
|
||||
slash := strings.IndexByte(rest, '/')
|
||||
if slash <= 0 {
|
||||
return 0
|
||||
}
|
||||
n, err := strconv.Atoi(strings.TrimSpace(rest[:slash]))
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return n
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (cl *Client) pollStatus(ctx context.Context, timeout time.Duration, want string) error {
|
||||
|
||||
107
e2e/harness/upstream.go
Normal file
107
e2e/harness/upstream.go
Normal file
@@ -0,0 +1,107 @@
|
||||
//go:build e2e
|
||||
|
||||
package harness
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
fakeUpstreamImage = "nginx:alpine"
|
||||
fakeUpstreamAlias = "fakeupstream"
|
||||
fakeUpstreamPort = "443/tcp"
|
||||
)
|
||||
|
||||
// fakeUpstreamNginxConf serves a canned OpenAI-shaped chat completion for any
|
||||
// path over a self-signed certificate, so the proxy reaches it only when the
|
||||
// provider opts into skipping TLS verification.
|
||||
const fakeUpstreamNginxConf = `pid /tmp/nginx.pid;
|
||||
events {}
|
||||
http {
|
||||
server {
|
||||
listen 443 ssl;
|
||||
ssl_certificate /certs/tls.crt;
|
||||
ssl_certificate_key /certs/tls.key;
|
||||
location / {
|
||||
default_type application/json;
|
||||
return 200 '{"id":"chatcmpl-e2e","object":"chat.completion","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":1,"total_tokens":2}}';
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
// FakeUpstream is a self-signed HTTPS server on the combined server's network,
|
||||
// used to exercise provider skip_tls_verification: a proxy that verifies the
|
||||
// certificate rejects it, one that skips verification reaches it.
|
||||
type FakeUpstream struct {
|
||||
container testcontainers.Container
|
||||
workDir string
|
||||
// URL is the upstream URL providers point at (https://<alias>).
|
||||
URL string
|
||||
}
|
||||
|
||||
// StartFakeUpstream runs the self-signed upstream on the shared network.
|
||||
func StartFakeUpstream(ctx context.Context, c *Combined) (*FakeUpstream, error) {
|
||||
workDir, err := os.MkdirTemp("/tmp", "nb-e2e-upstream-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create upstream work dir: %w", err)
|
||||
}
|
||||
// Widen so the (non-root worker) nginx container can traverse the bind mount.
|
||||
if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e cert dir
|
||||
return nil, fmt.Errorf("chmod upstream dir: %w", err)
|
||||
}
|
||||
if err := writeSelfSignedCert(workDir, []string{fakeUpstreamAlias}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(fakeUpstreamNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config
|
||||
return nil, fmt.Errorf("write nginx conf: %w", err)
|
||||
}
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: fakeUpstreamImage,
|
||||
ExposedPorts: []string{fakeUpstreamPort},
|
||||
Networks: []string{c.network.Name},
|
||||
NetworkAliases: map[string][]string{c.network.Name: {fakeUpstreamAlias}},
|
||||
Cmd: []string{"nginx", "-c", "/certs/nginx.conf", "-g", "daemon off;"},
|
||||
HostConfigModifier: func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, workDir+":/certs:ro")
|
||||
},
|
||||
WaitingFor: wait.ForListeningPort(fakeUpstreamPort).WithStartupTimeout(60 * time.Second),
|
||||
}
|
||||
|
||||
ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(workDir)
|
||||
return nil, fmt.Errorf("start fake upstream container: %w", err)
|
||||
}
|
||||
|
||||
return &FakeUpstream{container: ctr, workDir: workDir, URL: "https://" + fakeUpstreamAlias}, nil
|
||||
}
|
||||
|
||||
// Logs returns the upstream container logs, for diagnostics on failure.
|
||||
func (u *FakeUpstream) Logs(ctx context.Context) string {
|
||||
return containerLogs(ctx, u.container)
|
||||
}
|
||||
|
||||
// Terminate stops the upstream container and cleans its work dir.
|
||||
func (u *FakeUpstream) Terminate(ctx context.Context) error {
|
||||
var err error
|
||||
if u.container != nil {
|
||||
err = u.container.Terminate(ctx)
|
||||
}
|
||||
if u.workDir != "" {
|
||||
_ = os.RemoveAll(u.workDir)
|
||||
}
|
||||
return err
|
||||
}
|
||||
113
e2e/harness/vllm.go
Normal file
113
e2e/harness/vllm.go
Normal file
@@ -0,0 +1,113 @@
|
||||
//go:build e2e
|
||||
|
||||
package harness
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/testcontainers/testcontainers-go"
|
||||
"github.com/testcontainers/testcontainers-go/wait"
|
||||
)
|
||||
|
||||
const (
|
||||
vllmImage = "nginx:alpine"
|
||||
vllmAlias = "vllm"
|
||||
vllmPort = "8000/tcp"
|
||||
// VLLMModel is the served model id the mock advertises and echoes back. It
|
||||
// matches a real small model commonly served by vLLM so the provider's
|
||||
// enumerated model and the client's request line up.
|
||||
VLLMModel = "Qwen/Qwen2.5-0.5B-Instruct"
|
||||
)
|
||||
|
||||
// vllmNginxConf emulates a vLLM OpenAI-compatible server over plain HTTP (vLLM's
|
||||
// default: no TLS, port 8000). It answers /v1/models with a one-model list and
|
||||
// any chat/completions path with a canned OpenAI-shaped chat completion carrying
|
||||
// a non-zero usage block, so the proxy's OpenAI parser records real token
|
||||
// consumption. Running actual vLLM in CI is infeasible (GPU + multi-GB model
|
||||
// download), so this stands in for the wire contract the proxy depends on.
|
||||
const vllmNginxConf = `pid /tmp/nginx.pid;
|
||||
events {}
|
||||
http {
|
||||
server {
|
||||
listen 8000;
|
||||
location = /v1/models {
|
||||
default_type application/json;
|
||||
return 200 '{"object":"list","data":[{"id":"Qwen/Qwen2.5-0.5B-Instruct","object":"model","owned_by":"vllm"}]}';
|
||||
}
|
||||
location / {
|
||||
default_type application/json;
|
||||
return 200 '{"id":"chatcmpl-e2e-vllm","object":"chat.completion","created":1700000000,"model":"Qwen/Qwen2.5-0.5B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"pong"},"finish_reason":"stop"}],"usage":{"prompt_tokens":11,"completion_tokens":2,"total_tokens":13}}';
|
||||
}
|
||||
}
|
||||
}
|
||||
`
|
||||
|
||||
// VLLM is a mock vLLM OpenAI-compatible server on the combined server's network,
|
||||
// reachable at http://vllm:8000. A "vllm" provider points at it to exercise the
|
||||
// proxy's support for self-hosted OpenAI-compatible backends.
|
||||
type VLLM struct {
|
||||
container testcontainers.Container
|
||||
workDir string
|
||||
// URL is the upstream URL the vllm provider points at (http://<alias>:8000).
|
||||
URL string
|
||||
}
|
||||
|
||||
// StartVLLM runs the mock vLLM server on the shared network over plain HTTP.
|
||||
func StartVLLM(ctx context.Context, c *Combined) (*VLLM, error) {
|
||||
workDir, err := os.MkdirTemp("/tmp", "nb-e2e-vllm-*")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create vllm work dir: %w", err)
|
||||
}
|
||||
// Widen so the (non-root worker) nginx container can traverse the bind mount.
|
||||
if err := os.Chmod(workDir, 0o755); err != nil { //nolint:gosec // throwaway e2e config dir
|
||||
return nil, fmt.Errorf("chmod vllm dir: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filepath.Join(workDir, "nginx.conf"), []byte(vllmNginxConf), 0o644); err != nil { //nolint:gosec // non-secret e2e config
|
||||
return nil, fmt.Errorf("write nginx conf: %w", err)
|
||||
}
|
||||
|
||||
req := testcontainers.ContainerRequest{
|
||||
Image: vllmImage,
|
||||
ExposedPorts: []string{vllmPort},
|
||||
Networks: []string{c.network.Name},
|
||||
NetworkAliases: map[string][]string{c.network.Name: {vllmAlias}},
|
||||
Cmd: []string{"nginx", "-c", "/conf/nginx.conf", "-g", "daemon off;"},
|
||||
HostConfigModifier: func(hc *container.HostConfig) {
|
||||
hc.Binds = append(hc.Binds, workDir+":/conf:ro")
|
||||
},
|
||||
WaitingFor: wait.ForListeningPort(vllmPort).WithStartupTimeout(60 * time.Second),
|
||||
}
|
||||
|
||||
ctr, err := testcontainers.GenericContainer(ctx, testcontainers.GenericContainerRequest{
|
||||
ContainerRequest: req,
|
||||
Started: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = os.RemoveAll(workDir)
|
||||
return nil, fmt.Errorf("start vllm container: %w", err)
|
||||
}
|
||||
|
||||
return &VLLM{container: ctr, workDir: workDir, URL: "http://" + vllmAlias + ":8000"}, nil
|
||||
}
|
||||
|
||||
// Logs returns the vLLM container logs, for diagnostics on failure.
|
||||
func (v *VLLM) Logs(ctx context.Context) string {
|
||||
return containerLogs(ctx, v.container)
|
||||
}
|
||||
|
||||
// Terminate stops the vLLM container and cleans its work dir.
|
||||
func (v *VLLM) Terminate(ctx context.Context) error {
|
||||
var err error
|
||||
if v.container != nil {
|
||||
err = v.container.Terminate(ctx)
|
||||
}
|
||||
if v.workDir != "" {
|
||||
_ = os.RemoveAll(v.workDir)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@@ -351,11 +351,6 @@ initialize_default_values() {
|
||||
NETBIRD_STUN_PORT=3478
|
||||
|
||||
# Docker images
|
||||
# Record whether the operator explicitly pinned the server/proxy images via
|
||||
# env vars, so the agent-network preset can pick its own defaults without
|
||||
# clobbering an explicit override.
|
||||
NETBIRD_SERVER_IMAGE_EXPLICIT=${NETBIRD_SERVER_IMAGE:+true}
|
||||
NETBIRD_PROXY_IMAGE_EXPLICIT=${NETBIRD_PROXY_IMAGE:+true}
|
||||
DASHBOARD_IMAGE=${DASHBOARD_IMAGE:-"netbirdio/dashboard:latest"}
|
||||
# Combined server replaces separate signal, relay, and management containers
|
||||
NETBIRD_SERVER_IMAGE=${NETBIRD_SERVER_IMAGE:-"netbirdio/netbird-server:latest"}
|
||||
@@ -415,15 +410,6 @@ apply_agent_network_preset() {
|
||||
ENABLE_PROXY="true"
|
||||
ENABLE_CROWDSEC="false"
|
||||
|
||||
# Agent-network ships dedicated server/proxy images. Honor an explicit
|
||||
# env override; otherwise pin the agent-network builds.
|
||||
if [[ "${NETBIRD_SERVER_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_SERVER_IMAGE="netbirdio/netbird-server:0.74.0-rc.2"
|
||||
fi
|
||||
if [[ "${NETBIRD_PROXY_IMAGE_EXPLICIT}" != "true" ]]; then
|
||||
NETBIRD_PROXY_IMAGE="netbirdio/reverse-proxy:0.74.0-rc.2"
|
||||
fi
|
||||
|
||||
if [[ -n "${NETBIRD_LETSENCRYPT_EMAIL}" ]]; then
|
||||
TRAEFIK_ACME_EMAIL="${NETBIRD_LETSENCRYPT_EMAIL}"
|
||||
else
|
||||
|
||||
@@ -627,6 +627,21 @@ var providers = []Provider{
|
||||
},
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
// vLLM is an OpenAI-compatible self-hosted server. It behaves like
|
||||
// the generic custom entry; it gets its own catalog id purely so it
|
||||
// surfaces as a named "vLLM" choice in the provider picker.
|
||||
ID: "vllm",
|
||||
Kind: KindCustom,
|
||||
Name: "vLLM",
|
||||
Description: "Self-hosted vLLM (OpenAI-compatible)",
|
||||
DefaultHost: "",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderTemplate: "Bearer ${API_KEY}",
|
||||
DefaultContentType: "application/json",
|
||||
BrandColor: "#30A2FF",
|
||||
Models: []Model{},
|
||||
},
|
||||
{
|
||||
ID: "custom",
|
||||
Kind: KindCustom,
|
||||
|
||||
@@ -366,6 +366,10 @@ type routerProviderRoute struct {
|
||||
// + refreshes the OAuth token at request time instead of injecting a static
|
||||
// AuthHeaderValue.
|
||||
GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"`
|
||||
// SkipTLSVerify disables upstream TLS certificate verification when the
|
||||
// proxy dials this provider's upstream. For self-hosted / internal gateways
|
||||
// behind a private or self-signed certificate.
|
||||
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"`
|
||||
}
|
||||
|
||||
// indexProviderGroups walks the enabled policies and returns, per
|
||||
@@ -450,6 +454,7 @@ func buildRouterConfigJSON(providers []*types.Provider, groupIndex map[string][]
|
||||
Vertex: catalog.IsVertexPathStyle(p.ProviderID),
|
||||
Bedrock: catalog.IsBedrockPathStyle(p.ProviderID),
|
||||
GCPServiceAccountKeyB64: gcpSAKeyB64,
|
||||
SkipTLSVerify: p.SkipTLSVerification,
|
||||
})
|
||||
}
|
||||
out, err := json.Marshal(cfg)
|
||||
|
||||
@@ -1057,6 +1057,41 @@ func TestSynthesizeServices_UpstreamURLPath_FlowsToRouter(t *testing.T) {
|
||||
"upstream path must be carried so the router can disambiguate same-model providers; trailing slash trimmed for stable string-prefix matching")
|
||||
}
|
||||
|
||||
func TestSynthesizeServices_SkipTLSVerification_FlowsToRouter(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
mockStore := store.NewMockStore(ctrl)
|
||||
|
||||
// A provider fronting a self-hosted / internal gateway opts into skipping
|
||||
// upstream TLS verification; the synthesiser must carry it into the router
|
||||
// route so the proxy dials that upstream insecurely.
|
||||
provider := newSynthTestProvider()
|
||||
provider.SkipTLSVerification = true
|
||||
policy := newSynthTestPolicy(provider.ID, "grp-eng", "")
|
||||
|
||||
expectSynthBaseInputs(mockStore, ctx, newSynthTestSettings(),
|
||||
[]*types.Provider{provider},
|
||||
[]*types.Policy{policy},
|
||||
[]*types.Guardrail{})
|
||||
|
||||
services, err := SynthesizeServices(ctx, mockStore, testAccountID)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, services, 1)
|
||||
|
||||
mws := services[0].Targets[0].Options.Middlewares
|
||||
var routerCfg routerConfig
|
||||
for _, m := range mws {
|
||||
if m.ID == middlewareIDLLMRouter {
|
||||
require.NoError(t, json.Unmarshal(m.ConfigJSON, &routerCfg))
|
||||
break
|
||||
}
|
||||
}
|
||||
require.Len(t, routerCfg.Providers, 1)
|
||||
assert.True(t, routerCfg.Providers[0].SkipTLSVerify,
|
||||
"provider skip_tls_verification must flow into the router route")
|
||||
}
|
||||
|
||||
func TestSynthesizeServices_UnknownProviderID_FailsClosed(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
@@ -46,6 +46,11 @@ type Provider struct {
|
||||
// Empty means all catalog models are allowed at catalog prices.
|
||||
Models []ProviderModel `gorm:"serializer:json"`
|
||||
Enabled bool
|
||||
// SkipTLSVerification disables upstream TLS certificate verification for
|
||||
// this provider's URL. For self-hosted / internal gateways fronted by a
|
||||
// private or self-signed certificate. The synthesiser propagates it into
|
||||
// the router route so the proxy dials that provider's upstream insecurely.
|
||||
SkipTLSVerification bool `gorm:"column:skip_tls_verification"`
|
||||
// SessionPrivateKey + SessionPublicKey are the ed25519 keypair the
|
||||
// synthesised reverse-proxy service uses to sign / verify session
|
||||
// JWTs after a successful OIDC handshake. Generated once on
|
||||
@@ -129,6 +134,9 @@ func (p *Provider) FromAPIRequest(req *api.AgentNetworkProviderRequest) {
|
||||
if req.Enabled != nil {
|
||||
p.Enabled = *req.Enabled
|
||||
}
|
||||
if req.SkipTlsVerification != nil {
|
||||
p.SkipTLSVerification = *req.SkipTlsVerification
|
||||
}
|
||||
// Identity-header overrides for catalogs flagged Customizable.
|
||||
// nil pointer = "field omitted on the wire" → leave the stored
|
||||
// value untouched (per the openapi description). Empty string is
|
||||
@@ -155,14 +163,15 @@ func (p *Provider) ToAPIResponse() *api.AgentNetworkProvider {
|
||||
created := p.CreatedAt
|
||||
updated := p.UpdatedAt
|
||||
resp := &api.AgentNetworkProvider{
|
||||
Id: p.ID,
|
||||
ProviderId: p.ProviderID,
|
||||
Name: p.Name,
|
||||
UpstreamUrl: p.UpstreamURL,
|
||||
Models: models,
|
||||
Enabled: p.Enabled,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
Id: p.ID,
|
||||
ProviderId: p.ProviderID,
|
||||
Name: p.Name,
|
||||
UpstreamUrl: p.UpstreamURL,
|
||||
Models: models,
|
||||
Enabled: p.Enabled,
|
||||
SkipTlsVerification: p.SkipTLSVerification,
|
||||
CreatedAt: &created,
|
||||
UpdatedAt: &updated,
|
||||
}
|
||||
if len(p.ExtraValues) > 0 {
|
||||
out := make(map[string]string, len(p.ExtraValues))
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/netbirdio/netbird/shared/management/http/api"
|
||||
)
|
||||
|
||||
// TestProvider_SkipTLSVerification_RoundTrip covers the request→provider→
|
||||
// response mapping of skip_tls_verification, including the update semantics
|
||||
// (nil pointer preserves, explicit false clears).
|
||||
func TestProvider_SkipTLSVerification_RoundTrip(t *testing.T) {
|
||||
enable := true
|
||||
disable := false
|
||||
|
||||
base := func() *api.AgentNetworkProviderRequest {
|
||||
return &api.AgentNetworkProviderRequest{
|
||||
ProviderId: "openai_api",
|
||||
Name: "internal",
|
||||
UpstreamUrl: "https://gw.internal",
|
||||
}
|
||||
}
|
||||
|
||||
p := NewProvider("acc-1")
|
||||
|
||||
req := base()
|
||||
req.SkipTlsVerification = &enable
|
||||
p.FromAPIRequest(req)
|
||||
assert.True(t, p.SkipTLSVerification, "create with skip_tls_verification=true must set the field")
|
||||
assert.True(t, p.ToAPIResponse().SkipTlsVerification, "response must surface skip_tls_verification")
|
||||
|
||||
// Omitting the field on update leaves the stored value untouched.
|
||||
p.FromAPIRequest(base())
|
||||
assert.True(t, p.SkipTLSVerification, "omitting skip_tls_verification on update must preserve it")
|
||||
|
||||
// Explicit false clears it.
|
||||
req = base()
|
||||
req.SkipTlsVerification = &disable
|
||||
p.FromAPIRequest(req)
|
||||
assert.False(t, p.SkipTLSVerification, "explicit false must clear skip_tls_verification")
|
||||
assert.False(t, p.ToAPIResponse().SkipTlsVerification, "response must reflect the cleared value")
|
||||
}
|
||||
@@ -47,7 +47,11 @@ func init() {
|
||||
precomputedDeprecatedRemotePeersConstraint = constraint
|
||||
}
|
||||
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings) *proto.NetbirdConfig {
|
||||
// toNetbirdConfig converts the server configuration to the wire representation. It returns
|
||||
// nil when no server config is set (the fan-out network-map path) because clients treat any
|
||||
// non-nil config as authoritative: a config without a relay section is interpreted as relay
|
||||
// disabled and wipes the clients' relay URLs.
|
||||
func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken *Token, extraSettings *types.ExtraSettings, settings *types.Settings) *proto.NetbirdConfig {
|
||||
if config == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -110,6 +114,12 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
||||
Relay: relayCfg,
|
||||
}
|
||||
|
||||
if settings != nil {
|
||||
nbConfig.Metrics = &proto.MetricsConfig{
|
||||
Enabled: settings.MetricsPushEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
return nbConfig
|
||||
}
|
||||
|
||||
@@ -166,7 +176,7 @@ func ToSyncResponse(ctx context.Context, config *nbconfig.Config, httpConfig *nb
|
||||
Checks: toProtocolChecks(ctx, checks),
|
||||
}
|
||||
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings)
|
||||
nbConfig := toNetbirdConfig(config, turnCredentials, relayCredentials, extraSettings, settings)
|
||||
extendedConfig := integrationsConfig.ExtendNetBirdConfig(peer.ID, peerGroups, nbConfig, extraSettings)
|
||||
response.NetbirdConfig = extendedConfig
|
||||
|
||||
|
||||
@@ -8,11 +8,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||
"github.com/netbirdio/netbird/management/server/types"
|
||||
)
|
||||
|
||||
func TestToProtocolDNSConfigWithCache(t *testing.T) {
|
||||
@@ -263,3 +265,39 @@ func TestEncodeSessionExpiresAt(t *testing.T) {
|
||||
assert.True(t, got.AsTime().Equal(deadline))
|
||||
})
|
||||
}
|
||||
|
||||
// TestToNetbirdConfig_RelayInvariant guards against the v0.74.0 relay-wipe regression.
|
||||
// Clients treat any non-nil NetbirdConfig as authoritative and interpret a missing relay
|
||||
// section as relay disabled, wiping their relay URLs. toNetbirdConfig must therefore
|
||||
// return nil when no server config is set (the fan-out network-map path) instead of a
|
||||
// partial config, and a result built from a relay-enabled config must carry the relay
|
||||
// section.
|
||||
func TestToNetbirdConfig_RelayInvariant(t *testing.T) {
|
||||
settings := &types.Settings{MetricsPushEnabled: true}
|
||||
|
||||
t.Run("nil server config returns nil config", func(t *testing.T) {
|
||||
nbCfg := toNetbirdConfig(nil, nil, nil, nil, settings)
|
||||
assert.Nil(t, nbCfg, "fan-out updates must not carry a partial NetbirdConfig even when settings are present")
|
||||
})
|
||||
|
||||
t.Run("relay-enabled config carries relay section", func(t *testing.T) {
|
||||
cfg := &nbconfig.Config{
|
||||
Stuns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "stun:stun.example.com:3478"}},
|
||||
TURNConfig: &nbconfig.TURNConfig{
|
||||
Turns: []*nbconfig.Host{{Proto: nbconfig.UDP, URI: "turn:turn.example.com:3478", Username: "user", Password: "pass"}},
|
||||
},
|
||||
Relay: &nbconfig.Relay{Addresses: []string{"rels://relay.example.com:443"}},
|
||||
Signal: &nbconfig.Host{Proto: nbconfig.HTTP, URI: "signal.example.com:10000"},
|
||||
}
|
||||
relayToken := &Token{Payload: "token-payload", Signature: "token-signature"}
|
||||
|
||||
nbCfg := toNetbirdConfig(cfg, nil, relayToken, nil, settings)
|
||||
require.NotNil(t, nbCfg)
|
||||
require.NotNil(t, nbCfg.Relay, "non-nil NetbirdConfig must include the relay section")
|
||||
assert.Equal(t, cfg.Relay.Addresses, nbCfg.Relay.Urls, "relay URLs should match the server config")
|
||||
assert.Equal(t, relayToken.Payload, nbCfg.Relay.TokenPayload, "relay token payload should be set")
|
||||
assert.Equal(t, relayToken.Signature, nbCfg.Relay.TokenSignature, "relay token signature should be set")
|
||||
require.NotNil(t, nbCfg.Metrics)
|
||||
assert.True(t, nbCfg.Metrics.Enabled, "metrics flag should carry the settings value")
|
||||
})
|
||||
}
|
||||
|
||||
@@ -917,7 +917,7 @@ func (s *Server) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer, ne
|
||||
|
||||
// if peer has reached this point then it has logged in
|
||||
loginResp := &proto.LoginResponse{
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil, settings),
|
||||
PeerConfig: toPeerConfig(peer, network, s.networkMapController.GetDNSDomain(settings), settings, s.config.HttpConfig, s.config.DeviceAuthorizationFlow, enableSSH),
|
||||
Checks: toProtocolChecks(ctx, postureChecks),
|
||||
}
|
||||
|
||||
@@ -358,7 +358,8 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
oldSettings.AutoUpdateVersion != newSettings.AutoUpdateVersion ||
|
||||
oldSettings.AutoUpdateAlways != newSettings.AutoUpdateAlways ||
|
||||
oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled ||
|
||||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration {
|
||||
oldSettings.PeerLoginExpiration != newSettings.PeerLoginExpiration ||
|
||||
oldSettings.MetricsPushEnabled != newSettings.MetricsPushEnabled {
|
||||
// Session deadline is derived from LastLogin + PeerLoginExpiration
|
||||
// on every Login/Sync response. Without a fan-out push, connected
|
||||
// peers keep the deadline they received at login time and only see
|
||||
@@ -409,6 +410,7 @@ func (am *DefaultAccountManager) UpdateAccountSettings(ctx context.Context, acco
|
||||
am.handleAutoUpdateVersionSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleAutoUpdateAlwaysSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handlePeerExposeSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
am.handleMetricsPushSettings(ctx, oldSettings, newSettings, userID, accountID)
|
||||
if err = am.handleInactivityExpirationSettings(ctx, oldSettings, newSettings, userID, accountID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -563,6 +565,16 @@ func (am *DefaultAccountManager) handleLazyConnectionSettings(ctx context.Contex
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handleMetricsPushSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.MetricsPushEnabled != newSettings.MetricsPushEnabled {
|
||||
if newSettings.MetricsPushEnabled {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountMetricsPushEnabled, nil)
|
||||
} else {
|
||||
am.StoreEvent(ctx, userID, accountID, accountID, activity.AccountMetricsPushDisabled, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (am *DefaultAccountManager) handlePeerLoginExpirationSettings(ctx context.Context, oldSettings, newSettings *types.Settings, userID, accountID string) {
|
||||
if oldSettings.PeerLoginExpirationEnabled != newSettings.PeerLoginExpirationEnabled {
|
||||
event := activity.AccountPeerLoginExpirationEnabled
|
||||
@@ -2045,6 +2057,7 @@ func newAccountWithId(ctx context.Context, accountID, userID, domain, email, nam
|
||||
Extra: &types.ExtraSettings{
|
||||
UserApprovalRequired: true,
|
||||
},
|
||||
LazyConnectionEnabled: true,
|
||||
},
|
||||
Onboarding: types.AccountOnboarding{
|
||||
OnboardingFlowPending: true,
|
||||
|
||||
@@ -276,6 +276,11 @@ const (
|
||||
// AgentNetworkSettingsUpdated indicates that a user updated Agent Network account settings
|
||||
AgentNetworkSettingsUpdated Activity = 139
|
||||
|
||||
// AccountMetricsPushEnabled indicates that a user enabled metrics push for the account
|
||||
AccountMetricsPushEnabled Activity = 140
|
||||
// AccountMetricsPushDisabled indicates that a user disabled metrics push for the account
|
||||
AccountMetricsPushDisabled Activity = 141
|
||||
|
||||
AccountDeleted Activity = 99999
|
||||
)
|
||||
|
||||
@@ -449,6 +454,9 @@ var activityMap = map[Activity]Code{
|
||||
|
||||
AgentNetworkSettingsUpdated: {"Agent Network settings updated", "agent_network.settings.update"},
|
||||
|
||||
AccountMetricsPushEnabled: {"Account metrics push enabled", "account.setting.metrics.push.enable"},
|
||||
AccountMetricsPushDisabled: {"Account metrics push disabled", "account.setting.metrics.push.disable"},
|
||||
|
||||
DomainAdded: {"Domain added", "domain.add"},
|
||||
DomainDeleted: {"Domain deleted", "domain.delete"},
|
||||
DomainValidated: {"Domain validated", "domain.validate"},
|
||||
|
||||
@@ -283,6 +283,9 @@ func (h *handler) updateAccountRequestSettings(req api.PutApiAccountsAccountIdJS
|
||||
if req.Settings.Ipv6EnabledGroups != nil {
|
||||
returnSettings.IPv6EnabledGroups = *req.Settings.Ipv6EnabledGroups
|
||||
}
|
||||
if req.Settings.MetricsPushEnabled != nil {
|
||||
returnSettings.MetricsPushEnabled = *req.Settings.MetricsPushEnabled
|
||||
}
|
||||
|
||||
return returnSettings, nil
|
||||
}
|
||||
@@ -413,6 +416,7 @@ func toAccountResponse(accountID string, settings *types.Settings, meta *types.A
|
||||
AutoUpdateVersion: &settings.AutoUpdateVersion,
|
||||
AutoUpdateAlways: &settings.AutoUpdateAlways,
|
||||
Ipv6EnabledGroups: &settings.IPv6EnabledGroups,
|
||||
MetricsPushEnabled: &settings.MetricsPushEnabled,
|
||||
EmbeddedIdpEnabled: &settings.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: &settings.LocalAuthDisabled,
|
||||
LocalMfaEnabled: &settings.LocalMfaEnabled,
|
||||
|
||||
@@ -129,6 +129,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr(""),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
@@ -156,6 +157,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr(""),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
@@ -183,6 +185,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr("latest"),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
@@ -210,6 +213,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr(""),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
@@ -237,6 +241,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr(""),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
@@ -264,6 +269,7 @@ func TestAccounts_AccountsHandler(t *testing.T) {
|
||||
DnsDomain: sr(""),
|
||||
AutoUpdateAlways: br(false),
|
||||
AutoUpdateVersion: sr(""),
|
||||
MetricsPushEnabled: br(false),
|
||||
EmbeddedIdpEnabled: br(false),
|
||||
LocalAuthDisabled: br(false),
|
||||
LocalMfaEnabled: br(false),
|
||||
|
||||
@@ -152,7 +152,11 @@ func (m *AuthMiddleware) checkJWTFromRequest(r *http.Request, authHeaderParts []
|
||||
return err
|
||||
}
|
||||
|
||||
err = m.syncUserJWTGroups(ctx, userAuth)
|
||||
// Detach the group-sync write from the request's cancellation: the dashboard
|
||||
// SPA aborts in-flight requests on re-render, which would otherwise cancel the
|
||||
// DB transaction mid-write and silently drop the synced groups. Context values
|
||||
// (request id, logger) are preserved; the store bounds the tx with its own timeout.
|
||||
err = m.syncUserJWTGroups(context.WithoutCancel(ctx), userAuth)
|
||||
if err != nil {
|
||||
log.WithContext(ctx).Errorf("HTTP server failed to sync user JWT groups: %s", err)
|
||||
}
|
||||
|
||||
@@ -241,6 +241,66 @@ func TestAuthMiddleware_Handler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestAuthMiddleware_SyncUserJWTGroupsDetachedFromRequestCancellation ensures the
|
||||
// JWT group sync write is not bound to the request context. The dashboard SPA
|
||||
// routinely aborts in-flight requests on re-render/navigation; if the sync ran in
|
||||
// the request context, the cancellation would roll back the DB transaction and the
|
||||
// synced groups would silently never persist. The sync must receive a context that
|
||||
// is not cancelled even when the originating request is.
|
||||
func TestAuthMiddleware_SyncUserJWTGroupsDetachedFromRequestCancellation(t *testing.T) {
|
||||
var (
|
||||
syncCalled bool
|
||||
syncCtxErr error
|
||||
)
|
||||
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
EnsureUserAccessByJWTGroupsFunc: mockEnsureUserAccessByJWTGroups,
|
||||
MarkPATUsedFunc: mockMarkPATUsed,
|
||||
GetPATInfoFunc: mockGetAccountInfoFromPAT,
|
||||
}
|
||||
|
||||
disabledLimiter := NewAPIRateLimiter(nil)
|
||||
disabledLimiter.SetEnabled(false)
|
||||
|
||||
authMiddleware := NewAuthMiddleware(
|
||||
mockAuth,
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (string, string, error) {
|
||||
return userAuth.AccountId, userAuth.UserId, nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) error {
|
||||
syncCalled = true
|
||||
syncCtxErr = ctx.Err()
|
||||
return nil
|
||||
},
|
||||
func(ctx context.Context, userAuth nbauth.UserAuth) (*types.User, error) {
|
||||
return &types.User{}, nil
|
||||
},
|
||||
disabledLimiter,
|
||||
nil,
|
||||
func(_ context.Context, _, _, _ string) bool { return false },
|
||||
)
|
||||
|
||||
handlerToTest := authMiddleware.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
|
||||
|
||||
// Simulate the dashboard aborting the request: it arrives already cancelled.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
req := httptest.NewRequest("GET", "http://testing/test", nil).WithContext(ctx)
|
||||
req.Header.Set("Authorization", "Bearer "+JWT)
|
||||
rec := httptest.NewRecorder()
|
||||
|
||||
handlerToTest.ServeHTTP(rec, req)
|
||||
|
||||
if !syncCalled {
|
||||
t.Fatal("syncUserJWTGroups was not called")
|
||||
}
|
||||
if syncCtxErr != nil {
|
||||
t.Fatalf("syncUserJWTGroups received a cancelled context (%v); the group-sync write must be detached from request cancellation", syncCtxErr)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAuthMiddleware_RateLimiting(t *testing.T) {
|
||||
mockAuth := &auth.MockManager{
|
||||
ValidateAndParseTokenFunc: mockValidateAndParseToken,
|
||||
|
||||
@@ -1048,7 +1048,7 @@ func testUpdateAccountPeers(t *testing.T) {
|
||||
|
||||
for _, channel := range peerChannels {
|
||||
update := <-channel
|
||||
assert.Nil(t, update.Update.NetbirdConfig)
|
||||
assert.Nil(t, update.Update.NetbirdConfig, "fan-out updates must not carry a NetbirdConfig; clients treat a config without relay as relay disabled and wipe their relay URLs")
|
||||
assert.Equal(t, tc.peers, len(update.Update.NetworkMap.RemotePeers))
|
||||
assert.Equal(t, tc.peers*2, len(update.Update.NetworkMap.FirewallRules))
|
||||
}
|
||||
|
||||
@@ -1605,7 +1605,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
settings_jwt_groups_enabled, settings_jwt_groups_claim_name, settings_jwt_allow_groups,
|
||||
settings_routing_peer_dns_resolution_enabled, settings_dns_domain, settings_network_range,
|
||||
settings_network_range_v6, settings_ipv6_enabled_groups, settings_lazy_connection_enabled,
|
||||
settings_local_mfa_enabled,
|
||||
settings_local_mfa_enabled, settings_metrics_push_enabled,
|
||||
-- Embedded ExtraSettings
|
||||
settings_extra_peer_approval_enabled, settings_extra_user_approval_required,
|
||||
settings_extra_integrated_validator, settings_extra_integrated_validator_groups
|
||||
@@ -1628,6 +1628,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
sIPv6EnabledGroups sql.NullString
|
||||
sLazyConnectionEnabled sql.NullBool
|
||||
sLocalMFAEnabled sql.NullBool
|
||||
sMetricsPushEnabled sql.NullBool
|
||||
sExtraPeerApprovalEnabled sql.NullBool
|
||||
sExtraUserApprovalRequired sql.NullBool
|
||||
sExtraIntegratedValidator sql.NullString
|
||||
@@ -1650,7 +1651,7 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
&sJWTGroupsEnabled, &sJWTGroupsClaimName, &sJWTAllowGroups,
|
||||
&sRoutingPeerDNSResolutionEnabled, &sDNSDomain, &sNetworkRange,
|
||||
&sNetworkRangeV6, &sIPv6EnabledGroups, &sLazyConnectionEnabled,
|
||||
&sLocalMFAEnabled,
|
||||
&sLocalMFAEnabled, &sMetricsPushEnabled,
|
||||
&sExtraPeerApprovalEnabled, &sExtraUserApprovalRequired,
|
||||
&sExtraIntegratedValidator, &sExtraIntegratedValidatorGroups,
|
||||
)
|
||||
@@ -1716,6 +1717,9 @@ func (s *SqlStore) getAccount(ctx context.Context, accountID string) (*types.Acc
|
||||
if sLocalMFAEnabled.Valid {
|
||||
account.Settings.LocalMfaEnabled = sLocalMFAEnabled.Bool
|
||||
}
|
||||
if sMetricsPushEnabled.Valid {
|
||||
account.Settings.MetricsPushEnabled = sMetricsPushEnabled.Bool
|
||||
}
|
||||
if sJWTAllowGroups.Valid {
|
||||
_ = json.Unmarshal([]byte(sJWTAllowGroups.String), &account.Settings.JWTAllowGroups)
|
||||
}
|
||||
|
||||
@@ -73,6 +73,9 @@ type Settings struct {
|
||||
// For new accounts this defaults to the All group.
|
||||
IPv6EnabledGroups []string `gorm:"serializer:json"`
|
||||
|
||||
// MetricsPushEnabled globally enables or disables client metrics push for the account
|
||||
MetricsPushEnabled bool `gorm:"default:false"`
|
||||
|
||||
// EmbeddedIdpEnabled indicates if the embedded identity provider is enabled.
|
||||
// This is a runtime-only field, not stored in the database.
|
||||
EmbeddedIdpEnabled bool `gorm:"-"`
|
||||
@@ -110,6 +113,7 @@ func (s *Settings) Copy() *Settings {
|
||||
AutoUpdateVersion: s.AutoUpdateVersion,
|
||||
AutoUpdateAlways: s.AutoUpdateAlways,
|
||||
IPv6EnabledGroups: slices.Clone(s.IPv6EnabledGroups),
|
||||
MetricsPushEnabled: s.MetricsPushEnabled,
|
||||
EmbeddedIdpEnabled: s.EmbeddedIdpEnabled,
|
||||
LocalAuthDisabled: s.LocalAuthDisabled,
|
||||
LocalMfaEnabled: s.LocalMfaEnabled,
|
||||
|
||||
@@ -59,6 +59,10 @@ type ProviderRoute struct {
|
||||
// (instead of the static AuthHeaderValue) — so the gateway holds a durable
|
||||
// Vertex credential rather than a 1-hour token.
|
||||
GCPServiceAccountKeyB64 string `json:"gcp_sa_key_b64,omitempty"`
|
||||
// SkipTLSVerify disables upstream TLS certificate verification when dialing
|
||||
// this route's upstream. For self-hosted / internal gateways behind a
|
||||
// private or self-signed certificate.
|
||||
SkipTLSVerify bool `json:"skip_tls_verify,omitempty"`
|
||||
}
|
||||
|
||||
// Config is the on-wire configuration accepted by the factory. An
|
||||
|
||||
@@ -615,8 +615,9 @@ func (m *Middleware) allowWithRoute(route ProviderRoute, userGroups []string) *m
|
||||
// path is silently dropped and the gateway returns a 4xx for
|
||||
// the malformed URL. Empty value leaves the original
|
||||
// target's path untouched.
|
||||
Path: route.UpstreamPath,
|
||||
StripHeaders: append([]string(nil), strippedAuthHeaders...),
|
||||
Path: route.UpstreamPath,
|
||||
StripHeaders: append([]string(nil), strippedAuthHeaders...),
|
||||
SkipTLSVerify: route.SkipTLSVerify,
|
||||
}
|
||||
authValue := route.AuthHeaderValue
|
||||
if route.GCPServiceAccountKeyB64 != "" {
|
||||
|
||||
@@ -107,6 +107,41 @@ func TestRouter_HappyPath(t *testing.T) {
|
||||
assert.Equal(t, "allow", dec, "decision metadata must be allow on a match")
|
||||
}
|
||||
|
||||
func TestRouter_SkipTLSVerifyPropagates(t *testing.T) {
|
||||
base := ProviderRoute{
|
||||
ID: "internal-gw",
|
||||
Models: []string{"gpt-4o"},
|
||||
AllowedGroupIDs: []string{defaultTestGroup},
|
||||
UpstreamScheme: "https",
|
||||
UpstreamHost: "gateway.internal",
|
||||
AuthHeaderName: "Authorization",
|
||||
AuthHeaderValue: "Bearer sk-test-123",
|
||||
}
|
||||
|
||||
t.Run("enabled", func(t *testing.T) {
|
||||
route := base
|
||||
route.SkipTLSVerify = true
|
||||
mw := New(Config{Providers: []ProviderRoute{route}})
|
||||
|
||||
out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, out.Mutations, "matched route must emit mutations")
|
||||
require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite")
|
||||
assert.True(t, out.Mutations.RewriteUpstream.SkipTLSVerify,
|
||||
"skip_tls_verify on the route must ride on the upstream rewrite")
|
||||
})
|
||||
|
||||
t.Run("default off", func(t *testing.T) {
|
||||
mw := New(Config{Providers: []ProviderRoute{base}})
|
||||
|
||||
out, err := mw.Invoke(context.Background(), newInputWithModel("gpt-4o"))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, out.Mutations.RewriteUpstream, "matched route must emit upstream rewrite")
|
||||
assert.False(t, out.Mutations.RewriteUpstream.SkipTLSVerify,
|
||||
"skip_tls_verify must default to false when the route does not set it")
|
||||
})
|
||||
}
|
||||
|
||||
func TestRouter_MissingModel(t *testing.T) {
|
||||
mw := New(Config{Providers: []ProviderRoute{{
|
||||
ID: "openai-prod",
|
||||
|
||||
@@ -243,6 +243,10 @@ type UpstreamRewrite struct {
|
||||
StripPathPrefix string
|
||||
AuthHeader *AuthHeader
|
||||
StripHeaders []string
|
||||
// SkipTLSVerify, when true, makes the proxy dial the rewritten upstream
|
||||
// without verifying its TLS certificate. Set by llm_router from the
|
||||
// provider's skip_tls_verification for self-hosted / internal gateways.
|
||||
SkipTLSVerify bool
|
||||
}
|
||||
|
||||
// AuthHeader is a single name/value pair the proxy injects on the
|
||||
|
||||
@@ -346,6 +346,11 @@ func (p *ReverseProxy) forwardUpstream(respWriter http.ResponseWriter, r *http.R
|
||||
r.Host = effectiveURL.Host
|
||||
applyUpstreamHeaders(r, upstreamRewrite)
|
||||
stripUpstreamPathPrefix(r, upstreamRewrite.StripPathPrefix)
|
||||
// A router-selected route (e.g. agent-network provider) can opt into
|
||||
// skipping upstream TLS verification per its provider config.
|
||||
if upstreamRewrite.SkipTLSVerify {
|
||||
ctx = roundtrip.WithSkipTLSVerify(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
rp := &httputil.ReverseProxy{
|
||||
|
||||
@@ -33,10 +33,15 @@ const ConnectTimeout = 10 * time.Second
|
||||
const healthCheckTimeout = 5 * time.Second
|
||||
|
||||
const (
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size (4 MB)
|
||||
// EnvMaxRecvMsgSize overrides the default gRPC max receive message size
|
||||
// for the management client connection. Value is in bytes.
|
||||
EnvMaxRecvMsgSize = "NB_MANAGEMENT_GRPC_MAX_MSG_SIZE"
|
||||
|
||||
// defaultMaxRecvMsgSize is the max gRPC receive message size used for the
|
||||
// management client connection when EnvMaxRecvMsgSize is unset or invalid.
|
||||
// It overrides the gRPC library default of 4 MB.
|
||||
defaultMaxRecvMsgSize = 1024 * 1024 * 16
|
||||
|
||||
errMsgMgmtPublicKey = "failed getting Management Service public key: %s"
|
||||
errMsgNoMgmtConnection = "no connection to management"
|
||||
)
|
||||
@@ -84,22 +89,22 @@ type ExposeResponse struct {
|
||||
}
|
||||
|
||||
// MaxRecvMsgSize returns the configured max gRPC receive message size from
|
||||
// the environment, or 0 if unset (which uses the gRPC default of 4 MB).
|
||||
// the environment, or defaultMaxRecvMsgSize (16 MB) if unset or invalid.
|
||||
func MaxRecvMsgSize() int {
|
||||
val := os.Getenv(EnvMaxRecvMsgSize)
|
||||
if val == "" {
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
size, err := strconv.Atoi(val)
|
||||
if err != nil {
|
||||
log.Warnf("invalid %s value %q, using default: %v", EnvMaxRecvMsgSize, val, err)
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
if size <= 0 {
|
||||
log.Warnf("invalid %s value %d, must be positive, using default", EnvMaxRecvMsgSize, size)
|
||||
return 0
|
||||
return defaultMaxRecvMsgSize
|
||||
}
|
||||
|
||||
return size
|
||||
@@ -536,7 +541,7 @@ func (c *GrpcClient) IsHealthy() bool {
|
||||
ctx, cancel := context.WithTimeout(c.ctx, healthCheckTimeout)
|
||||
defer cancel()
|
||||
|
||||
_, err := c.realClient.GetServerKey(ctx, &proto.Empty{})
|
||||
_, err := c.realClient.IsHealthy(ctx, &proto.Empty{})
|
||||
if err != nil {
|
||||
c.notifyDisconnected(err)
|
||||
log.Warnf("health check returned: %s", err)
|
||||
@@ -1030,8 +1035,6 @@ func infoToMetaData(info *system.Info) *proto.PeerSystemMeta {
|
||||
BlockLANAccess: info.BlockLANAccess,
|
||||
BlockInbound: info.BlockInbound,
|
||||
DisableIPv6: info.DisableIPv6,
|
||||
|
||||
LazyConnectionEnabled: info.LazyConnectionEnabled,
|
||||
},
|
||||
|
||||
Capabilities: peerCapabilities(*info),
|
||||
|
||||
@@ -21,11 +21,11 @@ func TestMaxRecvMsgSize(t *testing.T) {
|
||||
envValue string
|
||||
expected int
|
||||
}{
|
||||
{name: "unset returns 0", envValue: "", expected: 0},
|
||||
{name: "unset returns default", envValue: "", expected: defaultMaxRecvMsgSize},
|
||||
{name: "valid value", envValue: "10485760", expected: 10485760},
|
||||
{name: "non-numeric returns 0", envValue: "abc", expected: 0},
|
||||
{name: "negative returns 0", envValue: "-1", expected: 0},
|
||||
{name: "zero returns 0", envValue: "0", expected: 0},
|
||||
{name: "non-numeric returns default", envValue: "abc", expected: defaultMaxRecvMsgSize},
|
||||
{name: "negative returns default", envValue: "-1", expected: defaultMaxRecvMsgSize},
|
||||
{name: "zero returns default", envValue: "0", expected: defaultMaxRecvMsgSize},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
|
||||
@@ -371,6 +371,10 @@ components:
|
||||
description: When true, updates are installed automatically in the background. When false, updates require user interaction from the UI.
|
||||
type: boolean
|
||||
example: false
|
||||
metrics_push_enabled:
|
||||
description: Enables or disables client metrics push for all peers in the account
|
||||
type: boolean
|
||||
example: false
|
||||
embedded_idp_enabled:
|
||||
description: Indicates whether the embedded identity provider (Dex) is enabled for this account. This is a read-only field.
|
||||
type: boolean
|
||||
@@ -5115,6 +5119,10 @@ components:
|
||||
type: boolean
|
||||
description: Whether the provider is enabled.
|
||||
example: true
|
||||
skip_tls_verification:
|
||||
type: boolean
|
||||
description: Whether upstream TLS certificate verification is skipped when the proxy dials this provider's URL. Intended for self-hosted / internal gateways behind a private or self-signed certificate.
|
||||
example: false
|
||||
created_at:
|
||||
type: string
|
||||
format: date-time
|
||||
@@ -5134,6 +5142,7 @@ components:
|
||||
- upstream_url
|
||||
- models
|
||||
- enabled
|
||||
- skip_tls_verification
|
||||
- created_at
|
||||
- updated_at
|
||||
AgentNetworkProviderRequest:
|
||||
@@ -5186,6 +5195,10 @@ components:
|
||||
type: boolean
|
||||
description: Whether the provider is enabled. Defaults to true on create.
|
||||
example: true
|
||||
skip_tls_verification:
|
||||
type: boolean
|
||||
description: Skip upstream TLS certificate verification when the proxy dials this provider's URL. For self-hosted / internal gateways behind a private or self-signed certificate. Defaults to false. When omitted on update, the stored value is left unchanged.
|
||||
example: false
|
||||
required:
|
||||
- provider_id
|
||||
- name
|
||||
|
||||
@@ -1684,6 +1684,9 @@ type AccountSettings struct {
|
||||
// LocalMfaEnabled Enables or disables TOTP multi-factor authentication for local users. Only applicable when the embedded identity provider is enabled.
|
||||
LocalMfaEnabled *bool `json:"local_mfa_enabled,omitempty"`
|
||||
|
||||
// MetricsPushEnabled Enables or disables client metrics push for all peers in the account
|
||||
MetricsPushEnabled *bool `json:"metrics_push_enabled,omitempty"`
|
||||
|
||||
// NetworkRange Allows to define a custom network range for the account in CIDR format
|
||||
NetworkRange *string `json:"network_range,omitempty"`
|
||||
|
||||
@@ -2221,6 +2224,9 @@ type AgentNetworkProvider struct {
|
||||
// ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom).
|
||||
ProviderId string `json:"provider_id"`
|
||||
|
||||
// SkipTlsVerification Whether upstream TLS certificate verification is skipped when the proxy dials this provider's URL. Intended for self-hosted / internal gateways behind a private or self-signed certificate.
|
||||
SkipTlsVerification bool `json:"skip_tls_verification"`
|
||||
|
||||
// UpdatedAt Timestamp when the provider was last updated.
|
||||
UpdatedAt *time.Time `json:"updated_at,omitempty"`
|
||||
|
||||
@@ -2269,6 +2275,9 @@ type AgentNetworkProviderRequest struct {
|
||||
// ProviderId Catalog identifier for the upstream AI provider (e.g. openai_api, anthropic_api, azure_openai_api, bedrock_api, vertex_ai_api, mistral_api, custom).
|
||||
ProviderId string `json:"provider_id"`
|
||||
|
||||
// SkipTlsVerification Skip upstream TLS certificate verification when the proxy dials this provider's URL. For self-hosted / internal gateways behind a private or self-signed certificate. Defaults to false. When omitted on update, the stored value is left unchanged.
|
||||
SkipTlsVerification *bool `json:"skip_tls_verification,omitempty"`
|
||||
|
||||
// UpstreamUrl Full upstream URL (with scheme) that NetBird forwards traffic to.
|
||||
UpstreamUrl string `json:"upstream_url"`
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -312,6 +312,8 @@ message NetbirdConfig {
|
||||
RelayConfig relay = 4;
|
||||
|
||||
FlowConfig flow = 5;
|
||||
|
||||
MetricsConfig metrics = 6;
|
||||
}
|
||||
|
||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||
@@ -350,6 +352,10 @@ message FlowConfig {
|
||||
bool dnsCollection = 8;
|
||||
}
|
||||
|
||||
message MetricsConfig {
|
||||
bool enabled = 1;
|
||||
}
|
||||
|
||||
// JWTConfig represents JWT authentication configuration for validating tokens.
|
||||
message JWTConfig {
|
||||
string issuer = 1;
|
||||
|
||||
Reference in New Issue
Block a user